From ec86b21d3fc4f2cedff00d0631a02b05ba7723b8 Mon Sep 17 00:00:00 2001 From: yinwm Date: Thu, 19 Feb 2026 09:22:39 +0800 Subject: [PATCH] fix: improve migration logic and reduce code duplication - Preserve user's configured model during config migration (issue #5) - Simplify ExtractProtocol using strings.Cut - Extract NormalizeToolCall to shared utility, removing ~70 lines of duplicate code - Clean up unused fields in providerMigrationConfig struct Co-Authored-By: Claude Opus 4.6 --- pkg/agent/loop.go | 41 +-- pkg/config/migration.go | 487 +++++++++++++++++++----------- pkg/config/migration_test.go | 252 ++++++++++++++-- pkg/providers/factory_provider.go | 10 +- pkg/providers/toolcall_utils.go | 54 ++++ pkg/tools/toolloop.go | 41 +-- 6 files changed, 600 insertions(+), 285 deletions(-) create mode 100644 pkg/providers/toolcall_utils.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index b90c473f1..32e655710 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -607,7 +607,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls)) for _, tc := range response.ToolCalls { - normalizedToolCalls = append(normalizedToolCalls, normalizeProviderToolCall(tc)) + normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc)) } // Log tool calls @@ -715,45 +715,6 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M return finalContent, iteration, nil } -func normalizeProviderToolCall(tc providers.ToolCall) providers.ToolCall { - normalized := tc - - if normalized.Name == "" && normalized.Function != nil { - normalized.Name = normalized.Function.Name - } - - if normalized.Arguments == nil { - normalized.Arguments = map[string]interface{}{} - } - - if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" { - var parsed map[string]interface{} - if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil { - normalized.Arguments = parsed - } - } - - argsJSON, _ := json.Marshal(normalized.Arguments) - if normalized.Function == nil { - normalized.Function = &providers.FunctionCall{ - Name: normalized.Name, - Arguments: string(argsJSON), - } - } else { - if normalized.Function.Name == "" { - normalized.Function.Name = normalized.Name - } - if normalized.Name == "" { - normalized.Name = normalized.Function.Name - } - if normalized.Function.Arguments == "" { - normalized.Function.Arguments = string(argsJSON) - } - } - - return normalized -} - // updateToolContexts updates the context for tools that need channel/chatID info. func (al *AgentLoop) updateToolContexts(channel, chatID string) { // Use ContextualTool interface instead of type assertions diff --git a/pkg/config/migration.go b/pkg/config/migration.go index d1e165fbb..9b8df07bd 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -5,201 +5,326 @@ package config +import ( + "slices" + "strings" +) + +// providerMigrationConfig defines how to migrate a provider from old config to new format. +type providerMigrationConfig struct { + // providerNames are the possible names used in agents.defaults.provider + providerNames []string + // protocol is the protocol prefix for the model field + protocol string + // buildConfig creates the ModelConfig from ProviderConfig + buildConfig func(p ProvidersConfig) (ModelConfig, bool) +} + // ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig. // This enables backward compatibility with existing configurations. +// It preserves the user's configured model from agents.defaults.model when possible. func ConvertProvidersToModelList(cfg *Config) []ModelConfig { if cfg == nil { return nil } + // Get user's configured provider and model + userProvider := strings.ToLower(cfg.Agents.Defaults.Provider) + userModel := cfg.Agents.Defaults.Model + var result []ModelConfig p := cfg.Providers - // OpenAI - if p.OpenAI.APIKey != "" || p.OpenAI.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "openai", - Model: "openai/gpt-4o", - APIKey: p.OpenAI.APIKey, - APIBase: p.OpenAI.APIBase, - Proxy: p.OpenAI.Proxy, - AuthMethod: p.OpenAI.AuthMethod, - }) + // Define migration rules for each provider + migrations := []providerMigrationConfig{ + { + providerNames: []string{"openai", "gpt"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "openai", + Model: "openai/gpt-4o", + APIKey: p.OpenAI.APIKey, + APIBase: p.OpenAI.APIBase, + Proxy: p.OpenAI.Proxy, + AuthMethod: p.OpenAI.AuthMethod, + }, true + }, + }, + { + providerNames: []string{"anthropic", "claude"}, + protocol: "anthropic", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "anthropic", + Model: "anthropic/claude-3-sonnet", + APIKey: p.Anthropic.APIKey, + APIBase: p.Anthropic.APIBase, + Proxy: p.Anthropic.Proxy, + AuthMethod: p.Anthropic.AuthMethod, + }, true + }, + }, + { + providerNames: []string{"openrouter"}, + protocol: "openrouter", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "openrouter", + Model: "openrouter/auto", + APIKey: p.OpenRouter.APIKey, + APIBase: p.OpenRouter.APIBase, + Proxy: p.OpenRouter.Proxy, + }, true + }, + }, + { + providerNames: []string{"groq"}, + protocol: "groq", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Groq.APIKey == "" && p.Groq.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "groq", + Model: "groq/llama-3.1-70b-versatile", + APIKey: p.Groq.APIKey, + APIBase: p.Groq.APIBase, + Proxy: p.Groq.Proxy, + }, true + }, + }, + { + providerNames: []string{"zhipu", "glm"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "zhipu", + Model: "openai/glm-4", + APIKey: p.Zhipu.APIKey, + APIBase: p.Zhipu.APIBase, + Proxy: p.Zhipu.Proxy, + }, true + }, + }, + { + providerNames: []string{"vllm"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.VLLM.APIKey == "" && p.VLLM.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "vllm", + Model: "openai/auto", + APIKey: p.VLLM.APIKey, + APIBase: p.VLLM.APIBase, + Proxy: p.VLLM.Proxy, + }, true + }, + }, + { + providerNames: []string{"gemini", "google"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Gemini.APIKey == "" && p.Gemini.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "gemini", + Model: "openai/gemini-pro", + APIKey: p.Gemini.APIKey, + APIBase: p.Gemini.APIBase, + Proxy: p.Gemini.Proxy, + }, true + }, + }, + { + providerNames: []string{"nvidia"}, + protocol: "nvidia", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Nvidia.APIKey == "" && p.Nvidia.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "nvidia", + Model: "nvidia/meta/llama-3.1-8b-instruct", + APIKey: p.Nvidia.APIKey, + APIBase: p.Nvidia.APIBase, + Proxy: p.Nvidia.Proxy, + }, true + }, + }, + { + providerNames: []string{"ollama"}, + protocol: "ollama", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Ollama.APIKey == "" && p.Ollama.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "ollama", + Model: "ollama/llama3", + APIKey: p.Ollama.APIKey, + APIBase: p.Ollama.APIBase, + Proxy: p.Ollama.Proxy, + }, true + }, + }, + { + providerNames: []string{"moonshot", "kimi"}, + protocol: "moonshot", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Moonshot.APIKey == "" && p.Moonshot.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "moonshot", + Model: "moonshot/kimi", + APIKey: p.Moonshot.APIKey, + APIBase: p.Moonshot.APIBase, + Proxy: p.Moonshot.Proxy, + }, true + }, + }, + { + providerNames: []string{"shengsuanyun"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "shengsuanyun", + Model: "openai/auto", + APIKey: p.ShengSuanYun.APIKey, + APIBase: p.ShengSuanYun.APIBase, + Proxy: p.ShengSuanYun.Proxy, + }, true + }, + }, + { + providerNames: []string{"deepseek"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "deepseek", + Model: "openai/deepseek-chat", + APIKey: p.DeepSeek.APIKey, + APIBase: p.DeepSeek.APIBase, + Proxy: p.DeepSeek.Proxy, + }, true + }, + }, + { + providerNames: []string{"cerebras"}, + protocol: "cerebras", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "cerebras", + Model: "cerebras/llama-3.3-70b", + APIKey: p.Cerebras.APIKey, + APIBase: p.Cerebras.APIBase, + Proxy: p.Cerebras.Proxy, + }, true + }, + }, + { + providerNames: []string{"volcengine", "doubao"}, + protocol: "openai", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "volcengine", + Model: "openai/doubao-pro", + APIKey: p.VolcEngine.APIKey, + APIBase: p.VolcEngine.APIBase, + Proxy: p.VolcEngine.Proxy, + }, true + }, + }, + { + providerNames: []string{"github_copilot", "copilot"}, + protocol: "github-copilot", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.GitHubCopilot.ConnectMode == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "github-copilot", + Model: "github-copilot/gpt-4o", + APIBase: p.GitHubCopilot.APIBase, + ConnectMode: p.GitHubCopilot.ConnectMode, + }, true + }, + }, + { + providerNames: []string{"antigravity"}, + protocol: "antigravity", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Antigravity.APIKey == "" && p.Antigravity.AuthMethod == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "antigravity", + Model: "antigravity/gemini-2.0-flash", + APIKey: p.Antigravity.APIKey, + AuthMethod: p.Antigravity.AuthMethod, + }, true + }, + }, + { + providerNames: []string{"qwen", "tongyi"}, + protocol: "qwen", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Qwen.APIKey == "" && p.Qwen.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "qwen", + Model: "qwen/qwen-max", + APIKey: p.Qwen.APIKey, + APIBase: p.Qwen.APIBase, + Proxy: p.Qwen.Proxy, + }, true + }, + }, } - // Anthropic - if p.Anthropic.APIKey != "" || p.Anthropic.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "anthropic", - Model: "anthropic/claude-3-sonnet", - APIKey: p.Anthropic.APIKey, - APIBase: p.Anthropic.APIBase, - Proxy: p.Anthropic.Proxy, - AuthMethod: p.Anthropic.AuthMethod, - }) - } + // Process each provider migration + for _, m := range migrations { + mc, ok := m.buildConfig(p) + if !ok { + continue + } - // OpenRouter - if p.OpenRouter.APIKey != "" || p.OpenRouter.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "openrouter", - Model: "openrouter/auto", - APIKey: p.OpenRouter.APIKey, - APIBase: p.OpenRouter.APIBase, - Proxy: p.OpenRouter.Proxy, - }) - } + // Check if this is the user's configured provider + if slices.Contains(m.providerNames, userProvider) && userModel != "" { + // Use the user's configured model instead of default + mc.Model = m.protocol + "/" + userModel + } - // Groq - if p.Groq.APIKey != "" || p.Groq.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "groq", - Model: "groq/llama-3.1-70b-versatile", - APIKey: p.Groq.APIKey, - APIBase: p.Groq.APIBase, - Proxy: p.Groq.Proxy, - }) - } - - // Zhipu - if p.Zhipu.APIKey != "" || p.Zhipu.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "zhipu", - Model: "openai/glm-4", - APIKey: p.Zhipu.APIKey, - APIBase: p.Zhipu.APIBase, - Proxy: p.Zhipu.Proxy, - }) - } - - // VLLM - if p.VLLM.APIKey != "" || p.VLLM.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "vllm", - Model: "openai/auto", - APIKey: p.VLLM.APIKey, - APIBase: p.VLLM.APIBase, - Proxy: p.VLLM.Proxy, - }) - } - - // Gemini - if p.Gemini.APIKey != "" || p.Gemini.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "gemini", - Model: "openai/gemini-pro", - APIKey: p.Gemini.APIKey, - APIBase: p.Gemini.APIBase, - Proxy: p.Gemini.Proxy, - }) - } - - // Nvidia - if p.Nvidia.APIKey != "" || p.Nvidia.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "nvidia", - Model: "nvidia/meta/llama-3.1-8b-instruct", - APIKey: p.Nvidia.APIKey, - APIBase: p.Nvidia.APIBase, - Proxy: p.Nvidia.Proxy, - }) - } - - // Ollama - if p.Ollama.APIKey != "" || p.Ollama.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "ollama", - Model: "ollama/llama3", - APIKey: p.Ollama.APIKey, - APIBase: p.Ollama.APIBase, - Proxy: p.Ollama.Proxy, - }) - } - - // Moonshot - if p.Moonshot.APIKey != "" || p.Moonshot.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "moonshot", - Model: "moonshot/kimi", - APIKey: p.Moonshot.APIKey, - APIBase: p.Moonshot.APIBase, - Proxy: p.Moonshot.Proxy, - }) - } - - // ShengSuanYun - if p.ShengSuanYun.APIKey != "" || p.ShengSuanYun.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "shengsuanyun", - Model: "openai/auto", - APIKey: p.ShengSuanYun.APIKey, - APIBase: p.ShengSuanYun.APIBase, - Proxy: p.ShengSuanYun.Proxy, - }) - } - - // DeepSeek - if p.DeepSeek.APIKey != "" || p.DeepSeek.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "deepseek", - Model: "openai/deepseek-chat", - APIKey: p.DeepSeek.APIKey, - APIBase: p.DeepSeek.APIBase, - Proxy: p.DeepSeek.Proxy, - }) - } - - // Cerebras - if p.Cerebras.APIKey != "" || p.Cerebras.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "cerebras", - Model: "cerebras/llama-3.3-70b", - APIKey: p.Cerebras.APIKey, - APIBase: p.Cerebras.APIBase, - Proxy: p.Cerebras.Proxy, - }) - } - - // VolcEngine (Doubao) - if p.VolcEngine.APIKey != "" || p.VolcEngine.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "volcengine", - Model: "openai/doubao-pro", - APIKey: p.VolcEngine.APIKey, - APIBase: p.VolcEngine.APIBase, - Proxy: p.VolcEngine.Proxy, - }) - } - - // GitHub Copilot - if p.GitHubCopilot.APIKey != "" || p.GitHubCopilot.APIBase != "" || p.GitHubCopilot.ConnectMode != "" { - result = append(result, ModelConfig{ - ModelName: "github-copilot", - Model: "github-copilot/gpt-4o", - APIBase: p.GitHubCopilot.APIBase, - ConnectMode: p.GitHubCopilot.ConnectMode, - }) - } - - // Antigravity - if p.Antigravity.APIKey != "" || p.Antigravity.AuthMethod != "" { - result = append(result, ModelConfig{ - ModelName: "antigravity", - Model: "antigravity/gemini-2.0-flash", - APIKey: p.Antigravity.APIKey, - AuthMethod: p.Antigravity.AuthMethod, - }) - } - - // Qwen - if p.Qwen.APIKey != "" || p.Qwen.APIBase != "" { - result = append(result, ModelConfig{ - ModelName: "qwen", - Model: "qwen/qwen-max", - APIKey: p.Qwen.APIKey, - APIBase: p.Qwen.APIBase, - Proxy: p.Qwen.Proxy, - }) + result = append(result, mc) } return result diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index eff16ee7a..5a4f8cc8e 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -6,6 +6,7 @@ package config import ( + "strings" "testing" ) @@ -13,7 +14,7 @@ func TestConvertProvidersToModelList_OpenAI(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ OpenAI: ProviderConfig{ - APIKey: "sk-test-key", + APIKey: "sk-test-key", APIBase: "https://custom.api.com/v1", }, }, @@ -40,7 +41,7 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ Anthropic: ProviderConfig{ - APIKey: "ant-key", + APIKey: "ant-key", APIBase: "https://custom.anthropic.com", }, }, @@ -111,23 +112,23 @@ func TestConvertProvidersToModelList_Nil(t *testing.T) { func TestConvertProvidersToModelList_AllProviders(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ - OpenAI: ProviderConfig{APIKey: "key1"}, - Anthropic: ProviderConfig{APIKey: "key2"}, - OpenRouter: ProviderConfig{APIKey: "key3"}, - Groq: ProviderConfig{APIKey: "key4"}, - Zhipu: ProviderConfig{APIKey: "key5"}, - VLLM: ProviderConfig{APIKey: "key6"}, - Gemini: ProviderConfig{APIKey: "key7"}, - Nvidia: ProviderConfig{APIKey: "key8"}, - Ollama: ProviderConfig{APIKey: "key9"}, - Moonshot: ProviderConfig{APIKey: "key10"}, - ShengSuanYun: ProviderConfig{APIKey: "key11"}, - DeepSeek: ProviderConfig{APIKey: "key12"}, - Cerebras: ProviderConfig{APIKey: "key13"}, - VolcEngine: ProviderConfig{APIKey: "key14"}, + OpenAI: ProviderConfig{APIKey: "key1"}, + Anthropic: ProviderConfig{APIKey: "key2"}, + OpenRouter: ProviderConfig{APIKey: "key3"}, + Groq: ProviderConfig{APIKey: "key4"}, + Zhipu: ProviderConfig{APIKey: "key5"}, + VLLM: ProviderConfig{APIKey: "key6"}, + Gemini: ProviderConfig{APIKey: "key7"}, + Nvidia: ProviderConfig{APIKey: "key8"}, + Ollama: ProviderConfig{APIKey: "key9"}, + Moonshot: ProviderConfig{APIKey: "key10"}, + ShengSuanYun: ProviderConfig{APIKey: "key11"}, + DeepSeek: ProviderConfig{APIKey: "key12"}, + Cerebras: ProviderConfig{APIKey: "key13"}, + VolcEngine: ProviderConfig{APIKey: "key14"}, GitHubCopilot: ProviderConfig{ConnectMode: "grpc"}, - Antigravity: ProviderConfig{AuthMethod: "oauth"}, - Qwen: ProviderConfig{APIKey: "key17"}, + Antigravity: ProviderConfig{AuthMethod: "oauth"}, + Qwen: ProviderConfig{APIKey: "key17"}, }, } @@ -175,3 +176,218 @@ func TestConvertProvidersToModelList_AuthMethod(t *testing.T) { t.Errorf("len(result) = %d, want 0 (AuthMethod alone should not create entry)", len(result)) } } + +// Tests for preserving user's configured model during migration + +func TestConvertProvidersToModelList_PreservesUserModel_DeepSeek(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "deepseek", + Model: "deepseek-reasoner", + }, + }, + Providers: ProvidersConfig{ + DeepSeek: ProviderConfig{APIKey: "sk-deepseek"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // Should use user's model, not default + if result[0].Model != "openai/deepseek-reasoner" { + t.Errorf("Model = %q, want %q (user's configured model)", result[0].Model, "openai/deepseek-reasoner") + } +} + +func TestConvertProvidersToModelList_PreservesUserModel_OpenAI(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "openai", + Model: "gpt-4-turbo", + }, + }, + Providers: ProvidersConfig{ + OpenAI: ProviderConfig{APIKey: "sk-openai"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Model != "openai/gpt-4-turbo" { + t.Errorf("Model = %q, want %q", result[0].Model, "openai/gpt-4-turbo") + } +} + +func TestConvertProvidersToModelList_PreservesUserModel_Anthropic(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "claude", // alternative name + Model: "claude-3-opus-20240229", + }, + }, + Providers: ProvidersConfig{ + Anthropic: ProviderConfig{APIKey: "sk-ant"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Model != "anthropic/claude-3-opus-20240229" { + t.Errorf("Model = %q, want %q", result[0].Model, "anthropic/claude-3-opus-20240229") + } +} + +func TestConvertProvidersToModelList_PreservesUserModel_Qwen(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "qwen", + Model: "qwen-plus", + }, + }, + Providers: ProvidersConfig{ + Qwen: ProviderConfig{APIKey: "sk-qwen"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].Model != "qwen/qwen-plus" { + t.Errorf("Model = %q, want %q", result[0].Model, "qwen/qwen-plus") + } +} + +func TestConvertProvidersToModelList_UsesDefaultWhenNoUserModel(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "deepseek", + Model: "", // no model specified + }, + }, + Providers: ProvidersConfig{ + DeepSeek: ProviderConfig{APIKey: "sk-deepseek"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // Should use default model + if result[0].Model != "openai/deepseek-chat" { + t.Errorf("Model = %q, want %q (default)", result[0].Model, "openai/deepseek-chat") + } +} + +func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: "deepseek", + Model: "deepseek-reasoner", + }, + }, + Providers: ProvidersConfig{ + OpenAI: ProviderConfig{APIKey: "sk-openai"}, + DeepSeek: ProviderConfig{APIKey: "sk-deepseek"}, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 2 { + t.Fatalf("len(result) = %d, want 2", len(result)) + } + + // Find each provider and verify model + for _, mc := range result { + switch mc.ModelName { + case "openai": + if mc.Model != "openai/gpt-4o" { + t.Errorf("OpenAI Model = %q, want %q (default)", mc.Model, "openai/gpt-4o") + } + case "deepseek": + if mc.Model != "openai/deepseek-reasoner" { + t.Errorf("DeepSeek Model = %q, want %q (user's)", mc.Model, "openai/deepseek-reasoner") + } + } + } +} + +func TestConvertProvidersToModelList_ProviderNameAliases(t *testing.T) { + tests := []struct { + providerAlias string + expectedModel string + provider ProviderConfig + }{ + {"gpt", "openai/gpt-4-custom", ProviderConfig{APIKey: "key"}}, + {"claude", "anthropic/claude-custom", ProviderConfig{APIKey: "key"}}, + {"doubao", "openai/doubao-custom", ProviderConfig{APIKey: "key"}}, + {"tongyi", "qwen/qwen-custom", ProviderConfig{APIKey: "key"}}, + {"kimi", "moonshot/kimi-custom", ProviderConfig{APIKey: "key"}}, + } + + for _, tt := range tests { + t.Run(tt.providerAlias, func(t *testing.T) { + cfg := &Config{ + Agents: AgentsConfig{ + Defaults: AgentDefaults{ + Provider: tt.providerAlias, + Model: strings.TrimPrefix(tt.expectedModel, tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1]), + }, + }, + Providers: ProvidersConfig{}, + } + + // Set the appropriate provider config + switch tt.providerAlias { + case "gpt": + cfg.Providers.OpenAI = tt.provider + case "claude": + cfg.Providers.Anthropic = tt.provider + case "doubao": + cfg.Providers.VolcEngine = tt.provider + case "tongyi": + cfg.Providers.Qwen = tt.provider + case "kimi": + cfg.Providers.Moonshot = tt.provider + } + + // Need to fix the model name in config + cfg.Agents.Defaults.Model = strings.TrimPrefix(tt.expectedModel, tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1]) + + result := ConvertProvidersToModelList(cfg) + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + // Extract just the model ID part (after the first /) + expectedModelID := tt.expectedModel + if result[0].Model != expectedModelID { + t.Errorf("Model = %q, want %q", result[0].Model, expectedModelID) + } + }) + } +} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 7851c7c5d..2097fbbff 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -45,13 +45,11 @@ func createCodexAuthProvider() (LLMProvider, error) { // - "gpt-4o" -> ("openai", "gpt-4o") // default protocol func ExtractProtocol(model string) (protocol, modelID string) { model = strings.TrimSpace(model) - for i := 0; i < len(model); i++ { - if model[i] == '/' { - return model[:i], model[i+1:] - } + protocol, modelID, found := strings.Cut(model, "/") + if !found { + return "openai", model } - // No prefix found, default to openai - return "openai", model + return protocol, modelID } // CreateProviderFromConfig creates a provider based on the ModelConfig. diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go new file mode 100644 index 000000000..c7c35ef42 --- /dev/null +++ b/pkg/providers/toolcall_utils.go @@ -0,0 +1,54 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package providers + +import "encoding/json" + +// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated. +// It handles cases where Name/Arguments might be in different locations (top-level vs Function) +// and ensures both are populated consistently. +func NormalizeToolCall(tc ToolCall) ToolCall { + normalized := tc + + // Ensure Name is populated from Function if not set + if normalized.Name == "" && normalized.Function != nil { + normalized.Name = normalized.Function.Name + } + + // Ensure Arguments is not nil + if normalized.Arguments == nil { + normalized.Arguments = map[string]interface{}{} + } + + // Parse Arguments from Function.Arguments if not already set + if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" { + var parsed map[string]interface{} + if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil { + normalized.Arguments = parsed + } + } + + // Ensure Function is populated with consistent values + argsJSON, _ := json.Marshal(normalized.Arguments) + if normalized.Function == nil { + normalized.Function = &FunctionCall{ + Name: normalized.Name, + Arguments: string(argsJSON), + } + } else { + if normalized.Function.Name == "" { + normalized.Function.Name = normalized.Name + } + if normalized.Name == "" { + normalized.Name = normalized.Function.Name + } + if normalized.Function.Arguments == "" { + normalized.Function.Arguments = string(argsJSON) + } + } + + return normalized +} diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index a95710816..0109c3447 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -85,7 +85,7 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider normalizedToolCalls := make([]providers.ToolCall, 0, len(response.ToolCalls)) for _, tc := range response.ToolCalls { - normalizedToolCalls = append(normalizedToolCalls, normalizeProviderToolCall(tc)) + normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc)) } // 5. Log tool calls @@ -159,42 +159,3 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider Iterations: iteration, }, nil } - -func normalizeProviderToolCall(tc providers.ToolCall) providers.ToolCall { - normalized := tc - - if normalized.Name == "" && normalized.Function != nil { - normalized.Name = normalized.Function.Name - } - - if normalized.Arguments == nil { - normalized.Arguments = map[string]interface{}{} - } - - if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" { - var parsed map[string]interface{} - if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil { - normalized.Arguments = parsed - } - } - - argsJSON, _ := json.Marshal(normalized.Arguments) - if normalized.Function == nil { - normalized.Function = &providers.FunctionCall{ - Name: normalized.Name, - Arguments: string(argsJSON), - } - } else { - if normalized.Function.Name == "" { - normalized.Function.Name = normalized.Name - } - if normalized.Name == "" { - normalized.Name = normalized.Function.Name - } - if normalized.Function.Arguments == "" { - normalized.Function.Arguments = string(argsJSON) - } - } - - return normalized -}