From 3390576eeacb97bdd3da95b156e226ab72ee0929 Mon Sep 17 00:00:00 2001 From: Zenix Date: Wed, 18 Feb 2026 17:30:30 +0900 Subject: [PATCH] Feature/websearch OpenAI (#118) * feature: add web search for codex models * fix: use more elegant way to solve the issue. --- config/config.example.json | 5 +- pkg/config/config.go | 33 ++++--- pkg/config/config_test.go | 39 ++++++++ pkg/migrate/config.go | 12 ++- pkg/providers/codex_provider.go | 37 +++++--- pkg/providers/codex_provider_test.go | 129 +++++++++++++++++++++++++-- pkg/providers/http_provider.go | 14 +-- 7 files changed, 230 insertions(+), 39 deletions(-) diff --git a/config/config.example.json b/config/config.example.json index 7cd0ab8c6..37c2bcd81 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -79,7 +79,8 @@ }, "openai": { "api_key": "", - "api_base": "" + "api_base": "", + "web_search": true }, "openrouter": { "api_key": "sk-or-v1-xxx", @@ -144,4 +145,4 @@ "host": "0.0.0.0", "port": 18790 } -} \ No newline at end of file +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 1d34f56f3..92a4a5862 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -167,19 +167,19 @@ type DevicesConfig struct { } type ProvidersConfig struct { - Anthropic ProviderConfig `json:"anthropic"` - OpenAI ProviderConfig `json:"openai"` - OpenRouter ProviderConfig `json:"openrouter"` - Groq ProviderConfig `json:"groq"` - Zhipu ProviderConfig `json:"zhipu"` - VLLM ProviderConfig `json:"vllm"` - Gemini ProviderConfig `json:"gemini"` - Nvidia ProviderConfig `json:"nvidia"` - Ollama ProviderConfig `json:"ollama"` - Moonshot ProviderConfig `json:"moonshot"` - ShengSuanYun ProviderConfig `json:"shengsuanyun"` - DeepSeek ProviderConfig `json:"deepseek"` - GitHubCopilot ProviderConfig `json:"github_copilot"` + Anthropic ProviderConfig `json:"anthropic"` + OpenAI OpenAIProviderConfig `json:"openai"` + OpenRouter ProviderConfig `json:"openrouter"` + Groq ProviderConfig `json:"groq"` + Zhipu ProviderConfig `json:"zhipu"` + VLLM ProviderConfig `json:"vllm"` + Gemini ProviderConfig `json:"gemini"` + Nvidia ProviderConfig `json:"nvidia"` + Ollama ProviderConfig `json:"ollama"` + Moonshot ProviderConfig `json:"moonshot"` + ShengSuanYun ProviderConfig `json:"shengsuanyun"` + DeepSeek ProviderConfig `json:"deepseek"` + GitHubCopilot ProviderConfig `json:"github_copilot"` } type ProviderConfig struct { @@ -190,6 +190,11 @@ type ProviderConfig struct { ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc` } +type OpenAIProviderConfig struct { + ProviderConfig + WebSearch bool `json:"web_search" env:"PICOCLAW_PROVIDERS_OPENAI_WEB_SEARCH"` +} + type GatewayConfig struct { Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` @@ -308,7 +313,7 @@ func DefaultConfig() *Config { }, Providers: ProvidersConfig{ Anthropic: ProviderConfig{}, - OpenAI: ProviderConfig{}, + OpenAI: OpenAIProviderConfig{WebSearch: true}, OpenRouter: ProviderConfig{}, Groq: ProviderConfig{}, Zhipu: ProviderConfig{}, diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index febfd0456..a1f73f0b3 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -204,3 +204,42 @@ func TestConfig_Complete(t *testing.T) { t.Error("Heartbeat should be enabled by default") } } + +func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) { + cfg := DefaultConfig() + if !cfg.Providers.OpenAI.WebSearch { + t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true") + } +} + +func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if !cfg.Providers.OpenAI.WebSearch { + t.Fatal("OpenAI codex web search should remain true when unset in config file") + } +} + +func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Providers.OpenAI.WebSearch { + t.Fatal("OpenAI codex web search should be false when disabled in config file") + } +} diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go index 9c1e36359..57032e566 100644 --- a/pkg/migrate/config.go +++ b/pkg/migrate/config.go @@ -108,7 +108,10 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error case "anthropic": cfg.Providers.Anthropic = pc case "openai": - cfg.Providers.OpenAI = pc + cfg.Providers.OpenAI = config.OpenAIProviderConfig{ + ProviderConfig: pc, + WebSearch: getBoolOrDefault(pMap, "web_search", true), + } case "openrouter": cfg.Providers.OpenRouter = pc case "groq": @@ -363,6 +366,13 @@ func getBool(data map[string]interface{}, key string) (bool, bool) { return b, ok } +func getBoolOrDefault(data map[string]interface{}, key string, defaultVal bool) bool { + if v, ok := getBool(data, key); ok { + return v + } + return defaultVal +} + func getStringSlice(data map[string]interface{}, key string) []string { v, ok := data[key] if !ok { diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index 7617bf716..e3526cfb5 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -18,9 +18,10 @@ const codexDefaultModel = "gpt-5.2" const codexDefaultInstructions = "You are Codex, a coding assistant." type CodexProvider struct { - client *openai.Client - accountID string - tokenSource func() (string, string, error) + client *openai.Client + accountID string + tokenSource func() (string, string, error) + enableWebSearch bool } const defaultCodexInstructions = "You are Codex, a coding assistant." @@ -37,8 +38,9 @@ func NewCodexProvider(token, accountID string) *CodexProvider { } client := openai.NewClient(opts...) return &CodexProvider{ - client: &client, - accountID: accountID, + client: &client, + accountID: accountID, + enableWebSearch: true, } } @@ -78,7 +80,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To }) } - params := buildCodexParams(messages, tools, resolvedModel, options) + params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch) stream := p.client.Responses.NewStreaming(ctx, params, opts...) defer stream.Close() @@ -182,7 +184,7 @@ func resolveCodexModel(model string) (string, string) { return codexDefaultModel, "unsupported model family" } -func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { +func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams { var inputItems responses.ResponseInputParam var instructions string @@ -266,8 +268,8 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, params.Instructions = openai.Opt(defaultCodexInstructions) } - if len(tools) > 0 { - params.Tools = translateToolsForCodex(tools) + if len(tools) > 0 || enableWebSearch { + params.Tools = translateToolsForCodex(tools, enableWebSearch) } return params @@ -297,9 +299,19 @@ func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) return name, "{}", true } -func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { - result := make([]responses.ToolUnionParam, 0, len(tools)) +func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam { + capHint := len(tools) + if enableWebSearch { + capHint++ + } + result := make([]responses.ToolUnionParam, 0, capHint) for _, t := range tools { + if t.Type != "function" { + continue + } + if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") { + continue + } ft := responses.FunctionToolParam{ Name: t.Function.Name, Parameters: t.Function.Parameters, @@ -310,6 +322,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { } result = append(result, responses.ToolUnionParam{OfFunction: &ft}) } + if enableWebSearch { + result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch)) + } return result } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 8406760c4..92e276165 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -19,7 +19,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) { params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ "max_tokens": 2048, "temperature": 0.7, - }) + }, true) if params.Model != "gpt-4o" { t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") } @@ -39,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { {Role: "system", Content: "You are helpful"}, {Role: "user", Content: "Hi"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true) if !params.Instructions.Valid() { t.Fatal("Instructions should be set") } @@ -59,7 +59,7 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) { }, {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false) if params.Input.OfInputItemList == nil { t.Fatal("Input.OfInputItemList should not be nil") } @@ -87,7 +87,7 @@ func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) { {Role: "tool", Content: "ok", ToolCallID: "call_1"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false) if params.Input.OfInputItemList == nil { t.Fatal("Input.OfInputItemList should not be nil") } @@ -123,7 +123,7 @@ func TestBuildCodexParams_WithTools(t *testing.T) { }, }, } - params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false) if len(params.Tools) != 1 { t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) } @@ -136,12 +136,61 @@ func TestBuildCodexParams_WithTools(t *testing.T) { } func TestBuildCodexParams_StoreIsFalse(t *testing.T) { - params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false) if !params.Store.Valid() || params.Store.Or(true) != false { t.Error("Store should be explicitly set to false") } } +func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfWebSearch == nil { + t.Fatal("Tool should include built-in web_search") + } + if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch { + t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch) + } +} + +func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "web_search", + Description: "local web search", + Parameters: map[string]interface{}{ + "type": "object", + }, + }, + }, + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "read_file", + Description: "read file", + Parameters: map[string]interface{}{ + "type": "object", + }, + }, + }, + } + + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true) + if len(params.Tools) != 2 { + t.Fatalf("len(Tools) = %d, want 2", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" { + t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0]) + } + if params.Tools[1].OfWebSearch == nil { + t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1]) + } +} + func TestParseCodexResponse_TextOutput(t *testing.T) { respJSON := `{ "id": "resp_test", @@ -260,6 +309,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest) return } + toolsAny, ok := reqBody["tools"].([]interface{}) + if !ok || len(toolsAny) != 1 { + http.Error(w, "missing default web search tool", http.StatusBadRequest) + return + } + toolObj, ok := toolsAny[0].(map[string]interface{}) + if !ok || toolObj["type"] != "web_search" { + http.Error(w, "expected web_search tool", http.StatusBadRequest) + return + } resp := map[string]interface{}{ "id": "resp_test", @@ -307,6 +366,64 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { } } +func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if _, ok := reqBody["tools"]; ok { + http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 4, + "output_tokens": 3, + "total_tokens": 7, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.enableWebSearch = false + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/responses" { diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 4cf2c6db2..946aa29d2 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -208,7 +208,7 @@ func createClaudeAuthProvider() (LLMProvider, error) { return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil } -func createCodexAuthProvider() (LLMProvider, error) { +func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) { cred, err := auth.GetCredential("openai") if err != nil { return nil, fmt.Errorf("loading auth credentials: %w", err) @@ -216,7 +216,9 @@ func createCodexAuthProvider() (LLMProvider, error) { if cred == nil { return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") } - return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil + p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()) + p.enableWebSearch = enableWebSearch + return p, nil } func CreateProvider(cfg *config.Config) (LLMProvider, error) { @@ -241,10 +243,12 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case "openai", "gpt": if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { - return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil + c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()) + c.enableWebSearch = cfg.Providers.OpenAI.WebSearch + return c, nil } if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() + return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch) } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase @@ -369,7 +373,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() + return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch) } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase