mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(web_search): add load balance and failover for api keys (#982)
* feat(web_search): add load balance and failover for api keys * feat(web_search): add load balance and failover for api keys * lint * new iter to get api key * deleted conflicts
This commit is contained in:
@@ -284,6 +284,9 @@
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"api_keys": [
|
||||
"YOUR_BRAVE_API_KEY"
|
||||
],
|
||||
"max_results": 5
|
||||
},
|
||||
"tavily": {
|
||||
@@ -298,7 +301,10 @@
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "",
|
||||
"api_key": "pplx-xxx",
|
||||
"api_keys": [
|
||||
"pplx-xxx"
|
||||
],
|
||||
"max_results": 5
|
||||
},
|
||||
"searxng": {
|
||||
|
||||
+6
-4
@@ -120,19 +120,21 @@ func registerSharedTools(
|
||||
continue
|
||||
}
|
||||
|
||||
// Web tools
|
||||
if cfg.Tools.IsToolEnabled("web") {
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Brave.APIKey, cfg.Tools.Web.Brave.APIKeys),
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
|
||||
TavilyAPIKeys: config.MergeAPIKeys(cfg.Tools.Web.Tavily.APIKey, cfg.Tools.Web.Tavily.APIKeys),
|
||||
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
|
||||
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
|
||||
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityAPIKeys: config.MergeAPIKeys(
|
||||
cfg.Tools.Web.Perplexity.APIKey,
|
||||
cfg.Tools.Web.Perplexity.APIKeys,
|
||||
),
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
|
||||
|
||||
+37
-10
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
@@ -593,16 +594,18 @@ type ToolConfig struct {
|
||||
}
|
||||
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEYS"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type TavilyConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
|
||||
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEYS"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type DuckDuckGoConfig struct {
|
||||
@@ -611,9 +614,10 @@ type DuckDuckGoConfig struct {
|
||||
}
|
||||
|
||||
type PerplexityConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
|
||||
APIKeys []string `json:"api_keys" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEYS"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type SearXNGConfig struct {
|
||||
@@ -933,6 +937,29 @@ func (c *Config) ValidateModelList() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func MergeAPIKeys(apiKey string, apiKeys []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var all []string
|
||||
|
||||
if k := strings.TrimSpace(apiKey); k != "" {
|
||||
if _, exists := seen[k]; !exists {
|
||||
seen[k] = struct{}{}
|
||||
all = append(all, k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, k := range apiKeys {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; !exists {
|
||||
seen[trimmed] = struct{}{}
|
||||
all = append(all, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
switch name {
|
||||
case "web":
|
||||
|
||||
@@ -296,7 +296,7 @@ func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
if cfg.Tools.Web.Brave.MaxResults != 5 {
|
||||
t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults)
|
||||
}
|
||||
if cfg.Tools.Web.Brave.APIKey != "" {
|
||||
if len(cfg.Tools.Web.Brave.APIKeys) != 0 {
|
||||
t.Error("Brave API key should be empty by default")
|
||||
}
|
||||
if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 {
|
||||
|
||||
@@ -384,6 +384,13 @@ func DefaultConfig() *Config {
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
APIKeys: nil,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Tavily: TavilyConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
APIKeys: nil,
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
@@ -393,6 +400,7 @@ func DefaultConfig() *Config {
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
APIKeys: nil,
|
||||
MaxResults: 5,
|
||||
},
|
||||
SearXNG: SearXNGConfig{
|
||||
|
||||
@@ -733,16 +733,18 @@ type WebToolsConfig struct {
|
||||
}
|
||||
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
MaxResults int `json:"max_results"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIKeys []string `json:"api_keys"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type TavilyConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
BaseURL string `json:"base_url"`
|
||||
MaxResults int `json:"max_results"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIKeys []string `json:"api_keys"`
|
||||
BaseURL string `json:"base_url"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type DuckDuckGoConfig struct {
|
||||
@@ -751,9 +753,10 @@ type DuckDuckGoConfig struct {
|
||||
}
|
||||
|
||||
type PerplexityConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
MaxResults int `json:"max_results"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"api_key"`
|
||||
APIKeys []string `json:"api_keys"`
|
||||
MaxResults int `json:"max_results"`
|
||||
}
|
||||
|
||||
type CronConfig struct {
|
||||
@@ -1082,6 +1085,7 @@ func (c ToolsConfig) ToStandardTools() config.ToolsConfig {
|
||||
Brave: config.BraveConfig{
|
||||
Enabled: c.Web.Brave.Enabled,
|
||||
APIKey: c.Web.Brave.APIKey,
|
||||
APIKeys: c.Web.Brave.APIKeys,
|
||||
MaxResults: c.Web.Brave.MaxResults,
|
||||
},
|
||||
Tavily: config.TavilyConfig{
|
||||
|
||||
+295
-188
@@ -11,6 +11,7 @@ import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -76,81 +77,140 @@ func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, err
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type APIKeyPool struct {
|
||||
keys []string
|
||||
current uint32
|
||||
}
|
||||
|
||||
func NewAPIKeyPool(keys []string) *APIKeyPool {
|
||||
return &APIKeyPool{
|
||||
keys: keys,
|
||||
}
|
||||
}
|
||||
|
||||
type APIKeyIterator struct {
|
||||
pool *APIKeyPool
|
||||
startIdx uint32
|
||||
attempt uint32
|
||||
}
|
||||
|
||||
func (p *APIKeyPool) NewIterator() *APIKeyIterator {
|
||||
if len(p.keys) == 0 {
|
||||
return &APIKeyIterator{pool: p}
|
||||
}
|
||||
idx := atomic.AddUint32(&p.current, 1) - 1
|
||||
return &APIKeyIterator{
|
||||
pool: p,
|
||||
startIdx: idx,
|
||||
}
|
||||
}
|
||||
|
||||
func (it *APIKeyIterator) Next() (string, bool) {
|
||||
length := uint32(len(it.pool.keys))
|
||||
if length == 0 || it.attempt >= length {
|
||||
return "", false
|
||||
}
|
||||
key := it.pool.keys[(it.startIdx+it.attempt)%length]
|
||||
it.attempt++
|
||||
return key, true
|
||||
}
|
||||
|
||||
type SearchProvider interface {
|
||||
Search(ctx context.Context, query string, count int) (string, error)
|
||||
}
|
||||
|
||||
type BraveSearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
keyPool *APIKeyPool
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
|
||||
url.QueryEscape(query), count)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
var lastErr error
|
||||
iter := p.keyPool.NewIterator()
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
// Log error body for debugging
|
||||
fmt.Printf("Brave API Error Body: %s\n", string(body))
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Web.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
for {
|
||||
apiKey, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Description != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Description))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("request failed: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read response: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
if resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusUnauthorized ||
|
||||
resp.StatusCode == http.StatusForbidden ||
|
||||
resp.StatusCode >= 500 {
|
||||
continue
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
// Log error body for debugging
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Web.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Description != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Description))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
type TavilySearchProvider struct {
|
||||
apiKey string
|
||||
keyPool *APIKeyPool
|
||||
baseURL string
|
||||
proxy string
|
||||
client *http.Client
|
||||
@@ -162,74 +222,96 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
|
||||
searchURL = "https://api.tavily.com/search"
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"api_key": p.apiKey,
|
||||
"query": query,
|
||||
"search_depth": "advanced",
|
||||
"include_answer": false,
|
||||
"include_images": false,
|
||||
"include_raw_content": false,
|
||||
"max_results": count,
|
||||
}
|
||||
var lastErr error
|
||||
iter := p.keyPool.NewIterator()
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
for {
|
||||
apiKey, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
|
||||
payload := map[string]any{
|
||||
"api_key": apiKey,
|
||||
"query": query,
|
||||
"search_depth": "advanced",
|
||||
"include_answer": false,
|
||||
"include_images": false,
|
||||
"include_raw_content": false,
|
||||
"max_results": count,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("request failed: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read response: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
|
||||
if resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusUnauthorized ||
|
||||
resp.StatusCode == http.StatusForbidden ||
|
||||
resp.StatusCode >= 500 {
|
||||
continue
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
results := searchResp.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
var lines []string
|
||||
lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
|
||||
for i, item := range results {
|
||||
if i >= count {
|
||||
break
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
|
||||
if item.Content != "" {
|
||||
lines = append(lines, fmt.Sprintf(" %s", item.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
type DuckDuckGoSearchProvider struct {
|
||||
@@ -324,75 +406,97 @@ func stripTags(content string) string {
|
||||
}
|
||||
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
proxy string
|
||||
client *http.Client
|
||||
keyPool *APIKeyPool
|
||||
proxy string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
payload := map[string]any{
|
||||
"model": "sonar",
|
||||
"messages": []map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
|
||||
var lastErr error
|
||||
iter := p.keyPool.NewIterator()
|
||||
|
||||
for {
|
||||
apiKey, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"model": "sonar",
|
||||
"messages": []map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
|
||||
},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
|
||||
},
|
||||
},
|
||||
"max_tokens": 1000,
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("request failed: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read response: %w", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("Perplexity API error: %s", string(body))
|
||||
if resp.StatusCode == http.StatusTooManyRequests ||
|
||||
resp.StatusCode == http.StatusUnauthorized ||
|
||||
resp.StatusCode == http.StatusForbidden ||
|
||||
resp.StatusCode >= 500 {
|
||||
continue
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(searchResp.Choices) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("Perplexity API error: %s", string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(searchResp.Choices) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
return "", fmt.Errorf("all api keys failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
type SearXNGSearchProvider struct {
|
||||
@@ -545,16 +649,16 @@ type WebSearchTool struct {
|
||||
}
|
||||
|
||||
type WebSearchToolOptions struct {
|
||||
BraveAPIKey string
|
||||
BraveAPIKeys []string
|
||||
BraveMaxResults int
|
||||
BraveEnabled bool
|
||||
TavilyAPIKey string
|
||||
TavilyAPIKeys []string
|
||||
TavilyBaseURL string
|
||||
TavilyMaxResults int
|
||||
TavilyEnabled bool
|
||||
DuckDuckGoMaxResults int
|
||||
DuckDuckGoEnabled bool
|
||||
PerplexityAPIKey string
|
||||
PerplexityAPIKeys []string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
SearXNGBaseURL string
|
||||
@@ -571,23 +675,26 @@ type WebSearchToolOptions struct {
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
if opts.PerplexityEnabled && len(opts.PerplexityAPIKeys) > 0 {
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
|
||||
}
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
|
||||
provider = &PerplexitySearchProvider{
|
||||
keyPool: NewAPIKeyPool(opts.PerplexityAPIKeys),
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
} else if opts.BraveEnabled && len(opts.BraveAPIKeys) > 0 {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
|
||||
}
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
|
||||
provider = &BraveSearchProvider{keyPool: NewAPIKeyPool(opts.BraveAPIKeys), proxy: opts.Proxy, client: client}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
@@ -596,13 +703,13 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.SearXNGMaxResults > 0 {
|
||||
maxResults = opts.SearXNGMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
} else if opts.TavilyEnabled && len(opts.TavilyAPIKeys) > 0 {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
|
||||
}
|
||||
provider = &TavilySearchProvider{
|
||||
apiKey: opts.TavilyAPIKey,
|
||||
keyPool: NewAPIKeyPool(opts.TavilyAPIKeys),
|
||||
baseURL: opts.TavilyBaseURL,
|
||||
proxy: opts.Proxy,
|
||||
client: client,
|
||||
|
||||
+124
-5
@@ -249,7 +249,7 @@ func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKeys: nil})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -269,7 +269,11 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKeys: []string{"test-key"},
|
||||
BraveMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -553,7 +557,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("perplexity", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
PerplexityEnabled: true,
|
||||
PerplexityAPIKey: "k",
|
||||
PerplexityAPIKeys: []string{"k"},
|
||||
PerplexityMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
@@ -572,7 +576,7 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
||||
t.Run("brave", func(t *testing.T) {
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
BraveEnabled: true,
|
||||
BraveAPIKey: "k",
|
||||
BraveAPIKeys: []string{"k"},
|
||||
BraveMaxResults: 3,
|
||||
Proxy: "http://127.0.0.1:7890",
|
||||
})
|
||||
@@ -650,7 +654,7 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKey: "test-key",
|
||||
TavilyAPIKeys: []string{"test-key"},
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
@@ -682,6 +686,121 @@ func TestWebTool_TavilySearch_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyPool(t *testing.T) {
|
||||
pool := NewAPIKeyPool([]string{"key1", "key2", "key3"})
|
||||
if len(pool.keys) != 3 {
|
||||
t.Fatalf("expected 3 keys, got %d", len(pool.keys))
|
||||
}
|
||||
if pool.keys[0] != "key1" || pool.keys[1] != "key2" || pool.keys[2] != "key3" {
|
||||
t.Fatalf("unexpected keys: %v", pool.keys)
|
||||
}
|
||||
|
||||
// Test Iterator: each iterator should cover all keys exactly once
|
||||
iter := pool.NewIterator()
|
||||
expected := []string{"key1", "key2", "key3"}
|
||||
for i, want := range expected {
|
||||
k, ok := iter.Next()
|
||||
if !ok {
|
||||
t.Fatalf("iter.Next() returned false at step %d", i)
|
||||
}
|
||||
if k != want {
|
||||
t.Errorf("step %d: expected %s, got %s", i, want, k)
|
||||
}
|
||||
}
|
||||
// Should be exhausted
|
||||
if _, ok := iter.Next(); ok {
|
||||
t.Errorf("expected iterator exhausted after all keys")
|
||||
}
|
||||
|
||||
// Second iterator starts at next position (load balancing)
|
||||
iter2 := pool.NewIterator()
|
||||
k, ok := iter2.Next()
|
||||
if !ok {
|
||||
t.Fatal("iter2.Next() returned false")
|
||||
}
|
||||
if k != "key2" {
|
||||
t.Errorf("expected key2 (round-robin), got %s", k)
|
||||
}
|
||||
|
||||
// Empty pool
|
||||
emptyPool := NewAPIKeyPool([]string{})
|
||||
emptyIter := emptyPool.NewIterator()
|
||||
if _, ok := emptyIter.Next(); ok {
|
||||
t.Errorf("expected false for empty pool")
|
||||
}
|
||||
|
||||
// Single key pool
|
||||
singlePool := NewAPIKeyPool([]string{"single"})
|
||||
singleIter := singlePool.NewIterator()
|
||||
if k, ok := singleIter.Next(); !ok || k != "single" {
|
||||
t.Errorf("expected single, got %s (ok=%v)", k, ok)
|
||||
}
|
||||
if _, ok := singleIter.Next(); ok {
|
||||
t.Errorf("expected exhausted after single key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_TavilySearch_Failover(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var payload map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
}
|
||||
|
||||
apiKey := payload["api_key"].(string)
|
||||
|
||||
if apiKey == "key1" {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte("Rate limited"))
|
||||
return
|
||||
}
|
||||
|
||||
if apiKey == "key2" {
|
||||
// Success
|
||||
response := map[string]any{
|
||||
"results": []map[string]any{
|
||||
{
|
||||
"title": "Success Result",
|
||||
"url": "https://example.com/success",
|
||||
"content": "Success content",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(response)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
||||
TavilyEnabled: true,
|
||||
TavilyAPIKeys: []string{"key1", "key2"},
|
||||
TavilyBaseURL: server.URL,
|
||||
TavilyMaxResults: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewWebSearchTool() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "test query",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got Error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "Success Result") {
|
||||
t.Errorf("Expected failover to second key and success result, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebTool_GLMSearch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
|
||||
Reference in New Issue
Block a user