diff --git a/config/config.example.json b/config/config.example.json index 9a92ff0c2..350f085d0 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -313,6 +313,7 @@ "allow_write_paths": null, "web": { "enabled": true, + "prefer_native": true, "fetch_limit_bytes": 10485760, "format": "plaintext", "brave": { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 98ef47a99..86994c360 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1037,6 +1037,19 @@ func (al *AgentLoop) runLLMIteration( // Build tool definitions providerToolDefs := agent.Tools.ToProviderDefs() + // Determine whether the provider's native web search should replace + // the client-side web_search tool for this request. Only enable when web + // search is actually enabled and registered (so users who disabled web + // access do not get provider-side search or billing). + _, hasWebSearch := agent.Tools.Get("web_search") + useNativeSearch := al.cfg.Tools.Web.PreferNative && + isNativeSearchProvider(agent.Provider) && + hasWebSearch + + if useNativeSearch { + providerToolDefs = filterClientWebSearch(providerToolDefs) + } + // Log LLM request details logger.DebugCF("agent", "LLM request", map[string]any{ @@ -1045,6 +1058,7 @@ func (al *AgentLoop) runLLMIteration( "model": activeModel, "messages_count": len(messages), "tools_count": len(providerToolDefs), + "native_search": useNativeSearch, "max_tokens": agent.MaxTokens, "temperature": agent.Temperature, "system_prompt_len": len(messages[0].Content), @@ -1067,6 +1081,9 @@ func (al *AgentLoop) runLLMIteration( "temperature": agent.Temperature, "prompt_cache_key": agent.ID, } + if useNativeSearch { + llmOpts["native_search"] = true + } // parseThinkingLevel guarantees ThinkingOff for empty/unknown values, // so checking != ThinkingOff is sufficient. if agent.ThinkingLevel != ThinkingOff { @@ -1976,6 +1993,28 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { return &routing.RoutePeer{Kind: parentKind, ID: parentID} } +// isNativeSearchProvider reports whether the given LLM provider implements +// NativeSearchCapable and returns true for SupportsNativeSearch. +func isNativeSearchProvider(p providers.LLMProvider) bool { + if ns, ok := p.(providers.NativeSearchCapable); ok { + return ns.SupportsNativeSearch() + } + return false +} + +// filterClientWebSearch returns a copy of tools with the client-side +// web_search tool removed. Used when native provider search is preferred. +func filterClientWebSearch(tools []providers.ToolDefinition) []providers.ToolDefinition { + result := make([]providers.ToolDefinition, 0, len(tools)) + for _, t := range tools { + if strings.EqualFold(t.Function.Name, "web_search") { + continue + } + result = append(result, t) + } + return result +} + // Helper to extract provider from registry for cleanup func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) { if registry == nil { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 25ee6ab4d..8432ccac4 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -1426,3 +1426,84 @@ func TestResolveMediaRefs_MixedImageAndFile(t *testing.T) { t.Fatalf("expected content %q, got %q", expectedContent, result[0].Content) } } + +// --- Native search helper tests --- + +type nativeSearchProvider struct { + supported bool +} + +func (p *nativeSearchProvider) Chat( + ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition, + model string, opts map[string]any, +) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "ok"}, nil +} + +func (p *nativeSearchProvider) GetDefaultModel() string { return "test-model" } + +func (p *nativeSearchProvider) SupportsNativeSearch() bool { return p.supported } + +type plainProvider struct{} + +func (p *plainProvider) Chat( + ctx context.Context, msgs []providers.Message, tools []providers.ToolDefinition, + model string, opts map[string]any, +) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "ok"}, nil +} + +func (p *plainProvider) GetDefaultModel() string { return "test-model" } + +func TestIsNativeSearchProvider_Supported(t *testing.T) { + if !isNativeSearchProvider(&nativeSearchProvider{supported: true}) { + t.Fatal("expected true for provider that supports native search") + } +} + +func TestIsNativeSearchProvider_NotSupported(t *testing.T) { + if isNativeSearchProvider(&nativeSearchProvider{supported: false}) { + t.Fatal("expected false for provider that does not support native search") + } +} + +func TestIsNativeSearchProvider_NoInterface(t *testing.T) { + if isNativeSearchProvider(&plainProvider{}) { + t.Fatal("expected false for provider that does not implement NativeSearchCapable") + } +} + +func TestFilterClientWebSearch_RemovesWebSearch(t *testing.T) { + defs := []providers.ToolDefinition{ + {Type: "function", Function: providers.ToolFunctionDefinition{Name: "web_search"}}, + {Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}}, + {Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}}, + } + result := filterClientWebSearch(defs) + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } + for _, td := range result { + if td.Function.Name == "web_search" { + t.Fatal("web_search should be filtered out") + } + } +} + +func TestFilterClientWebSearch_NoWebSearch(t *testing.T) { + defs := []providers.ToolDefinition{ + {Type: "function", Function: providers.ToolFunctionDefinition{Name: "read_file"}}, + {Type: "function", Function: providers.ToolFunctionDefinition{Name: "exec"}}, + } + result := filterClientWebSearch(defs) + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } +} + +func TestFilterClientWebSearch_EmptyInput(t *testing.T) { + result := filterClientWebSearch(nil) + if len(result) != 0 { + t.Fatalf("len(result) = %d, want 0", len(result)) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 7a47fccae..49fb3679f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -693,6 +693,12 @@ type WebToolsConfig struct { Perplexity PerplexityConfig ` json:"perplexity"` SearXNG SearXNGConfig ` json:"searxng"` GLMSearch GLMSearchConfig ` json:"glm_search"` + // PreferNative controls whether to use provider-native web search when + // the active LLM supports it (e.g. OpenAI web_search_preview). When true, + // the client-side web_search tool is hidden to avoid duplicate search surfaces, + // and the provider's built-in search is used instead. Falls back to client-side + // search when the provider does not support native search. + PreferNative bool `json:"prefer_native" env:"PICOCLAW_TOOLS_WEB_PREFER_NATIVE"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index f4f8979e1..82a845471 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -401,6 +401,45 @@ func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) { } } +func TestDefaultConfig_WebPreferNativeEnabled(t *testing.T) { + cfg := DefaultConfig() + if !cfg.Tools.Web.PreferNative { + t.Fatal("DefaultConfig().Tools.Web.PreferNative should be true") + } +} + +func TestLoadConfig_WebPreferNativeDefaultsTrueWhenUnset(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"tools":{"web":{"enabled":true}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if !cfg.Tools.Web.PreferNative { + t.Fatal("PreferNative should remain true when unset in config file") + } +} + +func TestLoadConfig_WebPreferNativeCanBeDisabled(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"tools":{"web":{"prefer_native":false}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Tools.Web.PreferNative { + t.Fatal("PreferNative should be false when disabled in config file") + } +} + func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) { cfg := DefaultConfig() if !cfg.Tools.Exec.AllowRemote { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index eebb1dce3..9e8668779 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -411,6 +411,7 @@ func DefaultConfig() *Config { ToolConfig: ToolConfig{ Enabled: true, }, + PreferNative: true, Proxy: "", FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default Format: "plaintext", diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index cf5c2d876..4a6d61a4b 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -95,7 +95,10 @@ func (p *CodexProvider) Chat( ) } - params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch) + // Respect tools.web.prefer_native: only inject native search when the agent + // loop requested it (options["native_search"]), so prefer_native: false + useNativeSearch := p.enableWebSearch && (options["native_search"] == true) + params := buildCodexParams(messages, tools, resolvedModel, options, useNativeSearch) stream := p.client.Responses.NewStreaming(ctx, params, opts...) defer stream.Close() @@ -157,6 +160,10 @@ func (p *CodexProvider) GetDefaultModel() string { return codexDefaultModel } +func (p *CodexProvider) SupportsNativeSearch() bool { + return p.enableWebSearch +} + func resolveCodexModel(model string) (string, string) { m := strings.ToLower(strings.TrimSpace(model)) if m == "" { diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index dd5ad2637..3a0da5e3b 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -355,7 +355,9 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") messages := []Message{{Role: "user", Content: "Hello"}} - resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"max_tokens": 1024}) + // Pass native_search so Codex injects built-in web search (mirrors agent loop when prefer_native is true). + opts := map[string]any{"max_tokens": 1024, "native_search": true} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", opts) if err != nil { t.Fatalf("Chat() error: %v", err) } diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 5c328f418..4d823630e 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -55,3 +55,7 @@ func (p *HTTPProvider) Chat( func (p *HTTPProvider) GetDefaultModel() string { return "" } + +func (p *HTTPProvider) SupportsNativeSearch() bool { + return p.delegate.SupportsNativeSearch() +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index fb2abaa5c..261f2d482 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -103,8 +103,11 @@ func (p *Provider) Chat( "messages": common.SerializeMessages(messages), } - if len(tools) > 0 { - requestBody["tools"] = tools + // When fallback uses a different provider (e.g. DeepSeek), that provider must not inject web_search_preview. + nativeSearch, _ := options["native_search"].(bool) + nativeSearch = nativeSearch && isNativeSearchHost(p.apiBase) + if len(tools) > 0 || nativeSearch { + requestBody["tools"] = buildToolsList(tools, nativeSearch) requestBody["tool_choice"] = "auto" } @@ -195,6 +198,33 @@ func normalizeModel(model, apiBase string) string { } } +func buildToolsList(tools []ToolDefinition, nativeSearch bool) []any { + result := make([]any, 0, len(tools)+1) + for _, t := range tools { + if nativeSearch && strings.EqualFold(t.Function.Name, "web_search") { + continue + } + result = append(result, t) + } + if nativeSearch { + result = append(result, map[string]any{"type": "web_search_preview"}) + } + return result +} + +func (p *Provider) SupportsNativeSearch() bool { + return isNativeSearchHost(p.apiBase) +} + +func isNativeSearchHost(apiBase string) bool { + u, err := url.Parse(apiBase) + if err != nil { + return false + } + host := u.Hostname() + return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") +} + // supportsPromptCacheKey reports whether the given API base is known to // support the prompt_cache_key request field. Currently only OpenAI's own // API and Azure OpenAI support this. All other OpenAI-compatible providers diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index ed9747f9d..a3288a023 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -824,6 +824,232 @@ func TestSupportsPromptCacheKey(t *testing.T) { } } +func TestBuildToolsList_NativeSearchAddsWebSearchPreview(t *testing.T) { + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}}, + } + result := buildToolsList(tools, true) + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } + wsEntry, ok := result[1].(map[string]any) + if !ok { + t.Fatalf("web search entry is %T, want map[string]any", result[1]) + } + if wsEntry["type"] != "web_search_preview" { + t.Fatalf("type = %v, want web_search_preview", wsEntry["type"]) + } +} + +func TestBuildToolsList_NativeSearchFiltersClientWebSearch(t *testing.T) { + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}}, + {Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}}, + } + result := buildToolsList(tools, true) + for _, entry := range result { + if td, ok := entry.(ToolDefinition); ok && strings.EqualFold(td.Function.Name, "web_search") { + t.Fatal("client-side web_search should be filtered out when native search is enabled") + } + } + if len(result) != 2 { // read_file + web_search_preview + t.Fatalf("len(result) = %d, want 2 (read_file + web_search_preview)", len(result)) + } +} + +func TestBuildToolsList_NoNativeSearchPassesThrough(t *testing.T) { + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}}, + {Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}}, + } + result := buildToolsList(tools, false) + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } +} + +func TestIsNativeSearchHost(t *testing.T) { + tests := []struct { + apiBase string + want bool + }{ + {"https://api.openai.com/v1", true}, + {"https://myresource.openai.azure.com/openai/deployments/gpt-4", true}, + {"https://api.mistral.ai/v1", false}, + {"https://api.deepseek.com/v1", false}, + {"https://api.groq.com/openai/v1", false}, + {"http://localhost:11434/v1", false}, + {"", false}, + } + for _, tt := range tests { + if got := isNativeSearchHost(tt.apiBase); got != tt.want { + t.Errorf("isNativeSearchHost(%q) = %v, want %v", tt.apiBase, got, tt.want) + } + } +} + +func TestSupportsNativeSearch_OpenAI(t *testing.T) { + p := NewProvider("key", "https://api.openai.com/v1", "") + if !p.SupportsNativeSearch() { + t.Fatal("OpenAI provider should support native search") + } +} + +func TestSupportsNativeSearch_NonOpenAI(t *testing.T) { + p := NewProvider("key", "https://api.deepseek.com/v1", "") + if p.SupportsNativeSearch() { + t.Fatal("DeepSeek provider should not support native search") + } +} + +func TestProviderChat_NativeSearchToolInjected(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + p.apiBase = "https://api.openai.com/v1" + p.httpClient = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + r.URL, _ = url.Parse(server.URL + r.URL.Path) + return http.DefaultTransport.RoundTrip(r) + }), + } + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}}, + } + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + tools, + "gpt-5.4", + map[string]any{"native_search": true}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + toolsRaw, ok := requestBody["tools"].([]any) + if !ok { + t.Fatalf("tools is %T, want []any", requestBody["tools"]) + } + if len(toolsRaw) != 2 { + t.Fatalf("len(tools) = %d, want 2 (read_file + web_search_preview)", len(toolsRaw)) + } + + lastTool, ok := toolsRaw[1].(map[string]any) + if !ok { + t.Fatalf("last tool is %T, want map[string]any", toolsRaw[1]) + } + if lastTool["type"] != "web_search_preview" { + t.Fatalf("last tool type = %v, want web_search_preview", lastTool["type"]) + } +} + +func TestProviderChat_NativeSearchNotInjectedWithoutOption(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}}, + } + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + tools, + "gpt-5.4", + map[string]any{}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + toolsRaw, ok := requestBody["tools"].([]any) + if !ok { + t.Fatalf("tools is %T, want []any", requestBody["tools"]) + } + if len(toolsRaw) != 1 { + t.Fatalf("len(tools) = %d, want 1 (web_search only)", len(toolsRaw)) + } +} + +// TestProviderChat_NativeSearchIgnoredOnNonOpenAI verifies that when native_search +// is true in options but the provider's apiBase is not OpenAI (e.g. fallback to DeepSeek), +// we do not inject web_search_preview to avoid API errors. +func TestProviderChat_NativeSearchIgnoredOnNonOpenAI(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + // Use server.URL so host is not api.openai.com — simulates DeepSeek/other provider + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "deepseek-chat", + map[string]any{"native_search": true}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + // Should not have tools at all (no tools passed, and we must not add web_search_preview) + if toolsRaw, ok := requestBody["tools"]; ok { + t.Fatalf("tools should be omitted for non-OpenAI when only native_search was requested, got %v", toolsRaw) + } +} + func TestSerializeMessages_StripsSystemParts(t *testing.T) { messages := []protocoltypes.Message{ { diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 68bbd1e65..1f28bc4ad 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -44,6 +44,15 @@ type ThinkingCapable interface { SupportsThinking() bool } +// NativeSearchCapable is an optional interface for providers that support +// built-in web search during LLM inference (e.g. OpenAI web_search_preview, +// xAI Grok search). When the active provider implements this interface and +// returns true, the agent loop can hide the client-side web_search tool to +// avoid duplicate search surfaces and use the provider's native search instead. +type NativeSearchCapable interface { + SupportsNativeSearch() bool +} + // FailoverReason classifies why an LLM request failed for fallback decisions. type FailoverReason string