From 44a52c0cf646beee1522b85c67777df5fdba0386 Mon Sep 17 00:00:00 2001 From: Tong Niu Date: Sun, 1 Mar 2026 16:55:46 +1100 Subject: [PATCH] fix(tools): close resp.Body on retry cancel and cache http.Client instances (#940) * fix(tools): close resp.Body on retry cancel and cache http.Client instances Fix resp.Body leak in DoRequestWithRetry where req.Body (request) was incorrectly closed instead of resp.Body (response) on context cancel. Cache http.Client on web search/fetch provider structs and channel adapters (WeCom, LINE) to avoid per-call allocation overhead. * fix(channels): preserve original http client timeouts for LINE and WeCom Split LINE single 60s client into infoClient (10s) for bot info lookups and apiClient (30s) for messaging API calls. Lower WeCom cached client base timeout from 60s to 30s (matching uploadMedia), and ensure it is always >= the configured ReplyTimeout so the per-request context deadline remains the effective limit. * refactor(tools): extract timeout consts and deduplicate WebFetchTool constructors Address PR review feedback from xiaket: - Define searchTimeout, perplexityTimeout, fetchTimeout, defaultMaxChars, and maxRedirects as package-level consts instead of magic numbers. - Remove misleading "No proxy" comment in NewWebFetchTool. - Deduplicate NewWebFetchTool by delegating to NewWebFetchToolWithProxy. * test(utils): add context cancellation test for DoRequestWithRetry Verify that resp.Body is properly closed when the context is canceled during retry sleep, covering the C8 resp.Body leak fix. * fix(utils): close resp in test to satisfy bodyclose linter * fix(utils): eliminate flakiness in context cancellation retry test Synchronize cancellation using an onRoundTrip callback from the transport wrapper instead of a timing-based context timeout. This ensures the first client.Do completes before cancel fires, so cancellation always hits during sleepWithCtx. --- pkg/agent/loop.go | 14 ++++- pkg/channels/line/line.go | 20 ++++--- pkg/channels/wecom/app.go | 18 ++++-- pkg/channels/wecom/bot.go | 12 +++- pkg/tools/web.go | 109 +++++++++++++++++++---------------- pkg/tools/web_test.go | 45 ++++++++++++--- pkg/utils/http_retry.go | 3 + pkg/utils/http_retry_test.go | 88 ++++++++++++++++++++++++++++ 8 files changed, 230 insertions(+), 79 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8fd7328d1..a72f95bb1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -99,7 +99,7 @@ func registerSharedTools( } // Web tools - if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ BraveAPIKey: cfg.Tools.Web.Brave.APIKey, BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, BraveEnabled: cfg.Tools.Web.Brave.Enabled, @@ -113,10 +113,18 @@ func registerSharedTools( PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, Proxy: cfg.Tools.Web.Proxy, - }); searchTool != nil { + }) + if err != nil { + logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) + } else if searchTool != nil { agent.Tools.Register(searchTool) } - agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)) + fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } else { + agent.Tools.Register(fetchTool) + } // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms agent.Tools.Register(tools.NewI2CTool()) diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 9fac2831c..398f12e6b 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -45,11 +45,13 @@ type replyTokenEntry struct { type LINEChannel struct { *channels.BaseChannel config config.LINEConfig - botUserID string // Bot's user ID - botBasicID string // Bot's basic ID (e.g. @216ru...) - botDisplayName string // Bot's display name for text-based mention detection - replyTokens sync.Map // chatID -> replyTokenEntry - quoteTokens sync.Map // chatID -> quoteToken (string) + infoClient *http.Client // for bot info lookups (short timeout) + apiClient *http.Client // for messaging API calls + botUserID string // Bot's user ID + botBasicID string // Bot's basic ID (e.g. @216ru...) + botDisplayName string // Bot's display name for text-based mention detection + replyTokens sync.Map // chatID -> replyTokenEntry + quoteTokens sync.Map // chatID -> quoteToken (string) ctx context.Context cancel context.CancelFunc } @@ -69,6 +71,8 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha return &LINEChannel{ BaseChannel: base, config: cfg, + infoClient: &http.Client{Timeout: 10 * time.Second}, + apiClient: &http.Client{Timeout: 30 * time.Second}, }, nil } @@ -104,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error { } req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) + resp, err := c.infoClient.Do(req) if err != nil { return err } @@ -644,8 +647,7 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := c.apiClient.Do(req) if err != nil { return channels.ClassifyNetError(err) } diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 7a23f9617..292a71fd2 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -32,6 +32,7 @@ const ( type WeComAppChannel struct { *channels.BaseChannel config config.WeComAppConfig + client *http.Client accessToken string tokenExpiry time.Time tokenMu sync.RWMutex @@ -129,10 +130,18 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) + // Client timeout must be >= the configured ReplyTimeout so the + // per-request context deadline is always the effective limit. + clientTimeout := 30 * time.Second + if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { + clientTimeout = d + } + ctx, cancel := context.WithCancel(context.Background()) return &WeComAppChannel{ BaseChannel: base, config: cfg, + client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, processedMsgs: make(map[string]bool), @@ -306,8 +315,7 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp } req.Header.Set("Content-Type", writer.FormDataContentType()) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return "", channels.ClassifyNetError(err) } @@ -364,8 +372,7 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return channels.ClassifyNetError(err) } @@ -746,8 +753,7 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return channels.ClassifyNetError(err) } diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 39f84d55c..0d0426c0d 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -25,6 +25,7 @@ import ( type WeComBotChannel struct { *channels.BaseChannel config config.WeComConfig + client *http.Client ctx context.Context cancel context.CancelFunc processedMsgs map[string]bool // Message deduplication: msg_id -> processed @@ -93,10 +94,18 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) + // Client timeout must be >= the configured ReplyTimeout so the + // per-request context deadline is always the effective limit. + clientTimeout := 30 * time.Second + if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { + clientTimeout = d + } + ctx, cancel := context.WithCancel(context.Background()) return &WeComBotChannel{ BaseChannel: base, config: cfg, + client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, processedMsgs: make(map[string]bool), @@ -450,8 +459,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return channels.ClassifyNetError(err) } diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 8ba2a723a..834e7bfc7 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -15,6 +15,14 @@ import ( const ( userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" + + // HTTP client timeouts for web tool providers. + searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo + perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower) + fetchTimeout = 60 * time.Second // WebFetchTool + + defaultMaxChars = 50000 + maxRedirects = 5 ) // Pre-compiled regexes for HTML text extraction @@ -74,6 +82,7 @@ type SearchProvider interface { type BraveSearchProvider struct { apiKey string proxy string + client *http.Client } func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -88,11 +97,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in req.Header.Set("Accept", "application/json") req.Header.Set("X-Subscription-Token", p.apiKey) - client, err := createHTTPClient(p.proxy, 10*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -143,6 +148,7 @@ type TavilySearchProvider struct { apiKey string baseURL string proxy string + client *http.Client } func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -174,11 +180,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(p.proxy, 10*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -226,7 +228,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i } type DuckDuckGoSearchProvider struct { - proxy string + proxy string + client *http.Client } func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -239,11 +242,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(p.proxy, 10*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -322,6 +321,7 @@ func stripTags(content string) string { type PerplexitySearchProvider struct { apiKey string proxy string + client *http.Client } func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -356,11 +356,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou req.Header.Set("Authorization", "Bearer "+p.apiKey) req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(p.proxy, 30*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -415,43 +411,60 @@ type WebSearchToolOptions struct { Proxy string } -func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { +func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { var provider SearchProvider maxResults := 5 // Priority: Perplexity > Brave > Tavily > DuckDuckGo if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { - provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy} + client, err := createHTTPClient(opts.Proxy, perplexityTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err) + } + provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client} if opts.PerplexityMaxResults > 0 { maxResults = opts.PerplexityMaxResults } } else if opts.BraveEnabled && opts.BraveAPIKey != "" { - provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy} + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err) + } + provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client} if opts.BraveMaxResults > 0 { maxResults = opts.BraveMaxResults } } else if opts.TavilyEnabled && opts.TavilyAPIKey != "" { + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err) + } provider = &TavilySearchProvider{ apiKey: opts.TavilyAPIKey, baseURL: opts.TavilyBaseURL, proxy: opts.Proxy, + client: client, } if opts.TavilyMaxResults > 0 { maxResults = opts.TavilyMaxResults } } else if opts.DuckDuckGoEnabled { - provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy} + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err) + } + provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client} if opts.DuckDuckGoMaxResults > 0 { maxResults = opts.DuckDuckGoMaxResults } } else { - return nil + return nil, nil } return &WebSearchTool{ provider: provider, maxResults: maxResults, - } + }, nil } func (t *WebSearchTool) Name() string { @@ -508,25 +521,34 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR type WebFetchTool struct { maxChars int proxy string + client *http.Client } func NewWebFetchTool(maxChars int) *WebFetchTool { - if maxChars <= 0 { - maxChars = 50000 - } - return &WebFetchTool{ - maxChars: maxChars, - } + // createHTTPClient cannot fail with an empty proxy string. + tool, _ := NewWebFetchToolWithProxy(maxChars, "") + return tool } -func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool { +func NewWebFetchToolWithProxy(maxChars int, proxy string) (*WebFetchTool, error) { if maxChars <= 0 { - maxChars = 50000 + maxChars = defaultMaxChars + } + client, err := createHTTPClient(proxy, fetchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err) + } + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + return nil } return &WebFetchTool{ maxChars: maxChars, proxy: proxy, - } + client: client, + }, nil } func (t *WebFetchTool) Name() string { @@ -588,20 +610,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(t.proxy, 60*time.Second) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err)) - } - - // Configure redirect handling - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - if len(via) >= 5 { - return fmt.Errorf("stopped after 5 redirects") - } - return nil - } - - resp, err := client.Do(req) + resp, err := t.client.Do(req) if err != nil { return ErrorResult(fmt.Sprintf("request failed: %v", err)) } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 2cd79eb24..db3c08ba6 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -176,13 +176,19 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { // TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing func TestWebTool_WebSearch_NoApiKey(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) + tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } if tool != nil { t.Errorf("Expected nil tool when Brave API key is empty") } // Also nil when nothing is enabled - tool = NewWebSearchTool(WebSearchToolOptions{}) + tool, err = NewWebSearchTool(WebSearchToolOptions{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } if tool != nil { t.Errorf("Expected nil tool when no provider is enabled") } @@ -190,7 +196,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) { // TestWebTool_WebSearch_MissingQuery verifies error handling for missing query func TestWebTool_WebSearch_MissingQuery(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) + tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } ctx := context.Background() args := map[string]any{} @@ -438,7 +447,10 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { } func TestNewWebFetchToolWithProxy(t *testing.T) { - tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890") + tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890") + if err != nil { + t.Fatalf("NewWebFetchToolWithProxy() error: %v", err) + } if tool.maxChars != 1024 { t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024) } @@ -446,7 +458,10 @@ func TestNewWebFetchToolWithProxy(t *testing.T) { t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890") } - tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890") + tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890") + if err != nil { + t.Fatalf("NewWebFetchToolWithProxy() error: %v", err) + } if tool.maxChars != 50000 { t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000) } @@ -454,12 +469,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) { func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { t.Run("perplexity", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ PerplexityEnabled: true, PerplexityAPIKey: "k", PerplexityMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } p, ok := tool.provider.(*PerplexitySearchProvider) if !ok { t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider) @@ -470,12 +488,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { }) t.Run("brave", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ BraveEnabled: true, BraveAPIKey: "k", BraveMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } p, ok := tool.provider.(*BraveSearchProvider) if !ok { t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider) @@ -486,11 +507,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { }) t.Run("duckduckgo", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ DuckDuckGoEnabled: true, DuckDuckGoMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } p, ok := tool.provider.(*DuckDuckGoSearchProvider) if !ok { t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider) @@ -542,12 +566,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { })) defer server.Close() - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ TavilyEnabled: true, TavilyAPIKey: "test-key", TavilyBaseURL: server.URL, TavilyMaxResults: 5, }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } ctx := context.Background() args := map[string]any{ diff --git a/pkg/utils/http_retry.go b/pkg/utils/http_retry.go index e90fa2129..135ea0ef5 100644 --- a/pkg/utils/http_retry.go +++ b/pkg/utils/http_retry.go @@ -37,6 +37,9 @@ func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response, if i < maxRetries-1 { if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil { + if resp != nil { + resp.Body.Close() + } return nil, fmt.Errorf("failed to sleep: %w", err) } } diff --git a/pkg/utils/http_retry_test.go b/pkg/utils/http_retry_test.go index 1c2dbe115..d64cd5eda 100644 --- a/pkg/utils/http_retry_test.go +++ b/pkg/utils/http_retry_test.go @@ -1,8 +1,11 @@ package utils import ( + "context" + "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -77,6 +80,91 @@ func TestDoRequestWithRetry(t *testing.T) { } } +func TestDoRequestWithRetry_ContextCancel(t *testing.T) { + // Use a long retry delay so cancellation always hits during sleepWithCtx. + retryDelayUnit = 10 * time.Second + t.Cleanup(func() { retryDelayUnit = time.Second }) + + bodyClosed := false + firstRoundTripDone := make(chan struct{}, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("error")) + })) + defer server.Close() + + client := server.Client() + client.Timeout = 30 * time.Second + client.Transport = &bodyCloseTracker{ + rt: client.Transport, + onClose: func() { bodyClosed = true }, + // Signal after the first round-trip response is fully constructed on the client side. + onRoundTrip: func() { + select { + case firstRoundTripDone <- struct{}{}: + default: + } + }, + trackURL: server.URL, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Cancel the context after the first round-trip completes on the client side. + // This ensures client.Do has returned a valid resp (with body) and the retry + // loop is about to enter sleepWithCtx, where the cancel will be detected. + go func() { + <-firstRoundTripDone + cancel() + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := DoRequestWithRetry(client, req) + if resp != nil { + resp.Body.Close() + } + require.Error(t, err, "expected error from context cancellation") + assert.Nil(t, resp, "expected nil response when context is canceled") + assert.True(t, bodyClosed, "expected resp.Body to be closed on context cancellation") +} + +// bodyCloseTracker wraps an http.RoundTripper and records when response bodies are closed. +type bodyCloseTracker struct { + rt http.RoundTripper + onClose func() + onRoundTrip func() // called after each successful round-trip + trackURL string +} + +func (t *bodyCloseTracker) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.rt.RoundTrip(req) + if err != nil { + return resp, err + } + if strings.HasPrefix(req.URL.String(), t.trackURL) { + resp.Body = &closeNotifier{ReadCloser: resp.Body, onClose: t.onClose} + if t.onRoundTrip != nil { + t.onRoundTrip() + } + } + return resp, nil +} + +// closeNotifier wraps an io.ReadCloser to detect Close calls. +type closeNotifier struct { + io.ReadCloser + onClose func() +} + +func (c *closeNotifier) Close() error { + c.onClose() + return c.ReadCloser.Close() +} + func TestDoRequestWithRetry_Delay(t *testing.T) { retryDelayUnit = time.Millisecond t.Cleanup(func() { retryDelayUnit = time.Second })