mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Feature/websearch OpenAI (#118)
* feature: add web search for codex models * fix: use more elegant way to solve the issue.
This commit is contained in:
@@ -79,7 +79,8 @@
|
||||
},
|
||||
"openai": {
|
||||
"api_key": "",
|
||||
"api_base": ""
|
||||
"api_base": "",
|
||||
"web_search": true
|
||||
},
|
||||
"openrouter": {
|
||||
"api_key": "sk-or-v1-xxx",
|
||||
@@ -144,4 +145,4 @@
|
||||
"host": "0.0.0.0",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+19
-14
@@ -167,19 +167,19 @@ type DevicesConfig struct {
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Ollama ProviderConfig `json:"ollama"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI OpenAIProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Ollama ProviderConfig `json:"ollama"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
@@ -190,6 +190,11 @@ type ProviderConfig struct {
|
||||
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
|
||||
}
|
||||
|
||||
type OpenAIProviderConfig struct {
|
||||
ProviderConfig
|
||||
WebSearch bool `json:"web_search" env:"PICOCLAW_PROVIDERS_OPENAI_WEB_SEARCH"`
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
@@ -308,7 +313,7 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenAI: OpenAIProviderConfig{WebSearch: true},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
|
||||
@@ -204,3 +204,42 @@ func TestConfig_Complete(t *testing.T) {
|
||||
t.Error("Heartbeat should be enabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Providers.OpenAI.WebSearch {
|
||||
t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if !cfg.Providers.OpenAI.WebSearch {
|
||||
t.Fatal("OpenAI codex web search should remain true when unset in config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Providers.OpenAI.WebSearch {
|
||||
t.Fatal("OpenAI codex web search should be false when disabled in config file")
|
||||
}
|
||||
}
|
||||
|
||||
+11
-1
@@ -108,7 +108,10 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
|
||||
case "anthropic":
|
||||
cfg.Providers.Anthropic = pc
|
||||
case "openai":
|
||||
cfg.Providers.OpenAI = pc
|
||||
cfg.Providers.OpenAI = config.OpenAIProviderConfig{
|
||||
ProviderConfig: pc,
|
||||
WebSearch: getBoolOrDefault(pMap, "web_search", true),
|
||||
}
|
||||
case "openrouter":
|
||||
cfg.Providers.OpenRouter = pc
|
||||
case "groq":
|
||||
@@ -363,6 +366,13 @@ func getBool(data map[string]interface{}, key string) (bool, bool) {
|
||||
return b, ok
|
||||
}
|
||||
|
||||
func getBoolOrDefault(data map[string]interface{}, key string, defaultVal bool) bool {
|
||||
if v, ok := getBool(data, key); ok {
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func getStringSlice(data map[string]interface{}, key string) []string {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
|
||||
@@ -18,9 +18,10 @@ const codexDefaultModel = "gpt-5.2"
|
||||
const codexDefaultInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
type CodexProvider struct {
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
enableWebSearch bool
|
||||
}
|
||||
|
||||
const defaultCodexInstructions = "You are Codex, a coding assistant."
|
||||
@@ -37,8 +38,9 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
}
|
||||
client := openai.NewClient(opts...)
|
||||
return &CodexProvider{
|
||||
client: &client,
|
||||
accountID: accountID,
|
||||
client: &client,
|
||||
accountID: accountID,
|
||||
enableWebSearch: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +80,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
|
||||
})
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
|
||||
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
@@ -182,7 +184,7 @@ func resolveCodexModel(model string) (string, string) {
|
||||
return codexDefaultModel, "unsupported model family"
|
||||
}
|
||||
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams {
|
||||
var inputItems responses.ResponseInputParam
|
||||
var instructions string
|
||||
|
||||
@@ -266,8 +268,8 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
params.Instructions = openai.Opt(defaultCodexInstructions)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForCodex(tools)
|
||||
if len(tools) > 0 || enableWebSearch {
|
||||
params.Tools = translateToolsForCodex(tools, enableWebSearch)
|
||||
}
|
||||
|
||||
return params
|
||||
@@ -297,9 +299,19 @@ func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool)
|
||||
return name, "{}", true
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||
func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
|
||||
capHint := len(tools)
|
||||
if enableWebSearch {
|
||||
capHint++
|
||||
}
|
||||
result := make([]responses.ToolUnionParam, 0, capHint)
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
ft := responses.FunctionToolParam{
|
||||
Name: t.Function.Name,
|
||||
Parameters: t.Function.Parameters,
|
||||
@@ -310,6 +322,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
}
|
||||
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||
}
|
||||
if enableWebSearch {
|
||||
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
}, true)
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true)
|
||||
if !params.Instructions.Valid() {
|
||||
t.Fatal("Instructions should be set")
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
|
||||
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false)
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
@@ -136,12 +136,61 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if !params.Store.Valid() || params.Store.Or(true) != false {
|
||||
t.Error("Store should be explicitly set to false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) {
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true)
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
if params.Tools[0].OfWebSearch == nil {
|
||||
t.Fatal("Tool should include built-in web_search")
|
||||
}
|
||||
if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch {
|
||||
t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "local web search",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: "read file",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true)
|
||||
if len(params.Tools) != 2 {
|
||||
t.Fatalf("len(Tools) = %d, want 2", len(params.Tools))
|
||||
}
|
||||
if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" {
|
||||
t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0])
|
||||
}
|
||||
if params.Tools[1].OfWebSearch == nil {
|
||||
t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||
respJSON := `{
|
||||
"id": "resp_test",
|
||||
@@ -260,6 +309,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
toolsAny, ok := reqBody["tools"].([]interface{})
|
||||
if !ok || len(toolsAny) != 1 {
|
||||
http.Error(w, "missing default web search tool", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
toolObj, ok := toolsAny[0].(map[string]interface{})
|
||||
if !ok || toolObj["type"] != "web_search" {
|
||||
http.Error(w, "expected web_search tool", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
@@ -307,6 +366,64 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["tools"]; ok {
|
||||
http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 4,
|
||||
"output_tokens": 3,
|
||||
"total_tokens": 7,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("test-token", "acc-123")
|
||||
provider.enableWebSearch = false
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
|
||||
@@ -208,7 +208,7 @@ func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
@@ -216,7 +216,9 @@ func createCodexAuthProvider() (LLMProvider, error) {
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
|
||||
p.enableWebSearch = enableWebSearch
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
@@ -241,10 +243,12 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
|
||||
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
|
||||
c.enableWebSearch = cfg.Providers.OpenAI.WebSearch
|
||||
return c, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch)
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
@@ -369,7 +373,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch)
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
|
||||
Reference in New Issue
Block a user