diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 43b1c1402..116f0ed60 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -768,22 +768,26 @@ func (p *ExaSearchProvider) Search(ctx context.Context, query string, count int) return "", fmt.Errorf("exa: parse error: %w", err) } - var sb strings.Builder + if len(result.Results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via Exa)", query)) maxResults := count if maxResults > len(result.Results) { maxResults = len(result.Results) } for i, r := range result.Results[:maxResults] { - sb.WriteString(fmt.Sprintf("%d. %s\n URL: %s\n", i+1, r.Title, r.URL)) + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, r.Title, r.URL)) if r.Text != "" { snippet := r.Text if len(snippet) > 200 { snippet = snippet[:200] + "..." } - sb.WriteString(fmt.Sprintf(" %s\n", snippet)) + lines = append(lines, fmt.Sprintf(" %s", snippet)) } - sb.WriteString("\n") } - return sb.String(), nil + return strings.Join(lines, "\n"), nil } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 8a8b88131..896b39a33 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -681,3 +682,218 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser) } } + +func TestNewWebSearchTool_ExaPriority(t *testing.T) { + // Exa should be selected when enabled with API key + tool, err := NewWebSearchTool(WebSearchToolOptions{ + ExaEnabled: true, + ExaAPIKey: "exa-key", + ExaMaxResults: 3, + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + if tool == nil { + t.Fatal("Expected non-nil tool when Exa is enabled with API key") + } + if _, ok := tool.provider.(*ExaSearchProvider); !ok { + t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider) + } + if tool.maxResults != 3 { + t.Fatalf("maxResults = %d, want 3", tool.maxResults) + } + + // Exa enabled but no API key should fall through + tool, err = NewWebSearchTool(WebSearchToolOptions{ + ExaEnabled: true, + ExaAPIKey: "", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + if tool != nil { + t.Errorf("Expected nil tool when Exa API key is empty and no other provider enabled") + } + + // Perplexity should take priority over Exa + tool, err = NewWebSearchTool(WebSearchToolOptions{ + PerplexityEnabled: true, + PerplexityAPIKey: "perp-key", + ExaEnabled: true, + ExaAPIKey: "exa-key", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + if _, ok := tool.provider.(*PerplexitySearchProvider); !ok { + t.Fatalf("provider type = %T, want *PerplexitySearchProvider (Perplexity should outrank Exa)", tool.provider) + } + + // Exa should take priority over Brave + tool, err = NewWebSearchTool(WebSearchToolOptions{ + ExaEnabled: true, + ExaAPIKey: "exa-key", + BraveEnabled: true, + BraveAPIKey: "brave-key", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + if _, ok := tool.provider.(*ExaSearchProvider); !ok { + t.Fatalf("provider type = %T, want *ExaSearchProvider (Exa should outrank Brave)", tool.provider) + } +} + +func TestNewWebSearchTool_ExaProxyPropagation(t *testing.T) { + tool, err := NewWebSearchTool(WebSearchToolOptions{ + ExaEnabled: true, + ExaAPIKey: "k", + Proxy: "http://127.0.0.1:7890", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + p, ok := tool.provider.(*ExaSearchProvider) + if !ok { + t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider) + } + if p.proxy != "http://127.0.0.1:7890" { + t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890") + } +} + +func TestExaSearchProvider_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + } + if r.Header.Get("x-api-key") != "test-exa-key" { + t.Errorf("Expected x-api-key test-exa-key, got %s", r.Header.Get("x-api-key")) + } + + // Verify payload + body, _ := io.ReadAll(r.Body) + var payload map[string]any + json.Unmarshal(body, &payload) + if payload["query"] != "test query" { + t.Errorf("Expected query 'test query', got %v", payload["query"]) + } + if payload["type"] != "neural" { + t.Errorf("Expected type 'neural', got %v", payload["type"]) + } + + response := map[string]any{ + "results": []map[string]any{ + {"title": "Exa Result 1", "url": "https://exa.ai/1", "text": "First result text"}, + {"title": "Exa Result 2", "url": "https://exa.ai/2", "text": "Second result text"}, + {"title": "Exa Result 3", "url": "https://exa.ai/3", "text": "Third result text"}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + provider := &ExaSearchProvider{ + apiKey: "test-exa-key", + client: &http.Client{}, + } + + // Temporarily override the API URL by using a custom transport + provider.client.Transport = rewriteHostTransport(server.URL) + + result, err := provider.Search(context.Background(), "test query", 5) + if err != nil { + t.Fatalf("Search() error: %v", err) + } + + if !strings.Contains(result, "via Exa") { + t.Errorf("Expected '(via Exa)' attribution, got: %s", result) + } + if !strings.Contains(result, "Exa Result 1") || !strings.Contains(result, "https://exa.ai/1") { + t.Errorf("Expected results in output, got: %s", result) + } + if !strings.Contains(result, "First result text") { + t.Errorf("Expected snippet text in output, got: %s", result) + } +} + +func TestExaSearchProvider_EmptyResults(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := map[string]any{"results": []map[string]any{}} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + provider := &ExaSearchProvider{ + apiKey: "test-key", + client: &http.Client{Transport: rewriteHostTransport(server.URL)}, + } + + result, err := provider.Search(context.Background(), "no results query", 5) + if err != nil { + t.Fatalf("Search() error: %v", err) + } + if !strings.Contains(result, "No results for: no results query") { + t.Errorf("Expected 'No results' message, got: %s", result) + } +} + +func TestExaSearchProvider_MaxResultsCapping(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return 5 results + results := make([]map[string]any, 5) + for i := range results { + results[i] = map[string]any{ + "title": fmt.Sprintf("Result %d", i+1), + "url": fmt.Sprintf("https://exa.ai/%d", i+1), + "text": fmt.Sprintf("Text %d", i+1), + } + } + response := map[string]any{"results": results} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + provider := &ExaSearchProvider{ + apiKey: "test-key", + client: &http.Client{Transport: rewriteHostTransport(server.URL)}, + } + + // Request only 2 results even though API returns 5 + result, err := provider.Search(context.Background(), "test", 2) + if err != nil { + t.Fatalf("Search() error: %v", err) + } + + if !strings.Contains(result, "Result 1") || !strings.Contains(result, "Result 2") { + t.Errorf("Expected first 2 results, got: %s", result) + } + if strings.Contains(result, "Result 3") { + t.Errorf("Expected results capped at 2, but got Result 3 in output: %s", result) + } +} + +// rewriteHostTransport returns an http.RoundTripper that redirects all requests to the given target URL. +func rewriteHostTransport(target string) http.RoundTripper { + return roundTripFunc(func(req *http.Request) (*http.Response, error) { + newURL := target + req.URL.Path + newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body) + if err != nil { + return nil, err + } + newReq.Header = req.Header + return http.DefaultClient.Do(newReq) + }) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}