mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
8219b5a26f
- Add explicit empty-results handling ("No results for: <query>")
- Add "Results for: <query> (via Exa)" header and align per-result
format with Brave/Tavily/DuckDuckGo/Perplexity
- Add tests: provider priority (Perplexity > Exa > Brave), proxy
propagation, successful search with header/attribution, empty
results, and max-results capping
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
900 lines
27 KiB
Go
900 lines
27 KiB
Go
package tools
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
const testFetchLimit = int64(10 * 1024 * 1024)
|
|
|
|
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
|
func TestWebTool_WebFetch_Success(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain the fetched content (full JSON result)
|
|
if !strings.Contains(result.ForLLM, "Test Page") {
|
|
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForUser should contain summary
|
|
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
|
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_JSON verifies JSON content handling
|
|
func TestWebTool_WebFetch_JSON(t *testing.T) {
|
|
testData := map[string]string{"key": "value", "number": "123"}
|
|
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(expectedJSON)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain formatted JSON
|
|
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
|
|
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
|
|
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": "not-a-valid-url",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error for invalid URL")
|
|
}
|
|
|
|
// Should contain error message (either "invalid URL" or scheme error)
|
|
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
|
|
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
|
|
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": "ftp://example.com/file.txt",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error for unsupported URL scheme")
|
|
}
|
|
|
|
// Should mention only http/https allowed
|
|
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
|
|
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
|
|
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error when URL is missing")
|
|
}
|
|
|
|
// Should mention URL is required
|
|
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
|
|
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_Truncation verifies content truncation
|
|
func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
|
longContent := strings.Repeat("x", 20000)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(longContent))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain truncated content (not the full 20000 chars)
|
|
resultMap := make(map[string]any)
|
|
json.Unmarshal([]byte(result.ForLLM), &resultMap)
|
|
if text, ok := resultMap["text"].(string); ok {
|
|
if len(text) > 1100 { // Allow some margin
|
|
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
|
}
|
|
}
|
|
|
|
// Should be marked as truncated
|
|
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
|
t.Errorf("Expected 'truncated' to be true in result")
|
|
}
|
|
}
|
|
|
|
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
|
// Create a mock HTTP server
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Generate a payload intentionally larger than our limit.
|
|
// Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'.
|
|
largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100)
|
|
|
|
w.Write(largeData)
|
|
}))
|
|
// Ensure the server is shut down at the end of the test
|
|
defer ts.Close()
|
|
|
|
// Initialize the tool
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
// Prepare the arguments pointing to the URL of our local mock server
|
|
args := map[string]any{
|
|
"url": ts.URL,
|
|
}
|
|
|
|
// Execute the tool
|
|
ctx := context.Background()
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Assuming ErrorResult sets the ForLLM field with the error text.
|
|
if result == nil {
|
|
t.Fatal("expected a ToolResult, got nil")
|
|
}
|
|
|
|
// Search for the exact error string we set earlier in the Execute method
|
|
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
|
|
|
|
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
|
|
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
|
|
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if tool != nil {
|
|
t.Errorf("Expected nil tool when Brave API key is empty")
|
|
}
|
|
|
|
// Also nil when nothing is enabled
|
|
tool, err = NewWebSearchTool(WebSearchToolOptions{})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if tool != nil {
|
|
t.Errorf("Expected nil tool when no provider is enabled")
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
|
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
ctx := context.Background()
|
|
args := map[string]any{}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error when query is missing")
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
|
|
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(
|
|
[]byte(
|
|
`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`,
|
|
),
|
|
)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain extracted text (without script/style tags)
|
|
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
|
|
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
|
|
}
|
|
|
|
// Should NOT contain script or style tags in ForLLM
|
|
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
|
|
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebFetchTool_extractText verifies text extraction preserves newlines
|
|
func TestWebFetchTool_extractText(t *testing.T) {
|
|
tool := &WebFetchTool{}
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
wantFunc func(t *testing.T, got string)
|
|
}{
|
|
{
|
|
name: "preserves newlines between block elements",
|
|
input: "<html><body><h1>Title</h1>\n<p>Paragraph 1</p>\n<p>Paragraph 2</p></body></html>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
lines := strings.Split(got, "\n")
|
|
if len(lines) < 2 {
|
|
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
|
|
}
|
|
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") ||
|
|
!strings.Contains(got, "Paragraph 2") {
|
|
t.Errorf("Missing expected text: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "removes script and style tags",
|
|
input: "<script>alert('x');</script><style>body{}</style><p>Keep this</p>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if strings.Contains(got, "alert") || strings.Contains(got, "body{}") {
|
|
t.Errorf("Expected script/style content removed, got: %q", got)
|
|
}
|
|
if !strings.Contains(got, "Keep this") {
|
|
t.Errorf("Expected 'Keep this' to remain, got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "collapses excessive blank lines",
|
|
input: "<p>A</p>\n\n\n\n\n<p>B</p>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if strings.Contains(got, "\n\n\n") {
|
|
t.Errorf("Expected excessive blank lines collapsed, got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "collapses horizontal whitespace",
|
|
input: "<p>hello world</p>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if strings.Contains(got, " ") {
|
|
t.Errorf("Expected spaces collapsed, got: %q", got)
|
|
}
|
|
if !strings.Contains(got, "hello world") {
|
|
t.Errorf("Expected 'hello world', got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "empty input",
|
|
input: "",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if got != "" {
|
|
t.Errorf("Expected empty string, got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := tool.extractText(tt.input)
|
|
tt.wantFunc(t, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
|
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": "https://",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error for URL without domain")
|
|
}
|
|
|
|
// Should mention missing domain
|
|
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
|
|
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestCreateHTTPClient_ProxyConfigured(t *testing.T) {
|
|
client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("createHTTPClient() error: %v", err)
|
|
}
|
|
if client.Timeout != 12*time.Second {
|
|
t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second)
|
|
}
|
|
|
|
tr, ok := client.Transport.(*http.Transport)
|
|
if !ok {
|
|
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
|
|
}
|
|
if tr.Proxy == nil {
|
|
t.Fatal("transport.Proxy is nil, want non-nil")
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", "https://example.com", nil)
|
|
if err != nil {
|
|
t.Fatalf("http.NewRequest() error: %v", err)
|
|
}
|
|
proxyURL, err := tr.Proxy(req)
|
|
if err != nil {
|
|
t.Fatalf("transport.Proxy(req) error: %v", err)
|
|
}
|
|
if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" {
|
|
t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890")
|
|
}
|
|
}
|
|
|
|
func TestCreateHTTPClient_InvalidProxy(t *testing.T) {
|
|
_, err := createHTTPClient("://bad-proxy", 10*time.Second)
|
|
if err == nil {
|
|
t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil")
|
|
}
|
|
}
|
|
|
|
func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) {
|
|
client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("createHTTPClient() error: %v", err)
|
|
}
|
|
|
|
tr, ok := client.Transport.(*http.Transport)
|
|
if !ok {
|
|
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
|
|
}
|
|
req, err := http.NewRequest("GET", "https://example.com", nil)
|
|
if err != nil {
|
|
t.Fatalf("http.NewRequest() error: %v", err)
|
|
}
|
|
proxyURL, err := tr.Proxy(req)
|
|
if err != nil {
|
|
t.Fatalf("transport.Proxy(req) error: %v", err)
|
|
}
|
|
if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" {
|
|
t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080")
|
|
}
|
|
}
|
|
|
|
func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) {
|
|
_, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second)
|
|
if err == nil {
|
|
t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "unsupported proxy scheme") {
|
|
t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme")
|
|
}
|
|
}
|
|
|
|
func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
|
|
t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
|
|
t.Setenv("http_proxy", "http://127.0.0.1:8888")
|
|
t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
|
|
t.Setenv("https_proxy", "http://127.0.0.1:8888")
|
|
t.Setenv("ALL_PROXY", "")
|
|
t.Setenv("all_proxy", "")
|
|
t.Setenv("NO_PROXY", "")
|
|
t.Setenv("no_proxy", "")
|
|
|
|
client, err := createHTTPClient("", 10*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("createHTTPClient() error: %v", err)
|
|
}
|
|
|
|
tr, ok := client.Transport.(*http.Transport)
|
|
if !ok {
|
|
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
|
|
}
|
|
if tr.Proxy == nil {
|
|
t.Fatal("transport.Proxy is nil, want proxy function from environment")
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", "https://example.com", nil)
|
|
if err != nil {
|
|
t.Fatalf("http.NewRequest() error: %v", err)
|
|
}
|
|
if _, err := tr.Proxy(req); err != nil {
|
|
t.Fatalf("transport.Proxy(req) error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
|
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
} else if tool.maxChars != 1024 {
|
|
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
|
|
}
|
|
|
|
if tool.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
|
|
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
if tool.maxChars != 50000 {
|
|
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
|
|
}
|
|
}
|
|
|
|
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
|
t.Run("perplexity", func(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
PerplexityEnabled: true,
|
|
PerplexityAPIKey: "k",
|
|
PerplexityMaxResults: 3,
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*PerplexitySearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
})
|
|
|
|
t.Run("brave", func(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
BraveEnabled: true,
|
|
BraveAPIKey: "k",
|
|
BraveMaxResults: 3,
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*BraveSearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
})
|
|
|
|
t.Run("duckduckgo", func(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
DuckDuckGoEnabled: true,
|
|
DuckDuckGoMaxResults: 3,
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestWebTool_TavilySearch_Success verifies successful Tavily search
|
|
func TestWebTool_TavilySearch_Success(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if r.Header.Get("Content-Type") != "application/json" {
|
|
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
|
}
|
|
|
|
// Verify payload
|
|
var payload map[string]any
|
|
json.NewDecoder(r.Body).Decode(&payload)
|
|
if payload["api_key"] != "test-key" {
|
|
t.Errorf("Expected api_key test-key, got %v", payload["api_key"])
|
|
}
|
|
if payload["query"] != "test query" {
|
|
t.Errorf("Expected query 'test query', got %v", payload["query"])
|
|
}
|
|
|
|
// Return mock response
|
|
response := map[string]any{
|
|
"results": []map[string]any{
|
|
{
|
|
"title": "Test Result 1",
|
|
"url": "https://example.com/1",
|
|
"content": "Content for result 1",
|
|
},
|
|
{
|
|
"title": "Test Result 2",
|
|
"url": "https://example.com/2",
|
|
"content": "Content for result 2",
|
|
},
|
|
},
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(response)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
TavilyEnabled: true,
|
|
TavilyAPIKey: "test-key",
|
|
TavilyBaseURL: server.URL,
|
|
TavilyMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"query": "test query",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForUser should contain result titles and URLs
|
|
if !strings.Contains(result.ForUser, "Test Result 1") ||
|
|
!strings.Contains(result.ForUser, "https://example.com/1") {
|
|
t.Errorf("Expected results in output, got: %s", result.ForUser)
|
|
}
|
|
|
|
// Should mention via Tavily
|
|
if !strings.Contains(result.ForUser, "via Tavily") {
|
|
t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
func TestNewWebSearchTool_ExaPriority(t *testing.T) {
|
|
// Exa should be selected when enabled with API key
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
ExaEnabled: true,
|
|
ExaAPIKey: "exa-key",
|
|
ExaMaxResults: 3,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if tool == nil {
|
|
t.Fatal("Expected non-nil tool when Exa is enabled with API key")
|
|
}
|
|
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
|
|
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
|
|
}
|
|
if tool.maxResults != 3 {
|
|
t.Fatalf("maxResults = %d, want 3", tool.maxResults)
|
|
}
|
|
|
|
// Exa enabled but no API key should fall through
|
|
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
|
ExaEnabled: true,
|
|
ExaAPIKey: "",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if tool != nil {
|
|
t.Errorf("Expected nil tool when Exa API key is empty and no other provider enabled")
|
|
}
|
|
|
|
// Perplexity should take priority over Exa
|
|
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
|
PerplexityEnabled: true,
|
|
PerplexityAPIKey: "perp-key",
|
|
ExaEnabled: true,
|
|
ExaAPIKey: "exa-key",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*PerplexitySearchProvider); !ok {
|
|
t.Fatalf("provider type = %T, want *PerplexitySearchProvider (Perplexity should outrank Exa)", tool.provider)
|
|
}
|
|
|
|
// Exa should take priority over Brave
|
|
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
|
ExaEnabled: true,
|
|
ExaAPIKey: "exa-key",
|
|
BraveEnabled: true,
|
|
BraveAPIKey: "brave-key",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*ExaSearchProvider); !ok {
|
|
t.Fatalf("provider type = %T, want *ExaSearchProvider (Exa should outrank Brave)", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestNewWebSearchTool_ExaProxyPropagation(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
ExaEnabled: true,
|
|
ExaAPIKey: "k",
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*ExaSearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *ExaSearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
}
|
|
|
|
func TestExaSearchProvider_Success(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if r.Header.Get("Content-Type") != "application/json" {
|
|
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
|
}
|
|
if r.Header.Get("x-api-key") != "test-exa-key" {
|
|
t.Errorf("Expected x-api-key test-exa-key, got %s", r.Header.Get("x-api-key"))
|
|
}
|
|
|
|
// Verify payload
|
|
body, _ := io.ReadAll(r.Body)
|
|
var payload map[string]any
|
|
json.Unmarshal(body, &payload)
|
|
if payload["query"] != "test query" {
|
|
t.Errorf("Expected query 'test query', got %v", payload["query"])
|
|
}
|
|
if payload["type"] != "neural" {
|
|
t.Errorf("Expected type 'neural', got %v", payload["type"])
|
|
}
|
|
|
|
response := map[string]any{
|
|
"results": []map[string]any{
|
|
{"title": "Exa Result 1", "url": "https://exa.ai/1", "text": "First result text"},
|
|
{"title": "Exa Result 2", "url": "https://exa.ai/2", "text": "Second result text"},
|
|
{"title": "Exa Result 3", "url": "https://exa.ai/3", "text": "Third result text"},
|
|
},
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider := &ExaSearchProvider{
|
|
apiKey: "test-exa-key",
|
|
client: &http.Client{},
|
|
}
|
|
|
|
// Temporarily override the API URL by using a custom transport
|
|
provider.client.Transport = rewriteHostTransport(server.URL)
|
|
|
|
result, err := provider.Search(context.Background(), "test query", 5)
|
|
if err != nil {
|
|
t.Fatalf("Search() error: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(result, "via Exa") {
|
|
t.Errorf("Expected '(via Exa)' attribution, got: %s", result)
|
|
}
|
|
if !strings.Contains(result, "Exa Result 1") || !strings.Contains(result, "https://exa.ai/1") {
|
|
t.Errorf("Expected results in output, got: %s", result)
|
|
}
|
|
if !strings.Contains(result, "First result text") {
|
|
t.Errorf("Expected snippet text in output, got: %s", result)
|
|
}
|
|
}
|
|
|
|
func TestExaSearchProvider_EmptyResults(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := map[string]any{"results": []map[string]any{}}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider := &ExaSearchProvider{
|
|
apiKey: "test-key",
|
|
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
|
|
}
|
|
|
|
result, err := provider.Search(context.Background(), "no results query", 5)
|
|
if err != nil {
|
|
t.Fatalf("Search() error: %v", err)
|
|
}
|
|
if !strings.Contains(result, "No results for: no results query") {
|
|
t.Errorf("Expected 'No results' message, got: %s", result)
|
|
}
|
|
}
|
|
|
|
func TestExaSearchProvider_MaxResultsCapping(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Return 5 results
|
|
results := make([]map[string]any, 5)
|
|
for i := range results {
|
|
results[i] = map[string]any{
|
|
"title": fmt.Sprintf("Result %d", i+1),
|
|
"url": fmt.Sprintf("https://exa.ai/%d", i+1),
|
|
"text": fmt.Sprintf("Text %d", i+1),
|
|
}
|
|
}
|
|
response := map[string]any{"results": results}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(response)
|
|
}))
|
|
defer server.Close()
|
|
|
|
provider := &ExaSearchProvider{
|
|
apiKey: "test-key",
|
|
client: &http.Client{Transport: rewriteHostTransport(server.URL)},
|
|
}
|
|
|
|
// Request only 2 results even though API returns 5
|
|
result, err := provider.Search(context.Background(), "test", 2)
|
|
if err != nil {
|
|
t.Fatalf("Search() error: %v", err)
|
|
}
|
|
|
|
if !strings.Contains(result, "Result 1") || !strings.Contains(result, "Result 2") {
|
|
t.Errorf("Expected first 2 results, got: %s", result)
|
|
}
|
|
if strings.Contains(result, "Result 3") {
|
|
t.Errorf("Expected results capped at 2, but got Result 3 in output: %s", result)
|
|
}
|
|
}
|
|
|
|
// rewriteHostTransport returns an http.RoundTripper that redirects all requests to the given target URL.
|
|
func rewriteHostTransport(target string) http.RoundTripper {
|
|
return roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
newURL := target + req.URL.Path
|
|
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, newURL, req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newReq.Header = req.Header
|
|
return http.DefaultClient.Do(newReq)
|
|
})
|
|
}
|
|
|
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return f(req)
|
|
}
|