fix(tools): close resp.Body on retry cancel and cache http.Client instances (#940)

* fix(tools): close resp.Body on retry cancel and cache http.Client instances

Fix resp.Body leak in DoRequestWithRetry where req.Body (request) was
incorrectly closed instead of resp.Body (response) on context cancel.
Cache http.Client on web search/fetch provider structs and channel
adapters (WeCom, LINE) to avoid per-call allocation overhead.

* fix(channels): preserve original http client timeouts for LINE and WeCom

Split LINE single 60s client into infoClient (10s) for bot info lookups
and apiClient (30s) for messaging API calls. Lower WeCom cached client
base timeout from 60s to 30s (matching uploadMedia), and ensure it is
always >= the configured ReplyTimeout so the per-request context
deadline remains the effective limit.

* refactor(tools): extract timeout consts and deduplicate WebFetchTool constructors

Address PR review feedback from xiaket:
- Define searchTimeout, perplexityTimeout, fetchTimeout, defaultMaxChars,
  and maxRedirects as package-level consts instead of magic numbers.
- Remove misleading "No proxy" comment in NewWebFetchTool.
- Deduplicate NewWebFetchTool by delegating to NewWebFetchToolWithProxy.

* test(utils): add context cancellation test for DoRequestWithRetry

Verify that resp.Body is properly closed when the context is canceled
during retry sleep, covering the C8 resp.Body leak fix.

* fix(utils): close resp in test to satisfy bodyclose linter

* fix(utils): eliminate flakiness in context cancellation retry test

Synchronize cancellation using an onRoundTrip callback from the
transport wrapper instead of a timing-based context timeout. This
ensures the first client.Do completes before cancel fires, so
cancellation always hits during sleepWithCtx.
This commit is contained in:
Tong Niu
2026-03-01 16:55:46 +11:00
committed by GitHub
parent b3c3b02666
commit 44a52c0cf6
8 changed files with 230 additions and 79 deletions
+11 -3
View File
@@ -99,7 +99,7 @@ func registerSharedTools(
}
// Web tools
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
@@ -113,10 +113,18 @@ func registerSharedTools(
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
Proxy: cfg.Tools.Web.Proxy,
}); searchTool != nil {
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
}
agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
}
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
+11 -9
View File
@@ -45,11 +45,13 @@ type replyTokenEntry struct {
type LINEChannel struct {
*channels.BaseChannel
config config.LINEConfig
botUserID string // Bot's user ID
botBasicID string // Bot's basic ID (e.g. @216ru...)
botDisplayName string // Bot's display name for text-based mention detection
replyTokens sync.Map // chatID -> replyTokenEntry
quoteTokens sync.Map // chatID -> quoteToken (string)
infoClient *http.Client // for bot info lookups (short timeout)
apiClient *http.Client // for messaging API calls
botUserID string // Bot's user ID
botBasicID string // Bot's basic ID (e.g. @216ru...)
botDisplayName string // Bot's display name for text-based mention detection
replyTokens sync.Map // chatID -> replyTokenEntry
quoteTokens sync.Map // chatID -> quoteToken (string)
ctx context.Context
cancel context.CancelFunc
}
@@ -69,6 +71,8 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
return &LINEChannel{
BaseChannel: base,
config: cfg,
infoClient: &http.Client{Timeout: 10 * time.Second},
apiClient: &http.Client{Timeout: 30 * time.Second},
}, nil
}
@@ -104,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error {
}
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
resp, err := c.infoClient.Do(req)
if err != nil {
return err
}
@@ -644,8 +647,7 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
resp, err := c.apiClient.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
+12 -6
View File
@@ -32,6 +32,7 @@ const (
type WeComAppChannel struct {
*channels.BaseChannel
config config.WeComAppConfig
client *http.Client
accessToken string
tokenExpiry time.Time
tokenMu sync.RWMutex
@@ -129,10 +130,18 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
// Client timeout must be >= the configured ReplyTimeout so the
// per-request context deadline is always the effective limit.
clientTimeout := 30 * time.Second
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
clientTimeout = d
}
ctx, cancel := context.WithCancel(context.Background())
return &WeComAppChannel{
BaseChannel: base,
config: cfg,
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: make(map[string]bool),
@@ -306,8 +315,7 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
}
req.Header.Set("Content-Type", writer.FormDataContentType())
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return "", channels.ClassifyNetError(err)
}
@@ -364,8 +372,7 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
@@ -746,8 +753,7 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
+10 -2
View File
@@ -25,6 +25,7 @@ import (
type WeComBotChannel struct {
*channels.BaseChannel
config config.WeComConfig
client *http.Client
ctx context.Context
cancel context.CancelFunc
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
@@ -93,10 +94,18 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
// Client timeout must be >= the configured ReplyTimeout so the
// per-request context deadline is always the effective limit.
clientTimeout := 30 * time.Second
if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
clientTimeout = d
}
ctx, cancel := context.WithCancel(context.Background())
return &WeComBotChannel{
BaseChannel: base,
config: cfg,
client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
processedMsgs: make(map[string]bool),
@@ -450,8 +459,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
+59 -50
View File
@@ -15,6 +15,14 @@ import (
const (
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
// HTTP client timeouts for web tool providers.
searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
fetchTimeout = 60 * time.Second // WebFetchTool
defaultMaxChars = 50000
maxRedirects = 5
)
// Pre-compiled regexes for HTML text extraction
@@ -74,6 +82,7 @@ type SearchProvider interface {
type BraveSearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -88,11 +97,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -143,6 +148,7 @@ type TavilySearchProvider struct {
apiKey string
baseURL string
proxy string
client *http.Client
}
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -174,11 +180,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -226,7 +228,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
}
type DuckDuckGoSearchProvider struct {
proxy string
proxy string
client *http.Client
}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -239,11 +242,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 10*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -322,6 +321,7 @@ func stripTags(content string) string {
type PerplexitySearchProvider struct {
apiKey string
proxy string
client *http.Client
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -356,11 +356,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(p.proxy, 30*time.Second)
if err != nil {
return "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := client.Do(req)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -415,43 +411,60 @@ type WebSearchToolOptions struct {
Proxy string
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
}
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
}
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
}
provider = &TavilySearchProvider{
apiKey: opts.TavilyAPIKey,
baseURL: opts.TavilyBaseURL,
proxy: opts.Proxy,
client: client,
}
if opts.TavilyMaxResults > 0 {
maxResults = opts.TavilyMaxResults
}
} else if opts.DuckDuckGoEnabled {
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
}
provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else {
return nil
return nil, nil
}
return &WebSearchTool{
provider: provider,
maxResults: maxResults,
}
}, nil
}
func (t *WebSearchTool) Name() string {
@@ -508,25 +521,34 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR
type WebFetchTool struct {
maxChars int
proxy string
client *http.Client
}
func NewWebFetchTool(maxChars int) *WebFetchTool {
if maxChars <= 0 {
maxChars = 50000
}
return &WebFetchTool{
maxChars: maxChars,
}
// createHTTPClient cannot fail with an empty proxy string.
tool, _ := NewWebFetchToolWithProxy(maxChars, "")
return tool
}
func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
func NewWebFetchToolWithProxy(maxChars int, proxy string) (*WebFetchTool, error) {
if maxChars <= 0 {
maxChars = 50000
maxChars = defaultMaxChars
}
client, err := createHTTPClient(proxy, fetchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= maxRedirects {
return fmt.Errorf("stopped after %d redirects", maxRedirects)
}
return nil
}
return &WebFetchTool{
maxChars: maxChars,
proxy: proxy,
}
client: client,
}, nil
}
func (t *WebFetchTool) Name() string {
@@ -588,20 +610,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
req.Header.Set("User-Agent", userAgent)
client, err := createHTTPClient(t.proxy, 60*time.Second)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
}
// Configure redirect handling
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("stopped after 5 redirects")
}
return nil
}
resp, err := client.Do(req)
resp, err := t.client.Do(req)
if err != nil {
return ErrorResult(fmt.Sprintf("request failed: %v", err))
}
+36 -9
View File
@@ -176,13 +176,19 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when Brave API key is empty")
}
// Also nil when nothing is enabled
tool = NewWebSearchTool(WebSearchToolOptions{})
tool, err = NewWebSearchTool(WebSearchToolOptions{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if tool != nil {
t.Errorf("Expected nil tool when no provider is enabled")
}
@@ -190,7 +196,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ctx := context.Background()
args := map[string]any{}
@@ -438,7 +447,10 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
}
func TestNewWebFetchToolWithProxy(t *testing.T) {
tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
if err != nil {
t.Fatalf("NewWebFetchToolWithProxy() error: %v", err)
}
if tool.maxChars != 1024 {
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
}
@@ -446,7 +458,10 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
}
tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
if err != nil {
t.Fatalf("NewWebFetchToolWithProxy() error: %v", err)
}
if tool.maxChars != 50000 {
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
}
@@ -454,12 +469,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) {
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "k",
PerplexityMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*PerplexitySearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
@@ -470,12 +488,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
})
t.Run("brave", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKey: "k",
BraveMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*BraveSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
@@ -486,11 +507,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
})
t.Run("duckduckgo", func(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
DuckDuckGoEnabled: true,
DuckDuckGoMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
if !ok {
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
@@ -542,12 +566,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
}))
defer server.Close()
tool := NewWebSearchTool(WebSearchToolOptions{
tool, err := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKey: "test-key",
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
ctx := context.Background()
args := map[string]any{
+3
View File
@@ -37,6 +37,9 @@ func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response,
if i < maxRetries-1 {
if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil {
if resp != nil {
resp.Body.Close()
}
return nil, fmt.Errorf("failed to sleep: %w", err)
}
}
+88
View File
@@ -1,8 +1,11 @@
package utils
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -77,6 +80,91 @@ func TestDoRequestWithRetry(t *testing.T) {
}
}
func TestDoRequestWithRetry_ContextCancel(t *testing.T) {
// Use a long retry delay so cancellation always hits during sleepWithCtx.
retryDelayUnit = 10 * time.Second
t.Cleanup(func() { retryDelayUnit = time.Second })
bodyClosed := false
firstRoundTripDone := make(chan struct{}, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("error"))
}))
defer server.Close()
client := server.Client()
client.Timeout = 30 * time.Second
client.Transport = &bodyCloseTracker{
rt: client.Transport,
onClose: func() { bodyClosed = true },
// Signal after the first round-trip response is fully constructed on the client side.
onRoundTrip: func() {
select {
case firstRoundTripDone <- struct{}{}:
default:
}
},
trackURL: server.URL,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Cancel the context after the first round-trip completes on the client side.
// This ensures client.Do has returned a valid resp (with body) and the retry
// loop is about to enter sleepWithCtx, where the cancel will be detected.
go func() {
<-firstRoundTripDone
cancel()
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := DoRequestWithRetry(client, req)
if resp != nil {
resp.Body.Close()
}
require.Error(t, err, "expected error from context cancellation")
assert.Nil(t, resp, "expected nil response when context is canceled")
assert.True(t, bodyClosed, "expected resp.Body to be closed on context cancellation")
}
// bodyCloseTracker wraps an http.RoundTripper and records when response bodies are closed.
type bodyCloseTracker struct {
rt http.RoundTripper
onClose func()
onRoundTrip func() // called after each successful round-trip
trackURL string
}
func (t *bodyCloseTracker) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := t.rt.RoundTrip(req)
if err != nil {
return resp, err
}
if strings.HasPrefix(req.URL.String(), t.trackURL) {
resp.Body = &closeNotifier{ReadCloser: resp.Body, onClose: t.onClose}
if t.onRoundTrip != nil {
t.onRoundTrip()
}
}
return resp, nil
}
// closeNotifier wraps an io.ReadCloser to detect Close calls.
type closeNotifier struct {
io.ReadCloser
onClose func()
}
func (c *closeNotifier) Close() error {
c.onClose()
return c.ReadCloser.Close()
}
func TestDoRequestWithRetry_Delay(t *testing.T) {
retryDelayUnit = time.Millisecond
t.Cleanup(func() { retryDelayUnit = time.Second })