mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
2ecdb893d5
RFC 2544 benchmark addresses (198.18.0.0/15) are not globally routable but were missing from the isPrivateOrRestrictedIP blocklist, allowing SSRF bypasses via literal IPv4. Fixes #3077
2393 lines
72 KiB
Go
2393 lines
72 KiB
Go
package integrationtools
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
const (
|
|
testFetchLimit = int64(10 * 1024 * 1024)
|
|
format = "plaintext"
|
|
)
|
|
|
|
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
|
func TestWebTool_WebFetch_Success(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain the fetched content (full JSON result)
|
|
if !strings.Contains(result.ForLLM, "Test Page") {
|
|
t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForUser should contain summary
|
|
if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
|
|
t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_JSON verifies JSON content handling
|
|
func TestWebTool_WebFetch_JSON(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
testData := map[string]string{"key": "value", "number": "123"}
|
|
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(expectedJSON)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain formatted JSON
|
|
if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
|
|
t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
|
|
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": "not-a-valid-url",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error for invalid URL")
|
|
}
|
|
|
|
// Should contain error message (either "invalid URL" or scheme error)
|
|
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
|
|
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
|
|
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": "ftp://example.com/file.txt",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error for unsupported URL scheme")
|
|
}
|
|
|
|
// Should mention only http/https allowed
|
|
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
|
|
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
|
|
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error when URL is missing")
|
|
}
|
|
|
|
// Should mention URL is required
|
|
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
|
|
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_Truncation verifies content truncation
|
|
func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
longContent := strings.Repeat("x", 20000)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(longContent))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(1000, format, testFetchLimit) // Limit to 1000 chars
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain truncated content (not the full 20000 chars)
|
|
resultMap := make(map[string]any)
|
|
json.Unmarshal([]byte(result.ForLLM), &resultMap)
|
|
if text, ok := resultMap["text"].(string); ok {
|
|
if len(text) > 1100 { // Allow some margin
|
|
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
|
}
|
|
}
|
|
|
|
// Should be marked as truncated
|
|
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
|
t.Errorf("Expected 'truncated' to be true in result")
|
|
}
|
|
|
|
// Text should end with the truncation notice
|
|
if text, ok := resultMap["text"].(string); ok {
|
|
if !strings.HasSuffix(text, "[Content truncated due to size limit]") {
|
|
t.Errorf("Expected text to end with truncation notice, got: %q", text[max(0, len(text)-60):])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_TruncationNotice verifies the truncation notice is appended
|
|
// for all content formats (text/plain, text/html, markdown, application/json).
|
|
func TestWebTool_WebFetch_TruncationNotice(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
const truncationNotice = "[Content truncated due to size limit]"
|
|
const maxChars = 100
|
|
|
|
tests := []struct {
|
|
name string
|
|
contentType string
|
|
body string
|
|
format string
|
|
}{
|
|
{
|
|
name: "plain text",
|
|
contentType: "text/plain",
|
|
body: strings.Repeat("a", 500),
|
|
format: "plaintext",
|
|
},
|
|
{
|
|
name: "html plaintext extractor",
|
|
contentType: "text/html",
|
|
body: "<html><body>" + strings.Repeat("b", 500) + "</body></html>",
|
|
format: "plaintext",
|
|
},
|
|
{
|
|
name: "html markdown extractor",
|
|
contentType: "text/html",
|
|
body: "<html><body>" + strings.Repeat("c", 500) + "</body></html>",
|
|
format: "markdown",
|
|
},
|
|
{
|
|
name: "json",
|
|
contentType: "application/json",
|
|
body: `"` + strings.Repeat("d", 500) + `"`,
|
|
format: "plaintext",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", tt.contentType)
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(tt.body))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(maxChars, tt.format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("NewWebFetchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
|
if result.IsError {
|
|
t.Fatalf("unexpected error: %s", result.ForLLM)
|
|
}
|
|
|
|
var resultMap map[string]any
|
|
if err := json.Unmarshal([]byte(result.ForLLM), &resultMap); err != nil {
|
|
t.Fatalf("failed to unmarshal result JSON: %v", err)
|
|
}
|
|
|
|
text, ok := resultMap["text"].(string)
|
|
if !ok {
|
|
t.Fatal("missing 'text' field in result")
|
|
}
|
|
|
|
if !strings.HasSuffix(text, truncationNotice) {
|
|
t.Errorf("expected text to end with %q, got suffix: %q", truncationNotice, text[max(0, len(text)-60):])
|
|
}
|
|
|
|
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
|
t.Errorf("expected truncated=true in result")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_NoTruncationNoticeWhenFitsInLimit verifies that the notice
|
|
// is NOT appended when the content fits within the limit.
|
|
func TestWebTool_WebFetch_NoTruncationNoticeWhenFitsInLimit(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
const truncationNotice = "[Content truncated due to size limit]"
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("short content"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("NewWebFetchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
|
if result.IsError {
|
|
t.Fatalf("unexpected error: %s", result.ForLLM)
|
|
}
|
|
|
|
var resultMap map[string]any
|
|
if err := json.Unmarshal([]byte(result.ForLLM), &resultMap); err != nil {
|
|
t.Fatalf("failed to unmarshal result JSON: %v", err)
|
|
}
|
|
|
|
text, _ := resultMap["text"].(string)
|
|
if strings.Contains(text, truncationNotice) {
|
|
t.Errorf("expected no truncation notice for content within limit, got: %q", text)
|
|
}
|
|
|
|
if truncated, _ := resultMap["truncated"].(bool); truncated {
|
|
t.Errorf("expected truncated=false for content within limit")
|
|
}
|
|
}
|
|
|
|
func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
// Create a mock HTTP server
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Generate a payload intentionally larger than our limit.
|
|
// Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'.
|
|
largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100)
|
|
|
|
w.Write(largeData)
|
|
}))
|
|
// Ensure the server is shut down at the end of the test
|
|
defer ts.Close()
|
|
|
|
// Initialize the tool
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
// Prepare the arguments pointing to the URL of our local mock server
|
|
args := map[string]any{
|
|
"url": ts.URL,
|
|
}
|
|
|
|
// Execute the tool
|
|
ctx := context.Background()
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Assuming ErrorResult sets the ForLLM field with the error text.
|
|
if result == nil {
|
|
t.Fatal("expected a ToolResult, got nil")
|
|
}
|
|
|
|
// Search for the exact error string we set earlier in the Execute method
|
|
expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
|
|
|
|
if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
|
|
t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebSearch_NoApiKey verifies providers without required credentials are not registered.
|
|
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKeys: nil})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if tool != nil {
|
|
t.Fatalf("Expected nil tool when only enabled provider is missing credentials")
|
|
}
|
|
|
|
// Also nil when nothing is enabled
|
|
tool, err = NewWebSearchTool(WebSearchToolOptions{})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if tool != nil {
|
|
t.Errorf("Expected nil tool when no provider is enabled")
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
|
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
BraveEnabled: true,
|
|
BraveAPIKeys: []string{"test-key"},
|
|
BraveMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
ctx := context.Background()
|
|
args := map[string]any{}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error when query is missing")
|
|
}
|
|
}
|
|
|
|
func TestNormalizeSearchRange(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{name: "empty", input: "", want: ""},
|
|
{name: "day", input: "d", want: "d"},
|
|
{name: "week uppercase trimmed", input: " W ", want: "w"},
|
|
{name: "month", input: "m", want: "m"},
|
|
{name: "year", input: "y", want: "y"},
|
|
{name: "invalid", input: "q", wantErr: true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := normalizeSearchRange(tt.input)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatalf("expected error, got nil")
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if got != tt.want {
|
|
t.Fatalf("normalizeSearchRange(%q) = %q, want %q", tt.input, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSearchRangeMappings(t *testing.T) {
|
|
if got := mapBraveFreshness("d"); got != "pd" {
|
|
t.Fatalf("mapBraveFreshness(d) = %q, want pd", got)
|
|
}
|
|
if got := mapBraveFreshness("y"); got != "py" {
|
|
t.Fatalf("mapBraveFreshness(y) = %q, want py", got)
|
|
}
|
|
if got := mapTavilyTimeRange("w"); got != "week" {
|
|
t.Fatalf("mapTavilyTimeRange(w) = %q, want week", got)
|
|
}
|
|
if got := mapPerplexityRecencyFilter("m"); got != "month" {
|
|
t.Fatalf("mapPerplexityRecencyFilter(m) = %q, want month", got)
|
|
}
|
|
if got := mapDuckDuckGoDateFilter("y"); got != "t" {
|
|
t.Fatalf("mapDuckDuckGoDateFilter(y) = %q, want t", got)
|
|
}
|
|
if got := mapSearXNGTimeRange("d"); got != "day" {
|
|
t.Fatalf("mapSearXNGTimeRange(d) = %q, want day", got)
|
|
}
|
|
if got := mapGLMRecencyFilter("w"); got != "oneWeek" {
|
|
t.Fatalf("mapGLMRecencyFilter(w) = %q, want oneWeek", got)
|
|
}
|
|
if got := mapGLMRecencyFilter(""); got != "noLimit" {
|
|
t.Fatalf("mapGLMRecencyFilter(\"\") = %q, want noLimit", got)
|
|
}
|
|
if got := mapBaiduRecencyFilter("d"); got != "week" {
|
|
t.Fatalf("mapBaiduRecencyFilter(d) = %q, want week", got)
|
|
}
|
|
if got := mapBaiduRecencyFilter("m"); got != "month" {
|
|
t.Fatalf("mapBaiduRecencyFilter(m) = %q, want month", got)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_WebSearch_InvalidRange(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
BraveEnabled: true,
|
|
BraveAPIKeys: []string{"test-key"},
|
|
BraveMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
"range": "invalid",
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Fatalf("expected invalid range to return error")
|
|
}
|
|
if !strings.Contains(result.ForLLM, "range must be one of: d, w, m, y") {
|
|
t.Fatalf("unexpected error message: %q", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
|
|
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(
|
|
[]byte(
|
|
`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`,
|
|
),
|
|
)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": server.URL,
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForLLM should contain extracted text (without script/style tags)
|
|
if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
|
|
t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
|
|
}
|
|
|
|
// Should NOT contain script or style tags in ForLLM
|
|
if strings.Contains(result.ForLLM, "<script>") || strings.Contains(result.ForLLM, "<style>") {
|
|
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebFetchTool_extractText verifies text extraction preserves newlines
|
|
func TestWebFetchTool_extractText(t *testing.T) {
|
|
tool := &WebFetchTool{}
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
wantFunc func(t *testing.T, got string)
|
|
}{
|
|
{
|
|
name: "preserves newlines between block elements",
|
|
input: "<html><body><h1>Title</h1>\n<p>Paragraph 1</p>\n<p>Paragraph 2</p></body></html>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
lines := strings.Split(got, "\n")
|
|
if len(lines) < 2 {
|
|
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
|
|
}
|
|
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") ||
|
|
!strings.Contains(got, "Paragraph 2") {
|
|
t.Errorf("Missing expected text: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "removes script and style tags",
|
|
input: "<script>alert('x');</script><style>body{}</style><p>Keep this</p>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if strings.Contains(got, "alert") || strings.Contains(got, "body{}") {
|
|
t.Errorf("Expected script/style content removed, got: %q", got)
|
|
}
|
|
if !strings.Contains(got, "Keep this") {
|
|
t.Errorf("Expected 'Keep this' to remain, got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "collapses excessive blank lines",
|
|
input: "<p>A</p>\n\n\n\n\n<p>B</p>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if strings.Contains(got, "\n\n\n") {
|
|
t.Errorf("Expected excessive blank lines collapsed, got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "collapses horizontal whitespace",
|
|
input: "<p>hello world</p>",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if strings.Contains(got, " ") {
|
|
t.Errorf("Expected spaces collapsed, got: %q", got)
|
|
}
|
|
if !strings.Contains(got, "hello world") {
|
|
t.Errorf("Expected 'hello world', got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "empty input",
|
|
input: "",
|
|
wantFunc: func(t *testing.T, got string) {
|
|
if got != "" {
|
|
t.Errorf("Expected empty string, got: %q", got)
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := tool.extractText(tt.input)
|
|
tt.wantFunc(t, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func withPrivateWebFetchHostsAllowed(t *testing.T) {
|
|
t.Helper()
|
|
previous := allowPrivateWebFetchHosts.Load()
|
|
allowPrivateWebFetchHosts.Store(true)
|
|
t.Cleanup(func() {
|
|
allowPrivateWebFetchHosts.Store(previous)
|
|
})
|
|
}
|
|
|
|
func serverHostAndPort(t *testing.T, rawURL string) (string, string) {
|
|
t.Helper()
|
|
hostPort := strings.TrimPrefix(rawURL, "http://")
|
|
hostPort = strings.TrimPrefix(hostPort, "https://")
|
|
host, port, err := net.SplitHostPort(hostPort)
|
|
if err != nil {
|
|
t.Fatalf("failed to split host/port from %q: %v", rawURL, err)
|
|
}
|
|
return host, port
|
|
}
|
|
|
|
func singleHostCIDR(t *testing.T, host string) string {
|
|
t.Helper()
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
t.Fatalf("failed to parse IP %q", host)
|
|
}
|
|
if ip.To4() != nil {
|
|
return ip.String() + "/32"
|
|
}
|
|
return ip.String() + "/128"
|
|
}
|
|
|
|
func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": "http://127.0.0.1:0",
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Errorf("expected error for private host URL, got success")
|
|
}
|
|
if !strings.Contains(result.ForLLM, "private or local network") &&
|
|
!strings.Contains(result.ForUser, "private or local network") {
|
|
t.Errorf("expected private host block message, got %q", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_WebFetch_PrivateHostAllowedByExactWhitelist(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("exact whitelist ok"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
host, _ := serverHostAndPort(t, server.URL)
|
|
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{host})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": server.URL,
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success for exact whitelisted private IP, got %q", result.ForLLM)
|
|
}
|
|
if !strings.Contains(result.ForLLM, "exact whitelist ok") {
|
|
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_WebFetch_PrivateHostAllowedByCIDRWhitelist(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("cidr whitelist ok"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
host, _ := serverHostAndPort(t, server.URL)
|
|
tool, err := NewWebFetchToolWithConfig(50000, "", format, testFetchLimit, []string{singleHostCIDR(t, host)})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": server.URL,
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success for CIDR-whitelisted private IP, got %q", result.ForLLM)
|
|
}
|
|
if !strings.Contains(result.ForLLM, "cidr whitelist ok") {
|
|
t.Fatalf("expected fetched content, got %q", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("ok"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": server.URL,
|
|
})
|
|
|
|
if result.IsError {
|
|
t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_WebFetch_AllowsLoopbackProxy(t *testing.T) {
|
|
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.String() != "http://example.com/proxied" {
|
|
t.Fatalf("proxy received URL %q, want %q", r.URL.String(), "http://example.com/proxied")
|
|
}
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("proxied content"))
|
|
}))
|
|
defer proxy.Close()
|
|
|
|
tool, err := NewWebFetchToolWithProxy(50000, proxy.URL, format, testFetchLimit, nil)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": "http://example.com/proxied",
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success through loopback proxy, got %q", result.ForLLM)
|
|
}
|
|
if !strings.Contains(result.ForLLM, "proxied content") {
|
|
t.Fatalf("expected proxied content, got %q", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked
|
|
func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": "http://[::ffff:127.0.0.1]:0",
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Error("expected error for IPv4-mapped IPv6 loopback URL, got success")
|
|
}
|
|
}
|
|
|
|
// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked
|
|
func TestWebFetch_BlocksMetadataIP(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": "http://169.254.169.254/latest/meta-data",
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Error("expected error for cloud metadata IP, got success")
|
|
}
|
|
}
|
|
|
|
// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked
|
|
func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": "http://[fd00::1]:0",
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Error("expected error for IPv6 unique local address, got success")
|
|
}
|
|
}
|
|
|
|
// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked
|
|
func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
// 2002:7f00:0001::1 embeds 127.0.0.1
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": "http://[2002:7f00:0001::1]:0",
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Error("expected error for 6to4 with private embedded IPv4, got success")
|
|
}
|
|
}
|
|
|
|
// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked
|
|
func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
// 2002:0801:0101::1 embeds 8.1.1.1 (public) — pre-flight should pass,
|
|
// connection will fail (no listener) but that's after the SSRF check.
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": "http://[2002:0801:0101::1]:0",
|
|
})
|
|
|
|
// Should NOT be blocked by SSRF check — error should be connection failure, not "private"
|
|
if result.IsError && strings.Contains(result.ForLLM, "private") {
|
|
t.Error("6to4 with public embedded IPv4 should not be blocked as private")
|
|
}
|
|
}
|
|
|
|
// TestWebFetch_RedirectToPrivateBlocked verifies redirects to private IPs are blocked
|
|
func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Redirect to a private IP
|
|
http.Redirect(w, r, "http://10.0.0.1/secret", http.StatusFound)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Temporarily disable private host allowance for the redirect check
|
|
allowPrivateWebFetchHosts.Store(false)
|
|
defer allowPrivateWebFetchHosts.Store(true)
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create web fetch tool: %v", err)
|
|
}
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"url": server.URL,
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Error("expected error when redirecting to private IP, got success")
|
|
}
|
|
}
|
|
|
|
func TestNewSafeDialContext_BlocksPrivateDNSResolutionWithoutWhitelist(t *testing.T) {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("failed to listen on loopback: %v", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
_, port, err := net.SplitHostPort(listener.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("failed to split listener address: %v", err)
|
|
}
|
|
|
|
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, nil)
|
|
_, err = dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
|
if err == nil {
|
|
t.Fatal("expected localhost DNS resolution to be blocked without whitelist")
|
|
}
|
|
if !strings.Contains(err.Error(), "private") && !strings.Contains(err.Error(), "whitelisted") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewSafeDialContext_AllowsWhitelistedPrivateDNSResolution(t *testing.T) {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("failed to listen on loopback: %v", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
accepted := make(chan struct{}, 1)
|
|
go func() {
|
|
conn, acceptErr := listener.Accept()
|
|
if acceptErr != nil {
|
|
return
|
|
}
|
|
conn.Close()
|
|
accepted <- struct{}{}
|
|
}()
|
|
|
|
_, port, err := net.SplitHostPort(listener.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("failed to split listener address: %v", err)
|
|
}
|
|
|
|
whitelist, err := newPrivateHostWhitelist([]string{"127.0.0.0/8"})
|
|
if err != nil {
|
|
t.Fatalf("failed to parse whitelist: %v", err)
|
|
}
|
|
|
|
dialContext := newSafeDialContext(&net.Dialer{Timeout: time.Second}, whitelist)
|
|
conn, err := dialContext(context.Background(), "tcp", net.JoinHostPort("localhost", port))
|
|
if err != nil {
|
|
t.Fatalf("expected localhost DNS resolution to succeed with whitelist, got %v", err)
|
|
}
|
|
conn.Close()
|
|
|
|
select {
|
|
case <-accepted:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected localhost listener to accept a connection")
|
|
}
|
|
}
|
|
|
|
// TestIsPrivateOrRestrictedIP_Table tests IP classification logic
|
|
func TestIsPrivateOrRestrictedIP_Table(t *testing.T) {
|
|
tests := []struct {
|
|
ip string
|
|
blocked bool
|
|
desc string
|
|
}{
|
|
{"127.0.0.1", true, "IPv4 loopback"},
|
|
{"10.0.0.1", true, "IPv4 private class A"},
|
|
{"172.16.0.1", true, "IPv4 private class B"},
|
|
{"192.168.1.1", true, "IPv4 private class C"},
|
|
{"169.254.169.254", true, "link-local / cloud metadata"},
|
|
{"100.64.0.1", true, "carrier-grade NAT"},
|
|
{"198.18.0.1", true, "RFC 2544 benchmark"},
|
|
{"198.19.255.1", true, "RFC 2544 benchmark end"},
|
|
{"198.17.0.1", false, "just before 198.18.0.0/15"},
|
|
{"198.20.0.1", false, "just after 198.19.255.255"},
|
|
{"0.0.0.0", true, "unspecified"},
|
|
{"8.8.8.8", false, "public DNS"},
|
|
{"1.1.1.1", false, "public DNS"},
|
|
{"::1", true, "IPv6 loopback"},
|
|
{"::ffff:127.0.0.1", true, "IPv4-mapped IPv6 loopback"},
|
|
{"::ffff:10.0.0.1", true, "IPv4-mapped IPv6 private"},
|
|
{"fc00::1", true, "IPv6 unique local"},
|
|
{"fd00::1", true, "IPv6 unique local"},
|
|
{"2002:7f00:0001::1", true, "6to4 with embedded 127.x (private)"},
|
|
{"2002:0a00:0001::1", true, "6to4 with embedded 10.0.0.1 (private)"},
|
|
{"2002:0801:0101::1", false, "6to4 with embedded 8.1.1.1 (public)"},
|
|
{"2001:0000:4136:e378:8000:63bf:f5ff:fffe", true, "Teredo with client 10.0.0.1 (private)"},
|
|
{"2001:0000:4136:e378:8000:63bf:f7f6:fefe", false, "Teredo with client 8.9.1.1 (public)"},
|
|
{"2607:f8b0:4004:800::200e", false, "public IPv6 (Google)"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
ip := net.ParseIP(tt.ip)
|
|
if ip == nil {
|
|
t.Fatalf("failed to parse IP: %s", tt.ip)
|
|
}
|
|
got := isPrivateOrRestrictedIP(ip)
|
|
if got != tt.blocked {
|
|
t.Errorf("isPrivateOrRestrictedIP(%s) = %v, want %v", tt.ip, got, tt.blocked)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
|
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"url": "https://",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Should return error result
|
|
if !result.IsError {
|
|
t.Errorf("Expected error for URL without domain")
|
|
}
|
|
|
|
// Should mention missing domain
|
|
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
|
|
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestNewWebFetchToolWithProxy(t *testing.T) {
|
|
tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
} else if tool.maxChars != 1024 {
|
|
t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
|
|
}
|
|
|
|
if tool.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
|
|
tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", format, testFetchLimit, nil)
|
|
if err != nil {
|
|
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
|
}
|
|
|
|
if tool.maxChars != 50000 {
|
|
t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
|
|
}
|
|
}
|
|
|
|
func TestNewWebFetchToolWithConfig_InvalidPrivateHostWhitelist(t *testing.T) {
|
|
_, err := NewWebFetchToolWithConfig(1024, "", format, testFetchLimit, []string{"not-an-ip-or-cidr"})
|
|
if err == nil {
|
|
t.Fatal("expected invalid whitelist entry to fail")
|
|
}
|
|
if !strings.Contains(err.Error(), "invalid entry") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
|
|
t.Run("perplexity", func(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
PerplexityEnabled: true,
|
|
PerplexityAPIKeys: []string{"k"},
|
|
PerplexityMaxResults: 3,
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*PerplexitySearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
})
|
|
|
|
t.Run("brave", func(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
BraveEnabled: true,
|
|
BraveAPIKeys: []string{"k"},
|
|
BraveMaxResults: 3,
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*BraveSearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
})
|
|
|
|
t.Run("duckduckgo", func(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
DuckDuckGoEnabled: true,
|
|
DuckDuckGoMaxResults: 3,
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*DuckDuckGoSearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
})
|
|
|
|
t.Run("searxng", func(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
SearXNGEnabled: true,
|
|
SearXNGBaseURL: "https://searx.example.com",
|
|
SearXNGMaxResults: 3,
|
|
Proxy: "http://127.0.0.1:7890",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
p, ok := tool.provider.(*SearXNGSearchProvider)
|
|
if !ok {
|
|
t.Fatalf("provider type = %T, want *SearXNGSearchProvider", tool.provider)
|
|
}
|
|
if p.proxy != "http://127.0.0.1:7890" {
|
|
t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
|
|
}
|
|
tr, ok := p.client.Transport.(*http.Transport)
|
|
if !ok {
|
|
t.Fatalf("client.Transport type = %T, want *http.Transport", p.client.Transport)
|
|
}
|
|
req, err := http.NewRequest(http.MethodGet, "https://searx.example.com/search", nil)
|
|
if err != nil {
|
|
t.Fatalf("http.NewRequest() error: %v", err)
|
|
}
|
|
proxyURL, err := tr.Proxy(req)
|
|
if err != nil {
|
|
t.Fatalf("transport.Proxy(req) error: %v", err)
|
|
}
|
|
if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" {
|
|
t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestWebTool_TavilySearch_Success verifies successful Tavily search
|
|
func TestWebTool_TavilySearch_Success(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if r.Header.Get("Content-Type") != "application/json" {
|
|
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
|
}
|
|
|
|
// Verify payload
|
|
var payload map[string]any
|
|
json.NewDecoder(r.Body).Decode(&payload)
|
|
if payload["api_key"] != "test-key" {
|
|
t.Errorf("Expected api_key test-key, got %v", payload["api_key"])
|
|
}
|
|
if payload["query"] != "test query" {
|
|
t.Errorf("Expected query 'test query', got %v", payload["query"])
|
|
}
|
|
|
|
// Return mock response
|
|
response := map[string]any{
|
|
"results": []map[string]any{
|
|
{
|
|
"title": "Test Result 1",
|
|
"url": "https://example.com/1",
|
|
"content": "Content for result 1",
|
|
},
|
|
{
|
|
"title": "Test Result 2",
|
|
"url": "https://example.com/2",
|
|
"content": "Content for result 2",
|
|
},
|
|
},
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(response)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
TavilyEnabled: true,
|
|
TavilyAPIKeys: []string{"test-key"},
|
|
TavilyBaseURL: server.URL,
|
|
TavilyMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"query": "test query",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
// Success should not be an error
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
|
|
// ForUser should contain result titles and URLs
|
|
if !strings.Contains(result.ForUser, "Test Result 1") ||
|
|
!strings.Contains(result.ForUser, "https://example.com/1") {
|
|
t.Errorf("Expected results in output, got: %s", result.ForUser)
|
|
}
|
|
|
|
// Should mention via Tavily
|
|
if !strings.Contains(result.ForUser, "via Tavily") {
|
|
t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_TavilySearch_RangeMapping(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("failed to decode payload: %v", err)
|
|
}
|
|
if payload["time_range"] != "week" {
|
|
t.Fatalf("expected time_range=week, got %v", payload["time_range"])
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"results": []map[string]any{
|
|
{"title": "Recent result", "url": "https://example.com/recent", "content": "snippet"},
|
|
},
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
TavilyEnabled: true,
|
|
TavilyAPIKeys: []string{"test-key"},
|
|
TavilyBaseURL: server.URL,
|
|
TavilyMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
"range": "w",
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success, got %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_KagiSearch_SuccessRequestAndParsing(t *testing.T) {
|
|
var sawRequest bool
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sawRequest = true
|
|
if r.Method != http.MethodPost {
|
|
t.Fatalf("method = %s, want POST", r.Method)
|
|
}
|
|
if r.URL.Path != "/search" {
|
|
t.Fatalf("path = %s, want /search", r.URL.Path)
|
|
}
|
|
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
|
|
t.Fatalf("Authorization = %q, want Bearer test-key", got)
|
|
}
|
|
|
|
var payload map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("failed to decode payload: %v", err)
|
|
}
|
|
if payload["query"] != "test query" {
|
|
t.Fatalf("query = %v, want test query", payload["query"])
|
|
}
|
|
if payload["workflow"] != "search" {
|
|
t.Fatalf("workflow = %v, want search", payload["workflow"])
|
|
}
|
|
if payload["limit"] != float64(2) {
|
|
t.Fatalf("limit = %v, want 2", payload["limit"])
|
|
}
|
|
lens, ok := payload["lens"].(map[string]any)
|
|
if !ok || lens["time_relative"] != "week" {
|
|
t.Fatalf("lens = %v, want time_relative=week", payload["lens"])
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{
|
|
"data": {
|
|
"search": [
|
|
{
|
|
"title": "Kagi Result 1",
|
|
"url": "https://example.com/1",
|
|
"snippet": "<b>first</b> snippet",
|
|
"time": "2026-01-02T03:04:05Z"
|
|
},
|
|
{
|
|
"title": "Kagi Result 2",
|
|
"url": "https://example.com/2",
|
|
"snippet": "second snippet"
|
|
},
|
|
{
|
|
"title": "Kagi Result 3",
|
|
"url": "https://example.com/3",
|
|
"snippet": "third snippet"
|
|
}
|
|
]
|
|
}
|
|
}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
KagiEnabled: true,
|
|
KagiAPIKeys: []string{"test-key"},
|
|
KagiBaseURL: server.URL,
|
|
KagiMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
"count": float64(2),
|
|
"range": "w",
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success, got %s", result.ForLLM)
|
|
}
|
|
if !sawRequest {
|
|
t.Fatal("server did not receive request")
|
|
}
|
|
if !strings.Contains(result.ForUser, "via Kagi") ||
|
|
!strings.Contains(result.ForUser, "Kagi Result 1") ||
|
|
!strings.Contains(result.ForUser, "https://example.com/1") ||
|
|
!strings.Contains(result.ForUser, "first snippet") ||
|
|
!strings.Contains(result.ForUser, "Published: 2026-01-02T03:04:05Z") {
|
|
t.Fatalf("expected Kagi result fields in output, got: %s", result.ForUser)
|
|
}
|
|
if strings.Contains(result.ForUser, "Kagi Result 3") {
|
|
t.Fatalf("expected output truncated to requested count, got: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_KagiSearch_NoApiKey(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
KagiEnabled: true,
|
|
KagiMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if tool != nil {
|
|
t.Fatal("expected nil tool when Kagi is enabled without API keys")
|
|
}
|
|
}
|
|
|
|
func TestWebTool_KagiSearch_AuthErrorDoesNotLeakKey(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if got := r.Header.Get("Authorization"); got != "Bearer invalid-kagi-key" {
|
|
t.Fatalf("Authorization = %q, want Bearer invalid-kagi-key", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error":"unauthorized"}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
KagiEnabled: true,
|
|
KagiAPIKeys: []string{"invalid-kagi-key"},
|
|
KagiBaseURL: server.URL + "/search",
|
|
KagiMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{"query": "test query"})
|
|
if !result.IsError {
|
|
t.Fatal("expected auth error")
|
|
}
|
|
if !strings.Contains(result.ForLLM, "authentication failed") {
|
|
t.Fatalf("unexpected error message: %s", result.ForLLM)
|
|
}
|
|
if strings.Contains(result.ForLLM, "invalid-kagi-key") || strings.Contains(result.ForUser, "invalid-kagi-key") {
|
|
t.Fatalf("error leaked API key: %+v", result)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_KagiSearch_Non200Response(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte(`{"error":"temporary failure"}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
KagiEnabled: true,
|
|
KagiAPIKeys: []string{"test-key"},
|
|
KagiBaseURL: server.URL,
|
|
KagiMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{"query": "test query"})
|
|
if !result.IsError {
|
|
t.Fatal("expected non-200 error")
|
|
}
|
|
if !strings.Contains(result.ForLLM, "server error (status 500)") {
|
|
t.Fatalf("unexpected error message: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_KagiSearch_SkipsUnsupportedAndMalformedResults(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{
|
|
"data": {
|
|
"image": [
|
|
{"title": "Image Result", "url": "https://images.example.com/1"}
|
|
],
|
|
"search": [
|
|
{"title": "Missing URL"},
|
|
{"title": "Usable Result", "url": "https://example.com/usable", "snippet": "usable snippet", "extra": "ignored"}
|
|
]
|
|
}
|
|
}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
KagiEnabled: true,
|
|
KagiAPIKeys: []string{"test-key"},
|
|
KagiBaseURL: server.URL,
|
|
KagiMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{"query": "test query"})
|
|
if result.IsError {
|
|
t.Fatalf("expected success, got %s", result.ForLLM)
|
|
}
|
|
if !strings.Contains(result.ForUser, "Usable Result") ||
|
|
!strings.Contains(result.ForUser, "https://example.com/usable") {
|
|
t.Fatalf("expected usable result, got: %s", result.ForUser)
|
|
}
|
|
if strings.Contains(result.ForUser, "Image Result") || strings.Contains(result.ForUser, "Missing URL") {
|
|
t.Fatalf("unsupported or malformed result was included: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
// TestWebFetchTool_CloudflareChallenge_RetryWithHonestUA verifies that a 403 response
|
|
// with cf-mitigated: challenge triggers a retry using the honest picoclaw User-Agent,
|
|
// and that the retry response is returned when it succeeds.
|
|
func TestWebFetchTool_CloudflareChallenge_RetryWithHonestUA(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
requestCount := 0
|
|
var receivedUAs []string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestCount++
|
|
receivedUAs = append(receivedUAs, r.Header.Get("User-Agent"))
|
|
|
|
if requestCount == 1 {
|
|
// First request: simulate Cloudflare challenge
|
|
w.Header().Set("Cf-Mitigated", "challenge")
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte("<html><body>Cloudflare challenge</body></html>"))
|
|
return
|
|
}
|
|
// Second request (honest UA retry): success
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("real content"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("NewWebFetchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
|
|
|
if result.IsError {
|
|
t.Fatalf("expected success after retry, got error: %s", result.ForLLM)
|
|
}
|
|
if !strings.Contains(result.ForLLM, "real content") {
|
|
t.Errorf("expected retry response content, got: %s", result.ForLLM)
|
|
}
|
|
if requestCount != 2 {
|
|
t.Errorf("expected exactly 2 requests, got %d", requestCount)
|
|
}
|
|
|
|
// First request must use the generic user agent
|
|
if receivedUAs[0] != userAgent {
|
|
t.Errorf("first request UA = %q, want %q", receivedUAs[0], userAgent)
|
|
}
|
|
// Second request must use the honest picoclaw user agent
|
|
if !strings.Contains(receivedUAs[1], "picoclaw") {
|
|
t.Errorf("retry request UA = %q, want it to contain 'picoclaw'", receivedUAs[1])
|
|
}
|
|
}
|
|
|
|
// TestWebFetchTool_CloudflareChallenge_NoRetryOnOtherErrors verifies that a plain 403
|
|
// (without cf-mitigated: challenge) does NOT trigger a retry.
|
|
func TestWebFetchTool_CloudflareChallenge_NoRetryOnOtherErrors(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
requestCount := 0
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestCount++
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte("plain forbidden"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("NewWebFetchTool() error: %v", err)
|
|
}
|
|
|
|
tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
|
|
|
if requestCount != 1 {
|
|
t.Errorf("expected exactly 1 request for plain 403, got %d", requestCount)
|
|
}
|
|
}
|
|
|
|
// TestWebFetchTool_CloudflareChallenge_RetryFailsToo verifies that if the honest-UA
|
|
// retry also fails (e.g. still blocked), the error from the retry is returned.
|
|
func TestWebFetchTool_CloudflareChallenge_RetryFailsToo(t *testing.T) {
|
|
withPrivateWebFetchHostsAllowed(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Always return CF challenge regardless of UA
|
|
w.Header().Set("Cf-Mitigated", "challenge")
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
w.Write([]byte("<html><body>still blocked</body></html>"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebFetchTool(50000, format, testFetchLimit)
|
|
if err != nil {
|
|
t.Fatalf("NewWebFetchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{"url": server.URL})
|
|
|
|
// Should not be an error — the retry response is used as-is (403 is a valid HTTP response)
|
|
if result.IsError {
|
|
t.Fatalf("expected non-error result even when retry is also blocked, got: %s", result.ForLLM)
|
|
}
|
|
// Status in the JSON result should reflect the 403
|
|
if !strings.Contains(result.ForLLM, "403") {
|
|
t.Errorf("expected status 403 in result, got: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestAPIKeyPool(t *testing.T) {
|
|
pool := NewAPIKeyPool([]string{"key1", "key2", "key3"})
|
|
if len(pool.keys) != 3 {
|
|
t.Fatalf("expected 3 keys, got %d", len(pool.keys))
|
|
}
|
|
if pool.keys[0] != "key1" || pool.keys[1] != "key2" || pool.keys[2] != "key3" {
|
|
t.Fatalf("unexpected keys: %v", pool.keys)
|
|
}
|
|
|
|
// Test Iterator: each iterator should cover all keys exactly once
|
|
iter := pool.NewIterator()
|
|
expected := []string{"key1", "key2", "key3"}
|
|
for i, want := range expected {
|
|
k, ok := iter.Next()
|
|
if !ok {
|
|
t.Fatalf("iter.Next() returned false at step %d", i)
|
|
}
|
|
if k != want {
|
|
t.Errorf("step %d: expected %s, got %s", i, want, k)
|
|
}
|
|
}
|
|
// Should be exhausted
|
|
if _, ok := iter.Next(); ok {
|
|
t.Errorf("expected iterator exhausted after all keys")
|
|
}
|
|
|
|
// Second iterator starts at next position (load balancing)
|
|
iter2 := pool.NewIterator()
|
|
k, ok := iter2.Next()
|
|
if !ok {
|
|
t.Fatal("iter2.Next() returned false")
|
|
}
|
|
if k != "key2" {
|
|
t.Errorf("expected key2 (round-robin), got %s", k)
|
|
}
|
|
|
|
// Empty pool
|
|
emptyPool := NewAPIKeyPool([]string{})
|
|
emptyIter := emptyPool.NewIterator()
|
|
if _, ok := emptyIter.Next(); ok {
|
|
t.Errorf("expected false for empty pool")
|
|
}
|
|
|
|
// Single key pool
|
|
singlePool := NewAPIKeyPool([]string{"single"})
|
|
singleIter := singlePool.NewIterator()
|
|
if k, ok := singleIter.Next(); !ok || k != "single" {
|
|
t.Errorf("expected single, got %s (ok=%v)", k, ok)
|
|
}
|
|
if _, ok := singleIter.Next(); ok {
|
|
t.Errorf("expected exhausted after single key")
|
|
}
|
|
}
|
|
|
|
func TestWebTool_TavilySearch_Failover(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("failed to decode payload: %v", err)
|
|
}
|
|
|
|
apiKey := payload["api_key"].(string)
|
|
|
|
if apiKey == "key1" {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte("Rate limited"))
|
|
return
|
|
}
|
|
|
|
if apiKey == "key2" {
|
|
// Success
|
|
response := map[string]any{
|
|
"results": []map[string]any{
|
|
{
|
|
"title": "Success Result",
|
|
"url": "https://example.com/success",
|
|
"content": "Success content",
|
|
},
|
|
},
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(response)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
TavilyEnabled: true,
|
|
TavilyAPIKeys: []string{"key1", "key2"},
|
|
TavilyBaseURL: server.URL,
|
|
TavilyMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
args := map[string]any{
|
|
"query": "test query",
|
|
}
|
|
|
|
result := tool.Execute(ctx, args)
|
|
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got Error: %s", result.ForLLM)
|
|
}
|
|
if !strings.Contains(result.ForUser, "Success Result") {
|
|
t.Errorf("Expected failover to second key and success result, got: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_SearXNGSearch_RangeMapping(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if got := r.URL.Query().Get("time_range"); got != "year" {
|
|
t.Fatalf("expected time_range=year, got %q", got)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"results": []map[string]any{
|
|
{"title": "Recent result", "url": "https://example.com/1", "content": "snippet"},
|
|
},
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
SearXNGEnabled: true,
|
|
SearXNGBaseURL: server.URL,
|
|
SearXNGMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
"range": "y",
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success, got %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_GLMSearch_Success(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if r.Header.Get("Content-Type") != "application/json" {
|
|
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
|
}
|
|
if r.Header.Get("Authorization") != "Bearer test-glm-key" {
|
|
t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization"))
|
|
}
|
|
|
|
var payload map[string]any
|
|
json.NewDecoder(r.Body).Decode(&payload)
|
|
if payload["search_query"] != "test query" {
|
|
t.Errorf("Expected search_query 'test query', got %v", payload["search_query"])
|
|
}
|
|
if payload["search_engine"] != "search_std" {
|
|
t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"])
|
|
}
|
|
|
|
response := map[string]any{
|
|
"id": "web-search-test",
|
|
"created": 1709568000,
|
|
"search_result": []map[string]any{
|
|
{
|
|
"title": "Test GLM Result",
|
|
"content": "GLM search snippet",
|
|
"link": "https://example.com/glm",
|
|
"media": "Example",
|
|
"publish_date": "2026-03-04",
|
|
},
|
|
},
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(response)
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
GLMSearchEnabled: true,
|
|
GLMSearchAPIKey: "test-glm-key",
|
|
GLMSearchBaseURL: server.URL,
|
|
GLMSearchEngine: "search_std",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
})
|
|
|
|
if result.IsError {
|
|
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
|
}
|
|
if !strings.Contains(result.ForUser, "Test GLM Result") {
|
|
t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser)
|
|
}
|
|
if !strings.Contains(result.ForUser, "https://example.com/glm") {
|
|
t.Errorf("Expected URL in output, got: %s", result.ForUser)
|
|
}
|
|
if !strings.Contains(result.ForUser, "via GLM Search") {
|
|
t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_GLMSearch_RangeMapping(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("failed to decode payload: %v", err)
|
|
}
|
|
if payload["search_recency_filter"] != "oneMonth" {
|
|
t.Fatalf("expected search_recency_filter=oneMonth, got %v", payload["search_recency_filter"])
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"search_result": []map[string]any{
|
|
{"title": "Recent GLM Result", "content": "snippet", "link": "https://example.com/glm-range"},
|
|
},
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
GLMSearchEnabled: true,
|
|
GLMSearchAPIKey: "test-glm-key",
|
|
GLMSearchBaseURL: server.URL,
|
|
GLMSearchEngine: "search_std",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
"range": "m",
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success, got %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_BaiduSearch_RangeMapping(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var payload map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
t.Fatalf("failed to decode payload: %v", err)
|
|
}
|
|
if payload["search_recency_filter"] != "week" {
|
|
t.Fatalf("expected search_recency_filter=week for day fallback, got %v", payload["search_recency_filter"])
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"references": []map[string]any{
|
|
{"title": "Recent Baidu Result", "url": "https://example.com/baidu", "content": "snippet"},
|
|
},
|
|
})
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
BaiduSearchEnabled: true,
|
|
BaiduSearchAPIKey: "test-baidu-key",
|
|
BaiduSearchBaseURL: server.URL,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
"range": "d",
|
|
})
|
|
if result.IsError {
|
|
t.Fatalf("expected success, got %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_GLMSearch_APIError(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte(`{"error":"invalid api key"}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
GLMSearchEnabled: true,
|
|
GLMSearchAPIKey: "bad-key",
|
|
GLMSearchBaseURL: server.URL,
|
|
GLMSearchEngine: "search_std",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
result := tool.Execute(context.Background(), map[string]any{
|
|
"query": "test query",
|
|
})
|
|
|
|
if !result.IsError {
|
|
t.Errorf("Expected IsError=true for 401 response")
|
|
}
|
|
if !strings.Contains(result.ForLLM, "status 401") {
|
|
t.Errorf("Expected status 401 in error, got: %s", result.ForLLM)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_GLMSearch_Priority(t *testing.T) {
|
|
// GLM Search should only be selected when all other providers are disabled
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
DuckDuckGoEnabled: true,
|
|
DuckDuckGoMaxResults: 5,
|
|
GLMSearchEnabled: true,
|
|
GLMSearchAPIKey: "test-key",
|
|
GLMSearchBaseURL: "https://example.com",
|
|
GLMSearchEngine: "search_std",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
|
|
// DuckDuckGo should win over GLM Search
|
|
if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok {
|
|
t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider)
|
|
}
|
|
|
|
// With DuckDuckGo disabled, GLM Search should be selected
|
|
tool2, err := NewWebSearchTool(WebSearchToolOptions{
|
|
DuckDuckGoEnabled: false,
|
|
GLMSearchEnabled: true,
|
|
GLMSearchAPIKey: "test-key",
|
|
GLMSearchBaseURL: "https://example.com",
|
|
GLMSearchEngine: "search_std",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool2.provider.(*GLMSearchProvider); !ok {
|
|
t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_SogouSearch_Success(t *testing.T) {
|
|
provider := &SogouSearchProvider{
|
|
client: &http.Client{
|
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
rec := httptest.NewRecorder()
|
|
fmt.Fprint(rec, `<html><body>
|
|
<a class=resultLink href="/link?url=https%3A%2F%2Fexample.com%2Fa" id="sogou_vr_0_0">Result A</a>
|
|
<div class="clamp3">Snippet A</div>
|
|
<a class=resultLink href="/link?url=https%3A%2F%2Fexample.com%2Fb" id="sogou_vr_0_1">Result B</a>
|
|
<div class="clamp3">Snippet B</div>
|
|
</body></html>`)
|
|
return rec.Result(), nil
|
|
}),
|
|
},
|
|
}
|
|
|
|
out, err := provider.Search(context.Background(), "test query", 2, "")
|
|
if err != nil {
|
|
t.Fatalf("Search() error: %v", err)
|
|
}
|
|
if !strings.Contains(out, "via Sogou") || !strings.Contains(out, "https://example.com/a") {
|
|
t.Fatalf("unexpected output: %s", out)
|
|
}
|
|
}
|
|
|
|
func TestApplySogouRangeHint(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
query string
|
|
rangeCode string
|
|
want string
|
|
}{
|
|
{name: "empty range", query: "golang", rangeCode: "", want: "golang"},
|
|
{name: "day", query: "golang", rangeCode: "d", want: "golang 最近一天"},
|
|
{name: "week", query: "golang", rangeCode: "w", want: "golang 最近一周"},
|
|
{name: "month", query: "golang", rangeCode: "m", want: "golang 最近一个月"},
|
|
{name: "year", query: "golang", rangeCode: "y", want: "golang 最近一年"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := applySogouRangeHint(tt.query, tt.rangeCode); got != tt.want {
|
|
t.Fatalf("applySogouRangeHint(%q, %q) = %q, want %q", tt.query, tt.rangeCode, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPrefersDuckDuckGoQuery(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
query string
|
|
want bool
|
|
}{
|
|
{name: "english words", query: "golang web search", want: true},
|
|
{name: "english with numbers", query: "OpenAI o3 price 2026", want: true},
|
|
{name: "chinese", query: "今天上海天气", want: false},
|
|
{name: "mixed with han", query: "golang 中文 教程", want: false},
|
|
{name: "numbers only", query: "2026 04 15", want: false},
|
|
{name: "blank", query: " ", want: false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if got := prefersDuckDuckGoQuery(tt.query); got != tt.want {
|
|
t.Fatalf("prefersDuckDuckGoQuery(%q) = %v, want %v", tt.query, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPrefersDuckDuckGoQuery_DoesNotUseGlobalLanguageFallback(t *testing.T) {
|
|
if prefersDuckDuckGoQuery("2026 04 15") {
|
|
t.Fatal("numeric query should default to Sogou when no script-specific hint is present")
|
|
}
|
|
}
|
|
|
|
func TestWebTool_SogouPriorityAndExplicitProvider(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
DuckDuckGoEnabled: true,
|
|
DuckDuckGoMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*SogouSearchProvider); !ok {
|
|
t.Fatalf("expected SogouSearchProvider, got %T", tool.provider)
|
|
}
|
|
|
|
tool, err = NewWebSearchTool(WebSearchToolOptions{
|
|
Provider: "duckduckgo",
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
DuckDuckGoEnabled: true,
|
|
DuckDuckGoMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok {
|
|
t.Fatalf("expected DuckDuckGoSearchProvider, got %T", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_AutoProviderPrefersConfiguredProvidersBeforeSogou(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
BraveEnabled: true,
|
|
BraveAPIKeys: []string{"brave-key"},
|
|
BraveMaxResults: 5,
|
|
DuckDuckGoEnabled: true,
|
|
DuckDuckGoMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*BraveSearchProvider); !ok {
|
|
t.Fatalf("expected BraveSearchProvider, got %T", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_AutoProviderPrefersConfiguredProvidersBeforeGemini(t *testing.T) {
|
|
opts := WebSearchToolOptions{
|
|
GeminiEnabled: true,
|
|
GeminiAPIKey: "google-key",
|
|
GeminiModel: "gemini-2.5-flash",
|
|
GeminiMaxResults: 5,
|
|
BraveEnabled: true,
|
|
BraveAPIKeys: []string{"brave-key"},
|
|
BraveMaxResults: 5,
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
DuckDuckGoEnabled: true,
|
|
DuckDuckGoMaxResults: 5,
|
|
}
|
|
|
|
name, err := ResolveWebSearchProviderName(opts, "best robotics companies")
|
|
if err != nil {
|
|
t.Fatalf("ResolveWebSearchProviderName() error: %v", err)
|
|
}
|
|
if name != "brave" {
|
|
t.Fatalf("provider = %q, want brave", name)
|
|
}
|
|
|
|
name, err = ResolveWebSearchProviderName(opts, "今天上海天气")
|
|
if err != nil {
|
|
t.Fatalf("ResolveWebSearchProviderName() error: %v", err)
|
|
}
|
|
if name != "brave" {
|
|
t.Fatalf("provider = %q, want brave", name)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_GeminiRequiresAPIKey(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
Provider: "gemini",
|
|
GeminiEnabled: true,
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*SogouSearchProvider); !ok {
|
|
t.Fatalf("expected SogouSearchProvider after missing Gemini API key fallback, got %T", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestGeminiSearchProvider_SearchSuccess(t *testing.T) {
|
|
provider := &GeminiSearchProvider{
|
|
apiKey: "google-key",
|
|
model: "gemini-2.5-flash",
|
|
client: &http.Client{
|
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
if req.Method != http.MethodPost {
|
|
t.Fatalf("method = %s, want POST", req.Method)
|
|
}
|
|
if got := req.Header.Get("X-Goog-Api-Key"); got != "google-key" {
|
|
t.Fatalf("X-Goog-Api-Key = %q, want google-key", got)
|
|
}
|
|
if !strings.Contains(req.URL.String(), "/models/gemini-2.5-flash:generateContent") {
|
|
t.Fatalf("unexpected URL: %s", req.URL.String())
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
rec.WriteHeader(http.StatusOK)
|
|
fmt.Fprint(rec, `{
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{"text": "Answer paragraph one."},
|
|
{"text": "Answer paragraph two."}
|
|
]
|
|
},
|
|
"groundingMetadata": {
|
|
"groundingChunks": [
|
|
{"web": {"uri": "https://example.com/a", "title": "Result A"}},
|
|
{"web": {"uri": "https://example.com/b", "title": "Result B"}},
|
|
{"web": {"uri": "https://example.com/c", "title": "Result C"}}
|
|
]
|
|
}
|
|
}
|
|
]
|
|
}`)
|
|
return rec.Result(), nil
|
|
}),
|
|
},
|
|
}
|
|
|
|
out, err := provider.Search(context.Background(), "robotics", 2, "")
|
|
if err != nil {
|
|
t.Fatalf("Search() error: %v", err)
|
|
}
|
|
if !strings.Contains(out, "Results for: robotics (via Gemini Google Search)") {
|
|
t.Fatalf("missing header in output: %s", out)
|
|
}
|
|
if !strings.Contains(out, "Answer paragraph one.") || !strings.Contains(out, "Answer paragraph two.") {
|
|
t.Fatalf("missing response text in output: %s", out)
|
|
}
|
|
if !strings.Contains(out, "1. Result A") || !strings.Contains(out, "2. Result B") {
|
|
t.Fatalf("missing citations in output: %s", out)
|
|
}
|
|
if strings.Contains(out, "Result C") {
|
|
t.Fatalf("expected citations to be limited to count=2, got: %s", out)
|
|
}
|
|
}
|
|
|
|
func TestGeminiSearchProvider_SearchIgnoresRange(t *testing.T) {
|
|
provider := &GeminiSearchProvider{
|
|
apiKey: "google-key",
|
|
model: "gemini-2.5-flash",
|
|
client: &http.Client{
|
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
rec := httptest.NewRecorder()
|
|
rec.WriteHeader(http.StatusOK)
|
|
fmt.Fprint(rec, `{
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{"text": "Recent robotics result."}
|
|
]
|
|
}
|
|
}
|
|
]
|
|
}`)
|
|
return rec.Result(), nil
|
|
}),
|
|
},
|
|
}
|
|
|
|
out, err := provider.Search(context.Background(), "robotics", 2, "d")
|
|
if err != nil {
|
|
t.Fatalf("Search() error: %v", err)
|
|
}
|
|
if !strings.Contains(out, "Recent robotics result.") {
|
|
t.Fatalf("missing response text in output: %s", out)
|
|
}
|
|
}
|
|
|
|
func TestGeminiSearchProvider_SearchAPIError(t *testing.T) {
|
|
provider := &GeminiSearchProvider{
|
|
apiKey: "google-key",
|
|
model: "gemini-2.5-flash",
|
|
client: &http.Client{
|
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
rec := httptest.NewRecorder()
|
|
rec.WriteHeader(http.StatusTooManyRequests)
|
|
fmt.Fprint(rec, `{"error":"quota exceeded"}`)
|
|
return rec.Result(), nil
|
|
}),
|
|
},
|
|
}
|
|
|
|
_, err := provider.Search(context.Background(), "robotics", 2, "")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "status 429") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestGeminiSearchProvider_SearchEmptyCandidates(t *testing.T) {
|
|
provider := &GeminiSearchProvider{
|
|
apiKey: "google-key",
|
|
model: "gemini-2.5-flash",
|
|
client: &http.Client{
|
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
rec := httptest.NewRecorder()
|
|
rec.WriteHeader(http.StatusOK)
|
|
fmt.Fprint(rec, `{"candidates":[]}`)
|
|
return rec.Result(), nil
|
|
}),
|
|
},
|
|
}
|
|
|
|
out, err := provider.Search(context.Background(), "robotics", 2, "")
|
|
if err != nil {
|
|
t.Fatalf("Search() error: %v", err)
|
|
}
|
|
if out != "No results for: robotics" {
|
|
t.Fatalf("output = %q, want %q", out, "No results for: robotics")
|
|
}
|
|
}
|
|
|
|
func TestWebTool_ExplicitProviderFallsBackWhenMissingCredentials(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
Provider: "brave",
|
|
BraveEnabled: true,
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*SogouSearchProvider); !ok {
|
|
t.Fatalf("expected SogouSearchProvider after fallback, got %T", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_ExplicitProviderFallsBackWhenMissingBaseURL(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
Provider: "searxng",
|
|
SearXNGEnabled: true,
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*SogouSearchProvider); !ok {
|
|
t.Fatalf("expected SogouSearchProvider after fallback, got %T", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_AutoProviderSkipsEnabledButUnreadyProviders(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
Provider: "auto",
|
|
BraveEnabled: true,
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*SogouSearchProvider); !ok {
|
|
t.Fatalf("expected SogouSearchProvider when Brave has no API key, got %T", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestResolveWebSearchProviderName_FallsBackFromExplicitUnavailableProvider(t *testing.T) {
|
|
got, err := ResolveWebSearchProviderName(WebSearchToolOptions{
|
|
Provider: "brave",
|
|
BraveEnabled: true,
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
}, "")
|
|
if err != nil {
|
|
t.Fatalf("ResolveWebSearchProviderName() error: %v", err)
|
|
}
|
|
if got != "sogou" {
|
|
t.Fatalf("ResolveWebSearchProviderName() = %q, want sogou", got)
|
|
}
|
|
}
|
|
|
|
func TestWebTool_UnknownExplicitProviderFallsBackToAuto(t *testing.T) {
|
|
tool, err := NewWebSearchTool(WebSearchToolOptions{
|
|
Provider: "totally_unknown",
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("NewWebSearchTool() error: %v", err)
|
|
}
|
|
if _, ok := tool.provider.(*SogouSearchProvider); !ok {
|
|
t.Fatalf("expected SogouSearchProvider after fallback, got %T", tool.provider)
|
|
}
|
|
}
|
|
|
|
func TestResolveWebSearchProviderName_FallsBackFromUnknownProvider(t *testing.T) {
|
|
got, err := ResolveWebSearchProviderName(WebSearchToolOptions{
|
|
Provider: "totally_unknown",
|
|
SogouEnabled: true,
|
|
SogouMaxResults: 5,
|
|
}, "")
|
|
if err != nil {
|
|
t.Fatalf("ResolveWebSearchProviderName() error: %v", err)
|
|
}
|
|
if got != "sogou" {
|
|
t.Fatalf("ResolveWebSearchProviderName() = %q, want sogou", got)
|
|
}
|
|
}
|
|
|
|
type stubSearchProvider struct {
|
|
result string
|
|
calls []string
|
|
}
|
|
|
|
func (p *stubSearchProvider) Search(
|
|
_ context.Context,
|
|
query string,
|
|
_ int,
|
|
_ string,
|
|
) (string, error) {
|
|
p.calls = append(p.calls, query)
|
|
return p.result, nil
|
|
}
|
|
|
|
func TestWebTool_AutoProviderRoutesQueryLanguageBetweenSogouAndDuckDuckGo(t *testing.T) {
|
|
sogouProvider := &stubSearchProvider{result: "via sogou"}
|
|
duckProvider := &stubSearchProvider{result: "via duckduckgo"}
|
|
tool := &WebSearchTool{
|
|
provider: sogouProvider,
|
|
maxResults: 5,
|
|
providerResolver: func(query string) (SearchProvider, int) {
|
|
if prefersDuckDuckGoQuery(query) {
|
|
return duckProvider, 3
|
|
}
|
|
return sogouProvider, 5
|
|
},
|
|
}
|
|
|
|
enResult := tool.Execute(context.Background(), map[string]any{"query": "golang concurrency", "count": 10})
|
|
if enResult.IsError {
|
|
t.Fatalf("english Execute() returned error: %s", enResult.ForLLM)
|
|
}
|
|
if len(duckProvider.calls) != 1 || duckProvider.calls[0] != "golang concurrency" {
|
|
t.Fatalf("english query should use DuckDuckGo provider, calls=%v", duckProvider.calls)
|
|
}
|
|
if len(sogouProvider.calls) != 0 {
|
|
t.Fatalf("english query should not call Sogou provider, calls=%v", sogouProvider.calls)
|
|
}
|
|
|
|
zhResult := tool.Execute(context.Background(), map[string]any{"query": "今天上海天气"})
|
|
if zhResult.IsError {
|
|
t.Fatalf("chinese Execute() returned error: %s", zhResult.ForLLM)
|
|
}
|
|
if len(sogouProvider.calls) != 1 || sogouProvider.calls[0] != "今天上海天气" {
|
|
t.Fatalf("chinese query should use Sogou provider, calls=%v", sogouProvider.calls)
|
|
}
|
|
}
|
|
|
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return fn(req)
|
|
}
|