diff --git a/README.fr.md b/README.fr.md index d09276c27..f59807739 100644 --- a/README.fr.md +++ b/README.fr.md @@ -226,7 +226,7 @@ picoclaw onboard ], "agents": { "defaults": { - "model": "gpt4" + "model_name": "gpt4" } }, "channels": { diff --git a/README.ja.md b/README.ja.md index 67eccddc2..5a7bb8542 100644 --- a/README.ja.md +++ b/README.ja.md @@ -188,7 +188,7 @@ picoclaw onboard ], "agents": { "defaults": { - "model": "gpt4" + "model_name": "gpt4" } }, "channels": { diff --git a/README.md b/README.md index 778530db5..2b770f215 100644 --- a/README.md +++ b/README.md @@ -222,7 +222,7 @@ picoclaw onboard "agents": { "defaults": { "workspace": "~/.picoclaw/workspace", - "model": "gpt4", + "model_name": "gpt4", "max_tokens": 8192, "temperature": 0.7, "max_tool_iterations": 20 diff --git a/README.pt-br.md b/README.pt-br.md index 8d87333bc..0115b7f89 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -227,7 +227,7 @@ picoclaw onboard ], "agents": { "defaults": { - "model": "gpt4" + "model_name": "gpt4" } }, "tools": { diff --git a/README.vi.md b/README.vi.md index 1be58d9f6..015bc264e 100644 --- a/README.vi.md +++ b/README.vi.md @@ -207,7 +207,7 @@ picoclaw onboard ], "agents": { "defaults": { - "model": "gpt4" + "model_name": "gpt4" } }, "channels": { diff --git a/README.zh.md b/README.zh.md index 74760b3b1..4f4bde46a 100644 --- a/README.zh.md +++ b/README.zh.md @@ -224,7 +224,7 @@ picoclaw onboard "agents": { "defaults": { "workspace": "~/.picoclaw/workspace", - "model": "gpt4", + "model_name": "gpt4", "max_tokens": 8192, "temperature": 0.7, "max_tool_iterations": 20 diff --git a/cmd/picoclaw/cmd_agent.go b/cmd/picoclaw/cmd_agent.go index 8658c9d32..98ea51103 100644 --- a/cmd/picoclaw/cmd_agent.go +++ b/cmd/picoclaw/cmd_agent.go @@ -56,7 +56,7 @@ func agentCmd() { } if modelOverride != "" { - cfg.Agents.Defaults.Model = modelOverride + cfg.Agents.Defaults.ModelName = modelOverride } provider, modelID, err := providers.CreateProvider(cfg) @@ -66,7 +66,7 @@ func agentCmd() { } // Use the resolved model ID from provider creation if modelID != "" { - cfg.Agents.Defaults.Model = modelID + cfg.Agents.Defaults.ModelName = modelID } msgBus := bus.NewMessageBus() diff --git a/cmd/picoclaw/cmd_auth.go b/cmd/picoclaw/cmd_auth.go index 729c56177..55eb3cec3 100644 --- a/cmd/picoclaw/cmd_auth.go +++ b/cmd/picoclaw/cmd_auth.go @@ -144,7 +144,7 @@ func authLoginOpenAI(useDeviceCode bool) { } // Update default model to use OpenAI - appCfg.Agents.Defaults.Model = "gpt-5.2" + appCfg.Agents.Defaults.ModelName = "gpt-5.2" if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { fmt.Printf("Warning: could not update config: %v\n", err) @@ -218,7 +218,7 @@ func authLoginGoogleAntigravity() { } // Update default model - appCfg.Agents.Defaults.Model = "gemini-flash" + appCfg.Agents.Defaults.ModelName = "gemini-flash" if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { fmt.Printf("Warning: could not update config: %v\n", err) @@ -292,7 +292,7 @@ func authLoginPasteToken(provider string) { }) } // Update default model - appCfg.Agents.Defaults.Model = "claude-sonnet-4.6" + appCfg.Agents.Defaults.ModelName = "claude-sonnet-4.6" case "openai": appCfg.Providers.OpenAI.AuthMethod = "token" // Update ModelList @@ -312,7 +312,7 @@ func authLoginPasteToken(provider string) { }) } // Update default model - appCfg.Agents.Defaults.Model = "gpt-5.2" + appCfg.Agents.Defaults.ModelName = "gpt-5.2" } if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { fmt.Printf("Warning: could not update config: %v\n", err) @@ -320,7 +320,7 @@ func authLoginPasteToken(provider string) { } fmt.Printf("Token saved for %s!\n", provider) - fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.Model) + fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.GetModelName()) } func authLogoutCmd() { diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 28ef76ad3..cf7f3563a 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -52,7 +52,7 @@ func gatewayCmd() { } // Use the resolved model ID from provider creation if modelID != "" { - cfg.Agents.Defaults.Model = modelID + cfg.Agents.Defaults.ModelName = modelID } msgBus := bus.NewMessageBus() diff --git a/cmd/picoclaw/cmd_status.go b/cmd/picoclaw/cmd_status.go index 07296784e..6a117bd17 100644 --- a/cmd/picoclaw/cmd_status.go +++ b/cmd/picoclaw/cmd_status.go @@ -41,7 +41,7 @@ func statusCmd() { } if _, err := os.Stat(configPath); err == nil { - fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model) + fmt.Printf("Model: %s\n", cfg.Agents.Defaults.GetModelName()) hasOpenRouter := cfg.Providers.OpenRouter.APIKey != "" hasAnthropic := cfg.Providers.Anthropic.APIKey != "" diff --git a/config/config.example.json b/config/config.example.json index 555509732..e8c6b3d3f 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -3,7 +3,7 @@ "defaults": { "workspace": "~/.picoclaw/workspace", "restrict_to_workspace": true, - "model": "gpt4", + "model_name": "gpt4", "max_tokens": 8192, "temperature": 0.7, "max_tool_iterations": 20 diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index dfbef9fbc..c6a54c7d2 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -133,7 +133,7 @@ func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefau if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" { return strings.TrimSpace(agentCfg.Model.Primary) } - return defaults.Model + return defaults.GetModelName() } // resolveAgentFallbacks resolves the fallback models for an agent. diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram_commands.go index a084b641b..f28434f46 100644 --- a/pkg/channels/telegram_commands.go +++ b/pkg/channels/telegram_commands.go @@ -81,7 +81,7 @@ func (c *cmd) Show(ctx context.Context, message telego.Message) error { switch args { case "model": response = fmt.Sprintf("Current Model: %s (Provider: %s)", - c.config.Agents.Defaults.Model, + c.config.Agents.Defaults.GetModelName(), c.config.Agents.Defaults.Provider) case "channel": response = "Current Channel: telegram" @@ -120,7 +120,7 @@ func (c *cmd) List(ctx context.Context, message telego.Message) error { provider = "configured default" } response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml", - c.config.Agents.Defaults.Model, provider) + c.config.Agents.Defaults.GetModelName(), provider) case "channels": var enabled []string diff --git a/pkg/config/config.go b/pkg/config/config.go index 2595398c7..33e7d30e7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -170,7 +170,8 @@ type AgentDefaults struct { Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` + Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead ModelFallbacks []string `json:"model_fallbacks,omitempty"` ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` @@ -179,6 +180,15 @@ type AgentDefaults struct { MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } +// GetModelName returns the effective model name for the agent defaults. +// It prefers the new "model_name" field but falls back to "model" for backward compatibility. +func (d *AgentDefaults) GetModelName() string { + if d.ModelName != "" { + return d.ModelName + } + return d.Model +} + type ChannelsConfig struct { WhatsApp WhatsAppConfig `json:"whatsapp"` Telegram TelegramConfig `json:"telegram"` diff --git a/pkg/config/migration.go b/pkg/config/migration.go index 30eaa7474..70e1de438 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -41,7 +41,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { // Get user's configured provider and model userProvider := strings.ToLower(cfg.Agents.Defaults.Provider) - userModel := cfg.Agents.Defaults.Model + userModel := cfg.Agents.Defaults.GetModelName() p := cfg.Providers diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go index 3c411dc0f..99eea2782 100644 --- a/pkg/config/model_config_test.go +++ b/pkg/config/model_config_test.go @@ -6,6 +6,7 @@ package config import ( + "encoding/json" "strings" "sync" "testing" @@ -114,6 +115,137 @@ func TestGetModelConfig_Concurrent(t *testing.T) { } } +func TestAgentDefaults_GetModelName_BackwardCompat(t *testing.T) { + tests := []struct { + name string + defaults AgentDefaults + wantName string + }{ + { + name: "new model_name field only", + defaults: AgentDefaults{ModelName: "new-model"}, + wantName: "new-model", + }, + { + name: "old model field only", + defaults: AgentDefaults{Model: "legacy-model"}, + wantName: "legacy-model", + }, + { + name: "both fields - model_name takes precedence", + defaults: AgentDefaults{ModelName: "new-model", Model: "old-model"}, + wantName: "new-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.defaults.GetModelName(); got != tt.wantName { + t.Errorf("GetModelName() = %q, want %q", got, tt.wantName) + } + }) + } +} + +func TestAgentDefaults_JSON_BackwardCompat(t *testing.T) { + tests := []struct { + name string + json string + wantName string + }{ + { + name: "new model_name field", + json: `{"model_name": "gpt4"}`, + wantName: "gpt4", + }, + { + name: "old model field", + json: `{"model": "gpt4"}`, + wantName: "gpt4", + }, + { + name: "both fields - model_name wins", + json: `{"model_name": "new", "model": "old"}`, + wantName: "new", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var defaults AgentDefaults + if err := json.Unmarshal([]byte(tt.json), &defaults); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if got := defaults.GetModelName(); got != tt.wantName { + t.Errorf("GetModelName() = %q, want %q", got, tt.wantName) + } + }) + } +} + +func TestFullConfig_JSON_BackwardCompat(t *testing.T) { + // Test complete config with both old and new formats + oldFormat := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "gpt4", + "max_tokens": 4096 + } + }, + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-4o", + "api_key": "test-key" + } + ] + }` + + newFormat := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model_name": "gpt4", + "max_tokens": 4096 + } + }, + "model_list": [ + { + "model_name": "gpt4", + "model": "openai/gpt-4o", + "api_key": "test-key" + } + ] + }` + + for name, jsonStr := range map[string]string{ + "old format (model)": oldFormat, + "new format (model_name)": newFormat, + } { + t.Run(name, func(t *testing.T) { + cfg := &Config{} + if err := json.Unmarshal([]byte(jsonStr), cfg); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + // Check that GetModelName returns correct value + if got := cfg.Agents.Defaults.GetModelName(); got != "gpt4" { + t.Errorf("GetModelName() = %q, want %q", got, "gpt4") + } + + // Check that GetModelConfig works + modelCfg, err := cfg.GetModelConfig("gpt4") + if err != nil { + t.Fatalf("GetModelConfig error: %v", err) + } + if modelCfg.Model != "openai/gpt-4o" { + t.Errorf("Model = %q, want %q", modelCfg.Model, "openai/gpt-4o") + } + }) + } +} + func TestModelConfig_Validate(t *testing.T) { tests := []struct { name string diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go index 24ce33e94..869b39827 100644 --- a/pkg/migrate/config.go +++ b/pkg/migrate/config.go @@ -73,7 +73,10 @@ func ConvertConfig(data map[string]any) (*config.Config, []string, error) { if agents, ok := getMap(data, "agents"); ok { if defaults, ok := getMap(agents, "defaults"); ok { - if v, ok := getString(defaults, "model"); ok { + // Prefer model_name, fallback to model for backward compatibility + if v, ok := getString(defaults, "model_name"); ok { + cfg.Agents.Defaults.ModelName = v + } else if v, ok := getString(defaults, "model"); ok { cfg.Agents.Defaults.Model = v } if v, ok := getFloat(defaults, "max_tokens"); ok { diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index cda4753ea..11af14da4 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -36,7 +36,7 @@ type providerSelection struct { } func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { - model := cfg.Agents.Defaults.Model + model := cfg.Agents.Defaults.GetModelName() providerName := strings.ToLower(cfg.Agents.Defaults.Provider) lowerModel := strings.ToLower(model) diff --git a/pkg/providers/legacy_provider.go b/pkg/providers/legacy_provider.go index eb13cec65..23f137538 100644 --- a/pkg/providers/legacy_provider.go +++ b/pkg/providers/legacy_provider.go @@ -16,7 +16,7 @@ import ( // The old providers config is automatically converted to model_list during config loading. // Returns the provider, the model ID to use, and any error. func CreateProvider(cfg *config.Config) (LLMProvider, string, error) { - model := cfg.Agents.Defaults.Model + model := cfg.Agents.Defaults.GetModelName() // Ensure model_list is populated (should be done by LoadConfig, but handle edge cases) if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() {