mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'upstream-main' into feat/subturn-poc
This commit is contained in:
@@ -180,6 +180,10 @@ func buildParams(
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Skip tool calls with empty names to avoid API errors
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
args := tc.Arguments
|
||||
if args == nil && tc.Function != nil && tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
|
||||
@@ -50,10 +50,18 @@ func (p *ClaudeCliProvider) Chat(
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if stderrStr := stderr.String(); stderrStr != "" {
|
||||
stderrStr := strings.TrimSpace(stderr.String())
|
||||
stdoutStr := strings.TrimSpace(stdout.String())
|
||||
switch {
|
||||
case stderrStr != "" && stdoutStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %w\nstderr: %s\nstdout: %s", err, stderrStr, stdoutStr)
|
||||
case stderrStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %s", stderrStr)
|
||||
case stdoutStr != "":
|
||||
return nil, fmt.Errorf("claude cli error: %w\noutput: %s", err, stdoutStr)
|
||||
default:
|
||||
return nil, fmt.Errorf("claude cli error: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("claude cli error: %w", err)
|
||||
}
|
||||
|
||||
return p.parseClaudeCliResponse(stdout.String())
|
||||
|
||||
@@ -8,6 +8,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CodexHomeEnvVar is the environment variable that overrides the Codex CLI
|
||||
// home directory when resolving the codex auth.json credentials file.
|
||||
// Default: ~/.codex
|
||||
const CodexHomeEnvVar = "CODEX_HOME"
|
||||
|
||||
// CodexCliAuth represents the ~/.codex/auth.json file structure.
|
||||
type CodexCliAuth struct {
|
||||
Tokens struct {
|
||||
@@ -69,7 +74,7 @@ func CreateCodexCliTokenSource() func() (string, string, error) {
|
||||
}
|
||||
|
||||
func resolveCodexAuthPath() (string, error) {
|
||||
codexHome := os.Getenv("CODEX_HOME")
|
||||
codexHome := os.Getenv(CodexHomeEnvVar)
|
||||
if codexHome == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
|
||||
@@ -95,7 +95,10 @@ func (p *CodexProvider) Chat(
|
||||
)
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
|
||||
// Respect tools.web.prefer_native: only inject native search when the agent
|
||||
// loop requested it (options["native_search"]), so prefer_native: false
|
||||
useNativeSearch := p.enableWebSearch && (options["native_search"] == true)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, useNativeSearch)
|
||||
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
@@ -157,6 +160,10 @@ func (p *CodexProvider) GetDefaultModel() string {
|
||||
return codexDefaultModel
|
||||
}
|
||||
|
||||
func (p *CodexProvider) SupportsNativeSearch() bool {
|
||||
return p.enableWebSearch
|
||||
}
|
||||
|
||||
func resolveCodexModel(model string) (string, string) {
|
||||
m := strings.ToLower(strings.TrimSpace(model))
|
||||
if m == "" {
|
||||
|
||||
@@ -355,7 +355,9 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"max_tokens": 1024})
|
||||
// Pass native_search so Codex injects built-in web search (mirrors agent loop when prefer_native is true).
|
||||
opts := map[string]any{"max_tokens": 1024, "native_search": true}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", opts)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
@@ -55,8 +55,8 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, litellm, anthropic, anthropic-messages, antigravity,
|
||||
// claude-cli, codex-cli, github-copilot
|
||||
// Supported protocols: openai, litellm, novita, anthropic, anthropic-messages,
|
||||
// antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
@@ -116,7 +116,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
"minimax", "longcat", "modelscope":
|
||||
"minimax", "longcat", "modelscope", "novita":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
if cfg.APIKey == "" && cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
|
||||
@@ -219,6 +219,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://openrouter.ai/api/v1"
|
||||
case "litellm":
|
||||
return "http://localhost:4000/v1"
|
||||
case "novita":
|
||||
return "https://api.novita.ai/openai"
|
||||
case "groq":
|
||||
return "https://api.groq.com/openai/v1"
|
||||
case "zhipu":
|
||||
|
||||
@@ -112,6 +112,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
}{
|
||||
{"openai", "openai"},
|
||||
{"groq", "groq"},
|
||||
{"novita", "novita"},
|
||||
{"openrouter", "openrouter"},
|
||||
{"cerebras", "cerebras"},
|
||||
{"vivgrid", "vivgrid"},
|
||||
@@ -222,6 +223,34 @@ func TestGetDefaultAPIBase_ModelScope(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Novita(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-novita",
|
||||
Model: "novita/deepseek/deepseek-v3.2",
|
||||
APIKey: "test-key",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "deepseek/deepseek-v3.2" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "deepseek/deepseek-v3.2")
|
||||
}
|
||||
if _, ok := provider.(*HTTPProvider); !ok {
|
||||
t.Fatalf("expected *HTTPProvider, got %T", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultAPIBase_Novita(t *testing.T) {
|
||||
if got := getDefaultAPIBase("novita"); got != "https://api.novita.ai/openai" {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "novita", got, "https://api.novita.ai/openai")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
|
||||
@@ -55,3 +55,7 @@ func (p *HTTPProvider) Chat(
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) SupportsNativeSearch() bool {
|
||||
return p.delegate.SupportsNativeSearch()
|
||||
}
|
||||
|
||||
@@ -103,8 +103,11 @@ func (p *Provider) Chat(
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
// When fallback uses a different provider (e.g. DeepSeek), that provider must not inject web_search_preview.
|
||||
nativeSearch, _ := options["native_search"].(bool)
|
||||
nativeSearch = nativeSearch && isNativeSearchHost(p.apiBase)
|
||||
if len(tools) > 0 || nativeSearch {
|
||||
requestBody["tools"] = buildToolsList(tools, nativeSearch)
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
@@ -188,13 +191,40 @@ func normalizeModel(model, apiBase string) string {
|
||||
prefix := strings.ToLower(before)
|
||||
switch prefix {
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google",
|
||||
"openrouter", "zhipu", "mistral", "vivgrid", "minimax":
|
||||
"openrouter", "zhipu", "mistral", "vivgrid", "minimax", "novita":
|
||||
return after
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolsList(tools []ToolDefinition, nativeSearch bool) []any {
|
||||
result := make([]any, 0, len(tools)+1)
|
||||
for _, t := range tools {
|
||||
if nativeSearch && strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
result = append(result, t)
|
||||
}
|
||||
if nativeSearch {
|
||||
result = append(result, map[string]any{"type": "web_search_preview"})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *Provider) SupportsNativeSearch() bool {
|
||||
return isNativeSearchHost(p.apiBase)
|
||||
}
|
||||
|
||||
func isNativeSearchHost(apiBase string) bool {
|
||||
u, err := url.Parse(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host := u.Hostname()
|
||||
return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com")
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
|
||||
@@ -432,7 +432,28 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
|
||||
func TestProviderChat_StripsGroqOllamaDeepseekVivgridNovitaPrefixes(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
@@ -463,31 +484,25 @@ func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) {
|
||||
input: "vivgrid/auto",
|
||||
wantModel: "auto",
|
||||
},
|
||||
{
|
||||
name: "strips novita prefix deepseek model",
|
||||
input: "novita/deepseek/deepseek-v3.2",
|
||||
wantModel: "deepseek/deepseek-v3.2",
|
||||
},
|
||||
{
|
||||
name: "strips novita prefix zai model",
|
||||
input: "novita/zai-org/glm-5",
|
||||
wantModel: "zai-org/glm-5",
|
||||
},
|
||||
{
|
||||
name: "strips novita prefix minimax model",
|
||||
input: "novita/minimax/minimax-m2.5",
|
||||
wantModel: "minimax/minimax-m2.5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
@@ -573,6 +588,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" {
|
||||
t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto")
|
||||
}
|
||||
if got := normalizeModel(
|
||||
"novita/deepseek/deepseek-v3.2",
|
||||
"https://api.novita.ai/openai",
|
||||
); got != "deepseek/deepseek-v3.2" {
|
||||
t.Fatalf("normalizeModel(novita) = %q, want %q", got, "deepseek/deepseek-v3.2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_RequestTimeoutDefault(t *testing.T) {
|
||||
@@ -824,6 +845,232 @@ func TestSupportsPromptCacheKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsList_NativeSearchAddsWebSearchPreview(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
result := buildToolsList(tools, true)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
wsEntry, ok := result[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("web search entry is %T, want map[string]any", result[1])
|
||||
}
|
||||
if wsEntry["type"] != "web_search_preview" {
|
||||
t.Fatalf("type = %v, want web_search_preview", wsEntry["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsList_NativeSearchFiltersClientWebSearch(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
result := buildToolsList(tools, true)
|
||||
for _, entry := range result {
|
||||
if td, ok := entry.(ToolDefinition); ok && strings.EqualFold(td.Function.Name, "web_search") {
|
||||
t.Fatal("client-side web_search should be filtered out when native search is enabled")
|
||||
}
|
||||
}
|
||||
if len(result) != 2 { // read_file + web_search_preview
|
||||
t.Fatalf("len(result) = %d, want 2 (read_file + web_search_preview)", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolsList_NoNativeSearchPassesThrough(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
result := buildToolsList(tools, false)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("len(result) = %d, want 2", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNativeSearchHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
apiBase string
|
||||
want bool
|
||||
}{
|
||||
{"https://api.openai.com/v1", true},
|
||||
{"https://myresource.openai.azure.com/openai/deployments/gpt-4", true},
|
||||
{"https://api.mistral.ai/v1", false},
|
||||
{"https://api.deepseek.com/v1", false},
|
||||
{"https://api.groq.com/openai/v1", false},
|
||||
{"http://localhost:11434/v1", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isNativeSearchHost(tt.apiBase); got != tt.want {
|
||||
t.Errorf("isNativeSearchHost(%q) = %v, want %v", tt.apiBase, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsNativeSearch_OpenAI(t *testing.T) {
|
||||
p := NewProvider("key", "https://api.openai.com/v1", "")
|
||||
if !p.SupportsNativeSearch() {
|
||||
t.Fatal("OpenAI provider should support native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsNativeSearch_NonOpenAI(t *testing.T) {
|
||||
p := NewProvider("key", "https://api.deepseek.com/v1", "")
|
||||
if p.SupportsNativeSearch() {
|
||||
t.Fatal("DeepSeek provider should not support native search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_NativeSearchToolInjected(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
p.apiBase = "https://api.openai.com/v1"
|
||||
p.httpClient = &http.Client{
|
||||
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
|
||||
r.URL, _ = url.Parse(server.URL + r.URL.Path)
|
||||
return http.DefaultTransport.RoundTrip(r)
|
||||
}),
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "read_file", Description: "read"}},
|
||||
}
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
tools,
|
||||
"gpt-5.4",
|
||||
map[string]any{"native_search": true},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
toolsRaw, ok := requestBody["tools"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("tools is %T, want []any", requestBody["tools"])
|
||||
}
|
||||
if len(toolsRaw) != 2 {
|
||||
t.Fatalf("len(tools) = %d, want 2 (read_file + web_search_preview)", len(toolsRaw))
|
||||
}
|
||||
|
||||
lastTool, ok := toolsRaw[1].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("last tool is %T, want map[string]any", toolsRaw[1])
|
||||
}
|
||||
if lastTool["type"] != "web_search_preview" {
|
||||
t.Fatalf("last tool type = %v, want web_search_preview", lastTool["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_NativeSearchNotInjectedWithoutOption(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
tools := []ToolDefinition{
|
||||
{Type: "function", Function: ToolFunctionDefinition{Name: "web_search", Description: "search"}},
|
||||
}
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
tools,
|
||||
"gpt-5.4",
|
||||
map[string]any{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
toolsRaw, ok := requestBody["tools"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("tools is %T, want []any", requestBody["tools"])
|
||||
}
|
||||
if len(toolsRaw) != 1 {
|
||||
t.Fatalf("len(tools) = %d, want 1 (web_search only)", len(toolsRaw))
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderChat_NativeSearchIgnoredOnNonOpenAI verifies that when native_search
|
||||
// is true in options but the provider's apiBase is not OpenAI (e.g. fallback to DeepSeek),
|
||||
// we do not inject web_search_preview to avoid API errors.
|
||||
func TestProviderChat_NativeSearchIgnoredOnNonOpenAI(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Use server.URL so host is not api.openai.com — simulates DeepSeek/other provider
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"deepseek-chat",
|
||||
map[string]any{"native_search": true},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// Should not have tools at all (no tools passed, and we must not add web_search_preview)
|
||||
if toolsRaw, ok := requestBody["tools"]; ok {
|
||||
t.Fatalf("tools should be omitted for non-OpenAI when only native_search was requested, got %v", toolsRaw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{
|
||||
|
||||
@@ -44,6 +44,15 @@ type ThinkingCapable interface {
|
||||
SupportsThinking() bool
|
||||
}
|
||||
|
||||
// NativeSearchCapable is an optional interface for providers that support
|
||||
// built-in web search during LLM inference (e.g. OpenAI web_search_preview,
|
||||
// xAI Grok search). When the active provider implements this interface and
|
||||
// returns true, the agent loop can hide the client-side web_search tool to
|
||||
// avoid duplicate search surfaces and use the provider's native search instead.
|
||||
type NativeSearchCapable interface {
|
||||
SupportsNativeSearch() bool
|
||||
}
|
||||
|
||||
// FailoverReason classifies why an LLM request failed for fallback decisions.
|
||||
type FailoverReason string
|
||||
|
||||
|
||||
Reference in New Issue
Block a user