From 95716b106b81382619a813dc5c8f4b66e44ae238 Mon Sep 17 00:00:00 2001 From: yanhool <752612542@qq.com> Date: Tue, 10 Mar 2026 16:34:11 +0800 Subject: [PATCH] feat(web_search): add load balance and failover for api keys (#982) * feat(web_search): add load balance and failover for api keys * feat(web_search): add load balance and failover for api keys * lint * new iter to get api key * deleted conflicts --- config/config.example.json | 8 +- pkg/agent/loop.go | 10 +- pkg/config/config.go | 47 +- pkg/config/config_test.go | 2 +- pkg/config/defaults.go | 8 + .../sources/openclaw/openclaw_config.go | 24 +- pkg/tools/web.go | 483 +++++++++++------- pkg/tools/web_test.go | 129 ++++- 8 files changed, 492 insertions(+), 219 deletions(-) diff --git a/config/config.example.json b/config/config.example.json index 3a33b3caf..49658b9f2 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -284,6 +284,9 @@ "brave": { "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", + "api_keys": [ + "YOUR_BRAVE_API_KEY" + ], "max_results": 5 }, "tavily": { @@ -298,7 +301,10 @@ }, "perplexity": { "enabled": false, - "api_key": "", + "api_key": "pplx-xxx", + "api_keys": [ + "pplx-xxx" + ], "max_results": 5 }, "searxng": { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index d5f661293..235d42fcc 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -120,19 +120,21 @@ func registerSharedTools( continue } - // Web tools if cfg.Tools.IsToolEnabled("web") { searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Brave.APIKey, cfg.Tools.Web.Brave.APIKeys), BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, BraveEnabled: cfg.Tools.Web.Brave.Enabled, - TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, + TavilyAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Tavily.APIKey, cfg.Tools.Web.Tavily.APIKeys), TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, - PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityAPIKeys: config.MergeAPIKeys( + cfg.Tools.Web.Perplexity.APIKey, + cfg.Tools.Web.Perplexity.APIKeys, + ), PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL, diff --git a/pkg/config/config.go b/pkg/config/config.go index 76d312dae..e3520faaf 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "os" + "strings" "sync/atomic" "github.com/caarlos0/env/v11" @@ -593,16 +594,18 @@ type ToolConfig struct { } type BraveConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` - APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` - MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"` + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` + APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEYS"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"` } type TavilyConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"` - APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"` - BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"` - MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"` + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"` + APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEYS"` + BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"` } type DuckDuckGoConfig struct { @@ -611,9 +614,10 @@ type DuckDuckGoConfig struct { } type PerplexityConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"` - APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"` - MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"` + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"` + APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEYS"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"` } type SearXNGConfig struct { @@ -933,6 +937,29 @@ func (c *Config) ValidateModelList() error { return nil } +func MergeAPIKeys(apiKey string, apiKeys []string) []string { + seen := make(map[string]struct{}) + var all []string + + if k := strings.TrimSpace(apiKey); k != "" { + if _, exists := seen[k]; !exists { + seen[k] = struct{}{} + all = append(all, k) + } + } + + for _, k := range apiKeys { + if trimmed := strings.TrimSpace(k); trimmed != "" { + if _, exists := seen[trimmed]; !exists { + seen[trimmed] = struct{}{} + all = append(all, trimmed) + } + } + } + + return all +} + func (t *ToolsConfig) IsToolEnabled(name string) bool { switch name { case "web": diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 47f79c6f0..8baf3e6fd 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -296,7 +296,7 @@ func TestDefaultConfig_WebTools(t *testing.T) { if cfg.Tools.Web.Brave.MaxResults != 5 { t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults) } - if cfg.Tools.Web.Brave.APIKey != "" { + if len(cfg.Tools.Web.Brave.APIKeys) != 0 { t.Error("Brave API key should be empty by default") } if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 0892d45f4..88cb254ad 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -384,6 +384,13 @@ func DefaultConfig() *Config { Brave: BraveConfig{ Enabled: false, APIKey: "", + APIKeys: nil, + MaxResults: 5, + }, + Tavily: TavilyConfig{ + Enabled: false, + APIKey: "", + APIKeys: nil, MaxResults: 5, }, DuckDuckGo: DuckDuckGoConfig{ @@ -393,6 +400,7 @@ func DefaultConfig() *Config { Perplexity: PerplexityConfig{ Enabled: false, APIKey: "", + APIKeys: nil, MaxResults: 5, }, SearXNG: SearXNGConfig{ diff --git a/pkg/migrate/sources/openclaw/openclaw_config.go b/pkg/migrate/sources/openclaw/openclaw_config.go index 19d63bb77..e272d17a9 100644 --- a/pkg/migrate/sources/openclaw/openclaw_config.go +++ b/pkg/migrate/sources/openclaw/openclaw_config.go @@ -733,16 +733,18 @@ type WebToolsConfig struct { } type BraveConfig struct { - Enabled bool `json:"enabled"` - APIKey string `json:"api_key"` - MaxResults int `json:"max_results"` + Enabled bool `json:"enabled"` + APIKey string `json:"api_key"` + APIKeys []string `json:"api_keys"` + MaxResults int `json:"max_results"` } type TavilyConfig struct { - Enabled bool `json:"enabled"` - APIKey string `json:"api_key"` - BaseURL string `json:"base_url"` - MaxResults int `json:"max_results"` + Enabled bool `json:"enabled"` + APIKey string `json:"api_key"` + APIKeys []string `json:"api_keys"` + BaseURL string `json:"base_url"` + MaxResults int `json:"max_results"` } type DuckDuckGoConfig struct { @@ -751,9 +753,10 @@ type DuckDuckGoConfig struct { } type PerplexityConfig struct { - Enabled bool `json:"enabled"` - APIKey string `json:"api_key"` - MaxResults int `json:"max_results"` + Enabled bool `json:"enabled"` + APIKey string `json:"api_key"` + APIKeys []string `json:"api_keys"` + MaxResults int `json:"max_results"` } type CronConfig struct { @@ -1082,6 +1085,7 @@ func (c ToolsConfig) ToStandardTools() config.ToolsConfig { Brave: config.BraveConfig{ Enabled: c.Web.Brave.Enabled, APIKey: c.Web.Brave.APIKey, + APIKeys: c.Web.Brave.APIKeys, MaxResults: c.Web.Brave.MaxResults, }, Tavily: config.TavilyConfig{ diff --git a/pkg/tools/web.go b/pkg/tools/web.go index eeceabd98..e248ea966 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -11,6 +11,7 @@ import ( "net/url" "regexp" "strings" + "sync/atomic" "time" ) @@ -76,81 +77,140 @@ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, err return client, nil } +type APIKeyPool struct { + keys []string + current uint32 +} + +func NewAPIKeyPool(keys []string) *APIKeyPool { + return &APIKeyPool{ + keys: keys, + } +} + +type APIKeyIterator struct { + pool *APIKeyPool + startIdx uint32 + attempt uint32 +} + +func (p *APIKeyPool) NewIterator() *APIKeyIterator { + if len(p.keys) == 0 { + return &APIKeyIterator{pool: p} + } + idx := atomic.AddUint32(&p.current, 1) - 1 + return &APIKeyIterator{ + pool: p, + startIdx: idx, + } +} + +func (it *APIKeyIterator) Next() (string, bool) { + length := uint32(len(it.pool.keys)) + if length == 0 || it.attempt >= length { + return "", false + } + key := it.pool.keys[(it.startIdx+it.attempt)%length] + it.attempt++ + return key, true +} + type SearchProvider interface { Search(ctx context.Context, query string, count int) (string, error) } type BraveSearchProvider struct { - apiKey string - proxy string - client *http.Client + keyPool *APIKeyPool + proxy string + client *http.Client } func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", url.QueryEscape(query), count) - req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } + var lastErr error + iter := p.keyPool.NewIterator() - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Subscription-Token", p.apiKey) - - resp, err := p.client.Do(req) - if err != nil { - return "", fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body)) - } - - var searchResp struct { - Web struct { - Results []struct { - Title string `json:"title"` - URL string `json:"url"` - Description string `json:"description"` - } `json:"results"` - } `json:"web"` - } - - if err := json.Unmarshal(body, &searchResp); err != nil { - // Log error body for debugging - fmt.Printf("Brave API Error Body: %s\n", string(body)) - return "", fmt.Errorf("failed to parse response: %w", err) - } - - results := searchResp.Web.Results - if len(results) == 0 { - return fmt.Sprintf("No results for: %s", query), nil - } - - var lines []string - lines = append(lines, fmt.Sprintf("Results for: %s", query)) - for i, item := range results { - if i >= count { + for { + apiKey, ok := iter.Next() + if !ok { break } - lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) - if item.Description != "" { - lines = append(lines, fmt.Sprintf(" %s", item.Description)) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) } + + req.Header.Set("Accept", "application/json") + req.Header.Set("X-Subscription-Token", apiKey) + + resp, err := p.client.Do(req) + if err != nil { + lastErr = fmt.Errorf("request failed: %w", err) + continue + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + lastErr = fmt.Errorf("failed to read response: %w", err) + continue + } + + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + if resp.StatusCode == http.StatusTooManyRequests || + resp.StatusCode == http.StatusUnauthorized || + resp.StatusCode == http.StatusForbidden || + resp.StatusCode >= 500 { + continue + } + return "", lastErr + } + + var searchResp struct { + Web struct { + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` + } `json:"results"` + } `json:"web"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + // Log error body for debugging + return "", fmt.Errorf("failed to parse response: %w", err) + } + + results := searchResp.Web.Results + if len(results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s", query)) + for i, item := range results { + if i >= count { + break + } + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) + if item.Description != "" { + lines = append(lines, fmt.Sprintf(" %s", item.Description)) + } + } + + return strings.Join(lines, "\n"), nil } - return strings.Join(lines, "\n"), nil + return "", fmt.Errorf("all api keys failed, last error: %w", lastErr) } type TavilySearchProvider struct { - apiKey string + keyPool *APIKeyPool baseURL string proxy string client *http.Client @@ -162,74 +222,96 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i searchURL = "https://api.tavily.com/search" } - payload := map[string]any{ - "api_key": p.apiKey, - "query": query, - "search_depth": "advanced", - "include_answer": false, - "include_images": false, - "include_raw_content": false, - "max_results": count, - } + var lastErr error + iter := p.keyPool.NewIterator() - bodyBytes, err := json.Marshal(payload) - if err != nil { - return "", fmt.Errorf("failed to marshal payload: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes)) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", userAgent) - - resp, err := p.client.Do(req) - if err != nil { - return "", fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body)) - } - - var searchResp struct { - Results []struct { - Title string `json:"title"` - URL string `json:"url"` - Content string `json:"content"` - } `json:"results"` - } - - if err := json.Unmarshal(body, &searchResp); err != nil { - return "", fmt.Errorf("failed to parse response: %w", err) - } - - results := searchResp.Results - if len(results) == 0 { - return fmt.Sprintf("No results for: %s", query), nil - } - - var lines []string - lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query)) - for i, item := range results { - if i >= count { + for { + apiKey, ok := iter.Next() + if !ok { break } - lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) - if item.Content != "" { - lines = append(lines, fmt.Sprintf(" %s", item.Content)) + + payload := map[string]any{ + "api_key": apiKey, + "query": query, + "search_depth": "advanced", + "include_answer": false, + "include_images": false, + "include_raw_content": false, + "max_results": count, } + + bodyBytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", userAgent) + + resp, err := p.client.Do(req) + if err != nil { + lastErr = fmt.Errorf("request failed: %w", err) + continue + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + lastErr = fmt.Errorf("failed to read response: %w", err) + continue + } + + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body)) + if resp.StatusCode == http.StatusTooManyRequests || + resp.StatusCode == http.StatusUnauthorized || + resp.StatusCode == http.StatusForbidden || + resp.StatusCode >= 500 { + continue + } + return "", lastErr + } + + var searchResp struct { + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` + } `json:"results"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + results := searchResp.Results + if len(results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query)) + for i, item := range results { + if i >= count { + break + } + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) + if item.Content != "" { + lines = append(lines, fmt.Sprintf(" %s", item.Content)) + } + } + + return strings.Join(lines, "\n"), nil } - return strings.Join(lines, "\n"), nil + return "", fmt.Errorf("all api keys failed, last error: %w", lastErr) } type DuckDuckGoSearchProvider struct { @@ -324,75 +406,97 @@ func stripTags(content string) string { } type PerplexitySearchProvider struct { - apiKey string - proxy string - client *http.Client + keyPool *APIKeyPool + proxy string + client *http.Client } func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { searchURL := "https://api.perplexity.ai/chat/completions" - payload := map[string]any{ - "model": "sonar", - "messages": []map[string]string{ - { - "role": "system", - "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.", + var lastErr error + iter := p.keyPool.NewIterator() + + for { + apiKey, ok := iter.Next() + if !ok { + break + } + + payload := map[string]any{ + "model": "sonar", + "messages": []map[string]string{ + { + "role": "system", + "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.", + }, + { + "role": "user", + "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count), + }, }, - { - "role": "user", - "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count), - }, - }, - "max_tokens": 1000, + "max_tokens": 1000, + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes))) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("User-Agent", userAgent) + + resp, err := p.client.Do(req) + if err != nil { + lastErr = fmt.Errorf("request failed: %w", err) + continue + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + + if err != nil { + lastErr = fmt.Errorf("failed to read response: %w", err) + continue + } + + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("Perplexity API error: %s", string(body)) + if resp.StatusCode == http.StatusTooManyRequests || + resp.StatusCode == http.StatusUnauthorized || + resp.StatusCode == http.StatusForbidden || + resp.StatusCode >= 500 { + continue + } + return "", lastErr + } + + var searchResp struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if len(searchResp.Choices) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return "", fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes))) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+p.apiKey) - req.Header.Set("User-Agent", userAgent) - - resp, err := p.client.Do(req) - if err != nil { - return "", fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("Perplexity API error: %s", string(body)) - } - - var searchResp struct { - Choices []struct { - Message struct { - Content string `json:"content"` - } `json:"message"` - } `json:"choices"` - } - - if err := json.Unmarshal(body, &searchResp); err != nil { - return "", fmt.Errorf("failed to parse response: %w", err) - } - - if len(searchResp.Choices) == 0 { - return fmt.Sprintf("No results for: %s", query), nil - } - - return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil + return "", fmt.Errorf("all api keys failed, last error: %w", lastErr) } type SearXNGSearchProvider struct { @@ -545,16 +649,16 @@ type WebSearchTool struct { } type WebSearchToolOptions struct { - BraveAPIKey string + BraveAPIKeys []string BraveMaxResults int BraveEnabled bool - TavilyAPIKey string + TavilyAPIKeys []string TavilyBaseURL string TavilyMaxResults int TavilyEnabled bool DuckDuckGoMaxResults int DuckDuckGoEnabled bool - PerplexityAPIKey string + PerplexityAPIKeys []string PerplexityMaxResults int PerplexityEnabled bool SearXNGBaseURL string @@ -571,23 +675,26 @@ type WebSearchToolOptions struct { func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { var provider SearchProvider maxResults := 5 - // Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search - if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { + if opts.PerplexityEnabled && len(opts.PerplexityAPIKeys) > 0 { client, err := createHTTPClient(opts.Proxy, perplexityTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err) } - provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client} + provider = &PerplexitySearchProvider{ + keyPool: NewAPIKeyPool(opts.PerplexityAPIKeys), + proxy: opts.Proxy, + client: client, + } if opts.PerplexityMaxResults > 0 { maxResults = opts.PerplexityMaxResults } - } else if opts.BraveEnabled && opts.BraveAPIKey != "" { + } else if opts.BraveEnabled && len(opts.BraveAPIKeys) > 0 { client, err := createHTTPClient(opts.Proxy, searchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err) } - provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client} + provider = &BraveSearchProvider{keyPool: NewAPIKeyPool(opts.BraveAPIKeys), proxy: opts.Proxy, client: client} if opts.BraveMaxResults > 0 { maxResults = opts.BraveMaxResults } @@ -596,13 +703,13 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { if opts.SearXNGMaxResults > 0 { maxResults = opts.SearXNGMaxResults } - } else if opts.TavilyEnabled && opts.TavilyAPIKey != "" { + } else if opts.TavilyEnabled && len(opts.TavilyAPIKeys) > 0 { client, err := createHTTPClient(opts.Proxy, searchTimeout) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err) } provider = &TavilySearchProvider{ - apiKey: opts.TavilyAPIKey, + keyPool: NewAPIKeyPool(opts.TavilyAPIKeys), baseURL: opts.TavilyBaseURL, proxy: opts.Proxy, client: client, diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index bdd30d385..188fb8adb 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -249,7 +249,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) { // TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing func TestWebTool_WebSearch_NoApiKey(t *testing.T) { - tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) + tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKeys: nil}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -269,7 +269,11 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) { // TestWebTool_WebSearch_MissingQuery verifies error handling for missing query func TestWebTool_WebSearch_MissingQuery(t *testing.T) { - tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) + tool, err := NewWebSearchTool(WebSearchToolOptions{ + BraveEnabled: true, + BraveAPIKeys: []string{"test-key"}, + BraveMaxResults: 5, + }) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -553,7 +557,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { t.Run("perplexity", func(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{ PerplexityEnabled: true, - PerplexityAPIKey: "k", + PerplexityAPIKeys: []string{"k"}, PerplexityMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) @@ -572,7 +576,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { t.Run("brave", func(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{ BraveEnabled: true, - BraveAPIKey: "k", + BraveAPIKeys: []string{"k"}, BraveMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) @@ -650,7 +654,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{ TavilyEnabled: true, - TavilyAPIKey: "test-key", + TavilyAPIKeys: []string{"test-key"}, TavilyBaseURL: server.URL, TavilyMaxResults: 5, }) @@ -682,6 +686,121 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { } } +func TestAPIKeyPool(t *testing.T) { + pool := NewAPIKeyPool([]string{"key1", "key2", "key3"}) + if len(pool.keys) != 3 { + t.Fatalf("expected 3 keys, got %d", len(pool.keys)) + } + if pool.keys[0] != "key1" || pool.keys[1] != "key2" || pool.keys[2] != "key3" { + t.Fatalf("unexpected keys: %v", pool.keys) + } + + // Test Iterator: each iterator should cover all keys exactly once + iter := pool.NewIterator() + expected := []string{"key1", "key2", "key3"} + for i, want := range expected { + k, ok := iter.Next() + if !ok { + t.Fatalf("iter.Next() returned false at step %d", i) + } + if k != want { + t.Errorf("step %d: expected %s, got %s", i, want, k) + } + } + // Should be exhausted + if _, ok := iter.Next(); ok { + t.Errorf("expected iterator exhausted after all keys") + } + + // Second iterator starts at next position (load balancing) + iter2 := pool.NewIterator() + k, ok := iter2.Next() + if !ok { + t.Fatal("iter2.Next() returned false") + } + if k != "key2" { + t.Errorf("expected key2 (round-robin), got %s", k) + } + + // Empty pool + emptyPool := NewAPIKeyPool([]string{}) + emptyIter := emptyPool.NewIterator() + if _, ok := emptyIter.Next(); ok { + t.Errorf("expected false for empty pool") + } + + // Single key pool + singlePool := NewAPIKeyPool([]string{"single"}) + singleIter := singlePool.NewIterator() + if k, ok := singleIter.Next(); !ok || k != "single" { + t.Errorf("expected single, got %s (ok=%v)", k, ok) + } + if _, ok := singleIter.Next(); ok { + t.Errorf("expected exhausted after single key") + } +} + +func TestWebTool_TavilySearch_Failover(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var payload map[string]any + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("failed to decode payload: %v", err) + } + + apiKey := payload["api_key"].(string) + + if apiKey == "key1" { + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("Rate limited")) + return + } + + if apiKey == "key2" { + // Success + response := map[string]any{ + "results": []map[string]any{ + { + "title": "Success Result", + "url": "https://example.com/success", + "content": "Success content", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) + return + } + + w.WriteHeader(http.StatusBadRequest) + })) + defer server.Close() + + tool, err := NewWebSearchTool(WebSearchToolOptions{ + TavilyEnabled: true, + TavilyAPIKeys: []string{"key1", "key2"}, + TavilyBaseURL: server.URL, + TavilyMaxResults: 5, + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + ctx := context.Background() + args := map[string]any{ + "query": "test query", + } + + result := tool.Execute(ctx, args) + + if result.IsError { + t.Errorf("Expected success, got Error: %s", result.ForLLM) + } + if !strings.Contains(result.ForUser, "Success Result") { + t.Errorf("Expected failover to second key and success result, got: %s", result.ForUser) + } +} + func TestWebTool_GLMSearch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" {