feat(web_search): add load balance and failover for api keys (#982)

* feat(web_search): add load balance and failover for api keys

* feat(web_search): add load balance and failover for api keys

* lint

* new iter to get api key

* deleted conflicts
This commit is contained in:
yanhool
2026-03-10 16:34:11 +08:00
committed by GitHub
parent 26f623ed32
commit 95716b106b
8 changed files with 492 additions and 219 deletions
+7 -1
View File
@@ -284,6 +284,9 @@
"brave": {
"enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"api_keys": [
"YOUR_BRAVE_API_KEY"
],
"max_results": 5
},
"tavily": {
@@ -298,7 +301,10 @@
},
"perplexity": {
"enabled": false,
"api_key": "",
"api_key": "pplx-xxx",
"api_keys": [
"pplx-xxx"
],
"max_results": 5
},
"searxng": {
+6 -4
View File
@@ -120,19 +120,21 @@ func registerSharedTools(
continue
}
// Web tools
if cfg.Tools.IsToolEnabled("web") {
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Brave.APIKey, cfg.Tools.Web.Brave.APIKeys),
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Tavily.APIKey, cfg.Tools.Web.Tavily.APIKeys),
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityAPIKeys: config.MergeAPIKeys(
cfg.Tools.Web.Perplexity.APIKey,
cfg.Tools.Web.Perplexity.APIKeys,
),
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
+37 -10
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"os"
"strings"
"sync/atomic"
"github.com/caarlos0/env/v11"
@@ -593,16 +594,18 @@ type ToolConfig struct {
}
type BraveConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEYS"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
}
type TavilyConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEYS"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
}
type DuckDuckGoConfig struct {
@@ -611,9 +614,10 @@ type DuckDuckGoConfig struct {
}
type PerplexityConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEYS"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type SearXNGConfig struct {
@@ -933,6 +937,29 @@ func (c *Config) ValidateModelList() error {
return nil
}
func MergeAPIKeys(apiKey string, apiKeys []string) []string {
seen := make(map[string]struct{})
var all []string
if k := strings.TrimSpace(apiKey); k != "" {
if _, exists := seen[k]; !exists {
seen[k] = struct{}{}
all = append(all, k)
}
}
for _, k := range apiKeys {
if trimmed := strings.TrimSpace(k); trimmed != "" {
if _, exists := seen[trimmed]; !exists {
seen[trimmed] = struct{}{}
all = append(all, trimmed)
}
}
}
return all
}
func (t *ToolsConfig) IsToolEnabled(name string) bool {
switch name {
case "web":
+1 -1
View File
@@ -296,7 +296,7 @@ func TestDefaultConfig_WebTools(t *testing.T) {
if cfg.Tools.Web.Brave.MaxResults != 5 {
t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults)
}
if cfg.Tools.Web.Brave.APIKey != "" {
if len(cfg.Tools.Web.Brave.APIKeys) != 0 {
t.Error("Brave API key should be empty by default")
}
if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 {
+8
View File
@@ -384,6 +384,13 @@ func DefaultConfig() *Config {
Brave: BraveConfig{
Enabled: false,
APIKey: "",
APIKeys: nil,
MaxResults: 5,
},
Tavily: TavilyConfig{
Enabled: false,
APIKey: "",
APIKeys: nil,
MaxResults: 5,
},
DuckDuckGo: DuckDuckGoConfig{
@@ -393,6 +400,7 @@ func DefaultConfig() *Config {
Perplexity: PerplexityConfig{
Enabled: false,
APIKey: "",
APIKeys: nil,
MaxResults: 5,
},
SearXNG: SearXNGConfig{
+14 -10
View File
@@ -733,16 +733,18 @@ type WebToolsConfig struct {
}
type BraveConfig struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
MaxResults int `json:"max_results"`
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
APIKeys []string `json:"api_keys"`
MaxResults int `json:"max_results"`
}
type TavilyConfig struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
BaseURL string `json:"base_url"`
MaxResults int `json:"max_results"`
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
APIKeys []string `json:"api_keys"`
BaseURL string `json:"base_url"`
MaxResults int `json:"max_results"`
}
type DuckDuckGoConfig struct {
@@ -751,9 +753,10 @@ type DuckDuckGoConfig struct {
}
type PerplexityConfig struct {
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
MaxResults int `json:"max_results"`
Enabled bool `json:"enabled"`
APIKey string `json:"api_key"`
APIKeys []string `json:"api_keys"`
MaxResults int `json:"max_results"`
}
type CronConfig struct {
@@ -1082,6 +1085,7 @@ func (c ToolsConfig) ToStandardTools() config.ToolsConfig {
Brave: config.BraveConfig{
Enabled: c.Web.Brave.Enabled,
APIKey: c.Web.Brave.APIKey,
APIKeys: c.Web.Brave.APIKeys,
MaxResults: c.Web.Brave.MaxResults,
},
Tavily: config.TavilyConfig{
+295 -188
View File
@@ -11,6 +11,7 @@ import (
"net/url"
"regexp"
"strings"
"sync/atomic"
"time"
)
@@ -76,81 +77,140 @@ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, err
return client, nil
}
type APIKeyPool struct {
keys []string
current uint32
}
func NewAPIKeyPool(keys []string) *APIKeyPool {
return &APIKeyPool{
keys: keys,
}
}
type APIKeyIterator struct {
pool *APIKeyPool
startIdx uint32
attempt uint32
}
func (p *APIKeyPool) NewIterator() *APIKeyIterator {
if len(p.keys) == 0 {
return &APIKeyIterator{pool: p}
}
idx := atomic.AddUint32(&p.current, 1) - 1
return &APIKeyIterator{
pool: p,
startIdx: idx,
}
}
func (it *APIKeyIterator) Next() (string, bool) {
length := uint32(len(it.pool.keys))
if length == 0 || it.attempt >= length {
return "", false
}
key := it.pool.keys[(it.startIdx+it.attempt)%length]
it.attempt++
return key, true
}
type SearchProvider interface {
Search(ctx context.Context, query string, count int) (string, error)
}
type BraveSearchProvider struct {
apiKey string
proxy string
client *http.Client
keyPool *APIKeyPool
proxy string
client *http.Client
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(query), count)
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
var lastErr error
iter := p.keyPool.NewIterator()
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
// Log error body for debugging
fmt.Printf("Brave API Error Body: %s\n", string(body))
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
for {
apiKey, ok := iter.Next()
if !ok {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", apiKey)
resp, err := p.client.Do(req)
if err != nil {
lastErr = fmt.Errorf("request failed: %w", err)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read response: %w", err)
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusUnauthorized ||
resp.StatusCode == http.StatusForbidden ||
resp.StatusCode >= 500 {
continue
}
return "", lastErr
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
// Log error body for debugging
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
}
}
return strings.Join(lines, "\n"), nil
}
return strings.Join(lines, "\n"), nil
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
}
type TavilySearchProvider struct {
apiKey string
keyPool *APIKeyPool
baseURL string
proxy string
client *http.Client
@@ -162,74 +222,96 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
searchURL = "https://api.tavily.com/search"
}
payload := map[string]any{
"api_key": p.apiKey,
"query": query,
"search_depth": "advanced",
"include_answer": false,
"include_images": false,
"include_raw_content": false,
"max_results": count,
}
var lastErr error
iter := p.keyPool.NewIterator()
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
}
var searchResp struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
} `json:"results"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
for i, item := range results {
if i >= count {
for {
apiKey, ok := iter.Next()
if !ok {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
payload := map[string]any{
"api_key": apiKey,
"query": query,
"search_depth": "advanced",
"include_answer": false,
"include_images": false,
"include_raw_content": false,
"max_results": count,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
lastErr = fmt.Errorf("request failed: %w", err)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read response: %w", err)
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusUnauthorized ||
resp.StatusCode == http.StatusForbidden ||
resp.StatusCode >= 500 {
continue
}
return "", lastErr
}
var searchResp struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
} `json:"results"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Content != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Content))
}
}
return strings.Join(lines, "\n"), nil
}
return strings.Join(lines, "\n"), nil
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
}
type DuckDuckGoSearchProvider struct {
@@ -324,75 +406,97 @@ func stripTags(content string) string {
}
type PerplexitySearchProvider struct {
apiKey string
proxy string
client *http.Client
keyPool *APIKeyPool
proxy string
client *http.Client
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := "https://api.perplexity.ai/chat/completions"
payload := map[string]any{
"model": "sonar",
"messages": []map[string]string{
{
"role": "system",
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
var lastErr error
iter := p.keyPool.NewIterator()
for {
apiKey, ok := iter.Next()
if !ok {
break
}
payload := map[string]any{
"model": "sonar",
"messages": []map[string]string{
{
"role": "system",
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
},
{
"role": "user",
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
},
},
{
"role": "user",
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
},
},
"max_tokens": 1000,
"max_tokens": 1000,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
lastErr = fmt.Errorf("request failed: %w", err)
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
lastErr = fmt.Errorf("failed to read response: %w", err)
continue
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("Perplexity API error: %s", string(body))
if resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusUnauthorized ||
resp.StatusCode == http.StatusForbidden ||
resp.StatusCode >= 500 {
continue
}
return "", lastErr
}
var searchResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(searchResp.Choices) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("Perplexity API error: %s", string(body))
}
var searchResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(searchResp.Choices) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
}
type SearXNGSearchProvider struct {
@@ -545,16 +649,16 @@ type WebSearchTool struct {
}
type WebSearchToolOptions struct {
BraveAPIKey string
BraveAPIKeys []string
BraveMaxResults int
BraveEnabled bool
TavilyAPIKey string
TavilyAPIKeys []string
TavilyBaseURL string
TavilyMaxResults int
TavilyEnabled bool
DuckDuckGoMaxResults int
DuckDuckGoEnabled bool
PerplexityAPIKey string
PerplexityAPIKeys []string
PerplexityMaxResults int
PerplexityEnabled bool
SearXNGBaseURL string
@@ -571,23 +675,26 @@ type WebSearchToolOptions struct {
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
if opts.PerplexityEnabled && len(opts.PerplexityAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
}
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
provider = &PerplexitySearchProvider{
keyPool: NewAPIKeyPool(opts.PerplexityAPIKeys),
proxy: opts.Proxy,
client: client,
}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
} else if opts.BraveEnabled && len(opts.BraveAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
}
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
provider = &BraveSearchProvider{keyPool: NewAPIKeyPool(opts.BraveAPIKeys), proxy: opts.Proxy, client: client}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
@@ -596,13 +703,13 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.SearXNGMaxResults > 0 {
maxResults = opts.SearXNGMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
} else if opts.TavilyEnabled && len(opts.TavilyAPIKeys) > 0 {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
}
provider = &TavilySearchProvider{
apiKey: opts.TavilyAPIKey,
keyPool: NewAPIKeyPool(opts.TavilyAPIKeys),
baseURL: opts.TavilyBaseURL,
proxy: opts.Proxy,
client: client,
+124 -5
View File
@@ -249,7 +249,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKeys: nil})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
@@ -269,7 +269,11 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
tool, err := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKeys: []string{"test-key"},
BraveMaxResults: 5,
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
@@ -553,7 +557,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("perplexity", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
PerplexityEnabled: true,
PerplexityAPIKey: "k",
PerplexityAPIKeys: []string{"k"},
PerplexityMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
@@ -572,7 +576,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
t.Run("brave", func(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
BraveEnabled: true,
BraveAPIKey: "k",
BraveAPIKeys: []string{"k"},
BraveMaxResults: 3,
Proxy: "http://127.0.0.1:7890",
})
@@ -650,7 +654,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
tool, err := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKey: "test-key",
TavilyAPIKeys: []string{"test-key"},
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
@@ -682,6 +686,121 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
}
}
func TestAPIKeyPool(t *testing.T) {
pool := NewAPIKeyPool([]string{"key1", "key2", "key3"})
if len(pool.keys) != 3 {
t.Fatalf("expected 3 keys, got %d", len(pool.keys))
}
if pool.keys[0] != "key1" || pool.keys[1] != "key2" || pool.keys[2] != "key3" {
t.Fatalf("unexpected keys: %v", pool.keys)
}
// Test Iterator: each iterator should cover all keys exactly once
iter := pool.NewIterator()
expected := []string{"key1", "key2", "key3"}
for i, want := range expected {
k, ok := iter.Next()
if !ok {
t.Fatalf("iter.Next() returned false at step %d", i)
}
if k != want {
t.Errorf("step %d: expected %s, got %s", i, want, k)
}
}
// Should be exhausted
if _, ok := iter.Next(); ok {
t.Errorf("expected iterator exhausted after all keys")
}
// Second iterator starts at next position (load balancing)
iter2 := pool.NewIterator()
k, ok := iter2.Next()
if !ok {
t.Fatal("iter2.Next() returned false")
}
if k != "key2" {
t.Errorf("expected key2 (round-robin), got %s", k)
}
// Empty pool
emptyPool := NewAPIKeyPool([]string{})
emptyIter := emptyPool.NewIterator()
if _, ok := emptyIter.Next(); ok {
t.Errorf("expected false for empty pool")
}
// Single key pool
singlePool := NewAPIKeyPool([]string{"single"})
singleIter := singlePool.NewIterator()
if k, ok := singleIter.Next(); !ok || k != "single" {
t.Errorf("expected single, got %s (ok=%v)", k, ok)
}
if _, ok := singleIter.Next(); ok {
t.Errorf("expected exhausted after single key")
}
}
func TestWebTool_TavilySearch_Failover(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload map[string]any
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("failed to decode payload: %v", err)
}
apiKey := payload["api_key"].(string)
if apiKey == "key1" {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("Rate limited"))
return
}
if apiKey == "key2" {
// Success
response := map[string]any{
"results": []map[string]any{
{
"title": "Success Result",
"url": "https://example.com/success",
"content": "Success content",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
tool, err := NewWebSearchTool(WebSearchToolOptions{
TavilyEnabled: true,
TavilyAPIKeys: []string{"key1", "key2"},
TavilyBaseURL: server.URL,
TavilyMaxResults: 5,
})
if err != nil {
t.Fatalf("NewWebSearchTool() error: %v", err)
}
ctx := context.Background()
args := map[string]any{
"query": "test query",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success, got Error: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "Success Result") {
t.Errorf("Expected failover to second key and success result, got: %s", result.ForUser)
}
}
func TestWebTool_GLMSearch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {