diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index a72f95bb1..00b0f096a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -119,7 +119,7 @@ func registerSharedTools( } else if searchTool != nil { agent.Tools.Register(searchTool) } - fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy) + fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) if err != nil { logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) } else { diff --git a/pkg/config/config.go b/pkg/config/config.go index d84772d2b..55d0cfb2c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -523,7 +523,8 @@ type WebToolsConfig struct { Perplexity PerplexityConfig `json:"perplexity"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` + FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"` } type CronToolsConfig struct { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 9313623d1..44f4de7e9 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -315,7 +315,8 @@ func DefaultConfig() *Config { Interval: 5, }, Web: WebToolsConfig{ - Proxy: "", + Proxy: "", + FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default Brave: BraveConfig{ Enabled: false, APIKey: "", diff --git a/pkg/tools/web.go b/pkg/tools/web.go index bf9144f18..10498126b 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -519,18 +520,18 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR } type WebFetchTool struct { - maxChars int - proxy string - client *http.Client + maxChars int + proxy string + client *http.Client + fetchLimitBytes int64 } -func NewWebFetchTool(maxChars int) *WebFetchTool { +func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) { // createHTTPClient cannot fail with an empty proxy string. - tool, _ := NewWebFetchToolWithProxy(maxChars, "") - return tool + return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes) } -func NewWebFetchToolWithProxy(maxChars int, proxy string) (*WebFetchTool, error) { +func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) { if maxChars <= 0 { maxChars = defaultMaxChars } @@ -544,10 +545,14 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string) (*WebFetchTool, error) } return nil } + if fetchLimitBytes <= 0 { + fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback + } return &WebFetchTool{ - maxChars: maxChars, - proxy: proxy, - client: client, + maxChars: maxChars, + proxy: proxy, + client: client, + fetchLimitBytes: fetchLimitBytes, }, nil } @@ -614,10 +619,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe if err != nil { return ErrorResult(fmt.Sprintf("request failed: %v", err)) } + + resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return ErrorResult(fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes)) + } return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 84ec10d96..8a8b88131 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -1,15 +1,21 @@ package tools import ( + "bytes" "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/logger" ) +const testFetchLimit = int64(10 * 1024 * 1024) + // TestWebTool_WebFetch_Success verifies successful URL fetching func TestWebTool_WebFetch_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -19,7 +25,11 @@ func TestWebTool_WebFetch_Success(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -55,7 +65,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -76,7 +90,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) { // TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL func TestWebTool_WebFetch_InvalidURL(t *testing.T) { - tool := NewWebFetchTool(50000) + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + ctx := context.Background() args := map[string]any{ "url": "not-a-valid-url", @@ -97,7 +115,11 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) { // TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { - tool := NewWebFetchTool(50000) + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + ctx := context.Background() args := map[string]any{ "url": "ftp://example.com/file.txt", @@ -118,7 +140,11 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { // TestWebTool_WebFetch_MissingURL verifies error handling for missing URL func TestWebTool_WebFetch_MissingURL(t *testing.T) { - tool := NewWebFetchTool(50000) + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + ctx := context.Background() args := map[string]any{} @@ -146,7 +172,11 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(1000) // Limit to 1000 chars + tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -174,6 +204,49 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } } +func TestWebFetchTool_PayloadTooLarge(t *testing.T) { + // Create a mock HTTP server + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + + // Generate a payload intentionally larger than our limit. + // Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'. + largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100) + + w.Write(largeData) + })) + // Ensure the server is shut down at the end of the test + defer ts.Close() + + // Initialize the tool + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + + // Prepare the arguments pointing to the URL of our local mock server + args := map[string]any{ + "url": ts.URL, + } + + // Execute the tool + ctx := context.Background() + result := tool.Execute(ctx, args) + + // Assuming ErrorResult sets the ForLLM field with the error text. + if result == nil { + t.Fatal("expected a ToolResult, got nil") + } + + // Search for the exact error string we set earlier in the Execute method + expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit) + + if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) { + t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result) + } +} + // TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing func TestWebTool_WebSearch_NoApiKey(t *testing.T) { tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) @@ -224,7 +297,11 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { })) defer server.Close() - tool := NewWebFetchTool(50000) + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + ctx := context.Background() args := map[string]any{ "url": server.URL, @@ -325,7 +402,11 @@ func TestWebFetchTool_extractText(t *testing.T) { // TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain func TestWebTool_WebFetch_MissingDomain(t *testing.T) { - tool := NewWebFetchTool(50000) + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } + ctx := context.Background() args := map[string]any{ "url": "https://", @@ -447,21 +528,22 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { } func TestNewWebFetchToolWithProxy(t *testing.T) { - tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890") + tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890", testFetchLimit) if err != nil { - t.Fatalf("NewWebFetchToolWithProxy() error: %v", err) - } - if tool.maxChars != 1024 { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } else if tool.maxChars != 1024 { t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024) } + if tool.proxy != "http://127.0.0.1:7890" { t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890") } - tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890") + tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890", testFetchLimit) if err != nil { - t.Fatalf("NewWebFetchToolWithProxy() error: %v", err) + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) } + if tool.maxChars != 50000 { t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000) }