From dfa36f39cb90559beee0aef9a8c135328b5e0e87 Mon Sep 17 00:00:00 2001 From: Cytown Date: Fri, 13 Mar 2026 14:10:11 +0800 Subject: [PATCH] add model command to set default model (#1250) * add model command to set default model * fix for ci * fix test for model * fix active agent not recognized * implement test for model command * fix local-model can not set as default issue * fix review comment * fix for comment --- cmd/picoclaw/internal/model/command.go | 138 ++++++++ cmd/picoclaw/internal/model/command_test.go | 369 ++++++++++++++++++++ cmd/picoclaw/main.go | 2 + cmd/picoclaw/main_test.go | 1 + go.sum | 2 - pkg/config/config.go | 4 +- pkg/config/config_test.go | 4 +- 7 files changed, 514 insertions(+), 6 deletions(-) create mode 100644 cmd/picoclaw/internal/model/command.go create mode 100644 cmd/picoclaw/internal/model/command_test.go diff --git a/cmd/picoclaw/internal/model/command.go b/cmd/picoclaw/internal/model/command.go new file mode 100644 index 000000000..cad106fd5 --- /dev/null +++ b/cmd/picoclaw/internal/model/command.go @@ -0,0 +1,138 @@ +package model + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/sipeed/picoclaw/cmd/picoclaw/internal" + "github.com/sipeed/picoclaw/pkg/config" +) + +// LocalModel is a special model name that indicates that the model is local and with or without api_key. +const LocalModel = "local-model" + +func NewModelCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "model [model_name]", + Short: "Show or change the default model", + Long: `Show or change the default model configuration. + +If no argument is provided, shows the current default model. +If a model name is provided, sets it as the default model. + +Examples: + picoclaw model # Show current default model + picoclaw model gpt-5.2 # Set gpt-5.2 as default + picoclaw model claude-sonnet-4.6 # Set claude-sonnet-4.6 as default + picoclaw model local-model # Set local VLLM server as default + +Note: 'local-model' is a special value for using a local VLLM server +(running at localhost:8000 by default) which does not require an API key.`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + configPath := internal.GetConfigPath() + + // Load current config + cfg, err := config.LoadConfig(configPath) + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + if len(args) == 0 { + // Show current default model + showCurrentModel(cfg) + return nil + } + + // Set new default model + modelName := args[0] + return setDefaultModel(configPath, cfg, modelName) + }, + } + + return cmd +} + +func showCurrentModel(cfg *config.Config) { + defaultModel := cfg.Agents.Defaults.ModelName + if defaultModel == "" { + defaultModel = cfg.Agents.Defaults.Model + } + + if defaultModel == "" { + fmt.Println("No default model is currently set.") + fmt.Println("\nAvailable models in your config:") + listAvailableModels(cfg) + } else { + fmt.Printf("Current default model: %s\n", defaultModel) + fmt.Println("\nAvailable models in your config:") + listAvailableModels(cfg) + } +} + +func listAvailableModels(cfg *config.Config) { + if len(cfg.ModelList) == 0 { + fmt.Println(" No models configured in model_list") + return + } + + defaultModel := cfg.Agents.Defaults.ModelName + if defaultModel == "" { + defaultModel = cfg.Agents.Defaults.Model + } + + for _, model := range cfg.ModelList { + marker := " " + if model.ModelName == defaultModel { + marker = "> " + } + if model.APIKey == "" { + continue + } + fmt.Printf("%s- %s (%s)\n", marker, model.ModelName, model.Model) + } +} + +func setDefaultModel(configPath string, cfg *config.Config, modelName string) error { + // Validate that the model exists in model_list + modelFound := false + for _, model := range cfg.ModelList { + if model.APIKey != "" && model.ModelName == modelName { + modelFound = true + break + } + } + + if !modelFound && modelName != LocalModel { + return fmt.Errorf("cannot found model '%s' in config", modelName) + } + + // Update the default model + // Clear old model field and set new model_name + oldModel := cfg.Agents.Defaults.ModelName + if oldModel == "" { + oldModel = cfg.Agents.Defaults.Model + } + + cfg.Agents.Defaults.ModelName = modelName + cfg.Agents.Defaults.Model = "" // Clear deprecated field + + // Save config back to file + if err := config.SaveConfig(configPath, cfg); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + + fmt.Printf("✓ Default model changed from '%s' to '%s'\n", + formatModelName(oldModel), modelName) + fmt.Println("\nThe new default model will be used for all agent interactions.") + + return nil +} + +func formatModelName(name string) string { + if name == "" { + return "(none)" + } + return name +} diff --git a/cmd/picoclaw/internal/model/command_test.go b/cmd/picoclaw/internal/model/command_test.go new file mode 100644 index 000000000..82943e4a6 --- /dev/null +++ b/cmd/picoclaw/internal/model/command_test.go @@ -0,0 +1,369 @@ +package model + +import ( + "bytes" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/pkg/config" +) + +var configPath = "" + +func initTest(t *testing.T) { + tmpDir := t.TempDir() + configPath = filepath.Join(tmpDir, "config.json") + _ = os.Setenv("PICOCLAW_CONFIG", configPath) +} + +// captureStdout captures stdout during the execution of fn and returns the captured output +func captureStdout(fn func()) string { + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + fn() + + w.Close() + os.Stdout = oldStdout + + var buf bytes.Buffer + io.Copy(&buf, r) + return buf.String() +} + +func TestNewModelCommand(t *testing.T) { + cmd := NewModelCommand() + + require.NotNil(t, cmd) + + assert.Equal(t, "model [model_name]", cmd.Use) + assert.Equal(t, "Show or change the default model", cmd.Short) + + assert.Len(t, cmd.Aliases, 0) + + assert.False(t, cmd.HasFlags()) + + assert.Nil(t, cmd.Run) + assert.NotNil(t, cmd.RunE) + + assert.Nil(t, cmd.PersistentPreRunE) + assert.Nil(t, cmd.PersistentPreRun) + assert.Nil(t, cmd.PersistentPostRun) +} + +func TestShowCurrentModel_WithDefaultModel(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "gpt-4", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4", APIKey: "test"}, + {ModelName: "claude-3", Model: "anthropic/claude-3", APIKey: "test"}, + }, + } + + output := captureStdout(func() { + showCurrentModel(cfg) + }) + + assert.Contains(t, output, "Current default model: gpt-4") + assert.Contains(t, output, "Available models in your config:") + assert.Contains(t, output, "gpt-4") + assert.Contains(t, output, "claude-3") +} + +func TestShowCurrentModel_NoDefaultModel(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "", + Model: "", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4", APIKey: "test"}, + }, + } + + output := captureStdout(func() { + showCurrentModel(cfg) + }) + + assert.Contains(t, output, "No default model is currently set.") + assert.Contains(t, output, "Available models in your config:") +} + +func TestShowCurrentModel_BackwardCompatibility(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Model: "legacy-model", + }, + }, + ModelList: []config.ModelConfig{}, + } + + output := captureStdout(func() { + showCurrentModel(cfg) + }) + + assert.Contains(t, output, "Current default model: legacy-model") +} + +func TestListAvailableModels_Empty(t *testing.T) { + cfg := &config.Config{ + ModelList: []config.ModelConfig{}, + } + + output := captureStdout(func() { + listAvailableModels(cfg) + }) + + assert.Contains(t, output, "No models configured in model_list") +} + +func TestListAvailableModels_WithModels(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "gpt-4", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "gpt-4", Model: "openai/gpt-4", APIKey: "test"}, + {ModelName: "claude-3", Model: "anthropic/claude-3", APIKey: "test"}, + {ModelName: "no-key-model", Model: "openai/test", APIKey: ""}, + }, + } + + output := captureStdout(func() { + listAvailableModels(cfg) + }) + + assert.NotEmpty(t, output) + assert.Contains(t, output, "> - gpt-4 (openai/gpt-4)") + assert.Contains(t, output, "claude-3 (anthropic/claude-3)") + assert.NotContains(t, output, "no-key-model") +} + +func TestSetDefaultModel_ValidModel(t *testing.T) { + initTest(t) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "old-model", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "new-model", Model: "openai/new-model", APIKey: "test"}, + {ModelName: "old-model", Model: "openai/old-model", APIKey: "test"}, + }, + } + + output := captureStdout(func() { + err := setDefaultModel(configPath, cfg, "new-model") + assert.NoError(t, err) + }) + + assert.Contains(t, output, "Default model changed from 'old-model' to 'new-model'") + + // Verify config was updated + updatedCfg, err := config.LoadConfig(configPath) + require.NoError(t, err) + assert.Equal(t, "new-model", updatedCfg.Agents.Defaults.ModelName) + assert.Empty(t, updatedCfg.Agents.Defaults.Model) +} + +func TestSetDefaultModel_LegacyModelField(t *testing.T) { + initTest(t) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Model: "legacy-old", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "new-model", Model: "openai/new-model", APIKey: "test"}, + }, + } + + output := captureStdout(func() { + err := setDefaultModel(configPath, cfg, "new-model") + assert.NoError(t, err) + }) + + assert.Contains(t, output, "Default model changed from 'legacy-old' to 'new-model'") +} + +func TestSetDefaultModel_InvalidModel(t *testing.T) { + initTest(t) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "existing-model", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "existing-model", Model: "openai/existing", APIKey: "test"}, + }, + } + + assert.Error(t, setDefaultModel(configPath, cfg, "nonexistent-model")) +} + +func TestSetDefaultModel_ModelWithoutAPIKey(t *testing.T) { + initTest(t) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "existing-model", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "existing-model", Model: "openai/existing", APIKey: "test"}, + {ModelName: "no-key-model", Model: "openai/nokey", APIKey: ""}, + }, + } + + assert.Error(t, setDefaultModel(configPath, cfg, "no-key-model")) +} + +func TestSetDefaultModel_SaveConfigError(t *testing.T) { + // Use an invalid path to trigger save error + invalidPath := "/nonexistent/directory/config.json" + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "old-model", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "new-model", Model: "openai/new-model", APIKey: "test"}, + }, + } + + err := setDefaultModel(invalidPath, cfg, "new-model") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to save config") +} + +func TestFormatModelName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"empty string", "", "(none)"}, + {"simple model", "gpt-4", "gpt-4"}, + {"model with version", "claude-sonnet-4.6", "claude-sonnet-4.6"}, + {"model with spaces", "my model", "my model"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatModelName(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestModelCommandExecution_Show(t *testing.T) { + initTest(t) + + // Create a test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "test-model", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "test-model", Model: "openai/test", APIKey: "test"}, + }, + } + + err := config.SaveConfig(configPath, cfg) + require.NoError(t, err) + + cmd := NewModelCommand() + + output := captureStdout(func() { + err = cmd.RunE(cmd, []string{}) + assert.NoError(t, err) + }) + + assert.Contains(t, output, "Current default model: test-model") +} + +func TestModelCommandExecution_Set(t *testing.T) { + initTest(t) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "old-model", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "old-model", Model: "openai/old", APIKey: "test"}, + {ModelName: "new-model", Model: "openai/new", APIKey: "test"}, + }, + } + + err := config.SaveConfig(configPath, cfg) + require.NoError(t, err) + + cmd := NewModelCommand() + + output := captureStdout(func() { + err = cmd.RunE(cmd, []string{"new-model"}) + assert.NoError(t, err) + }) + + assert.Contains(t, output, "Default model changed from 'old-model' to 'new-model'") +} + +func TestModelCommandExecution_TooManyArgs(t *testing.T) { + cmd := NewModelCommand() + + err := cmd.RunE(cmd, []string{"model1", "model2"}) + + assert.Error(t, err) +} + +func TestListAvailableModels_MarkerLogic(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + ModelName: "middle-model", + }, + }, + ModelList: []config.ModelConfig{ + {ModelName: "first-model", Model: "openai/first", APIKey: "test"}, + {ModelName: "middle-model", Model: "openai/middle", APIKey: "test"}, + {ModelName: "last-model", Model: "openai/last", APIKey: "test"}, + }, + } + + output := captureStdout(func() { + listAvailableModels(cfg) + }) + + assert.Contains(t, output, " - first-model (openai/first)") + assert.Contains(t, output, "> - middle-model (openai/middle)") + assert.Contains(t, output, " - last-model (openai/last)") +} diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index b82475905..bf9c0389f 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -18,6 +18,7 @@ import ( "github.com/sipeed/picoclaw/cmd/picoclaw/internal/cron" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/gateway" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/migrate" + "github.com/sipeed/picoclaw/cmd/picoclaw/internal/model" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/onboard" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/skills" "github.com/sipeed/picoclaw/cmd/picoclaw/internal/status" @@ -43,6 +44,7 @@ func NewPicoclawCommand() *cobra.Command { cron.NewCronCommand(), migrate.NewMigrateCommand(), skills.NewSkillsCommand(), + model.NewModelCommand(), version.NewVersionCommand(), ) diff --git a/cmd/picoclaw/main_test.go b/cmd/picoclaw/main_test.go index e622675ee..ad18cb330 100644 --- a/cmd/picoclaw/main_test.go +++ b/cmd/picoclaw/main_test.go @@ -39,6 +39,7 @@ func TestNewPicoclawCommand(t *testing.T) { "cron", "gateway", "migrate", + "model", "onboard", "skills", "status", diff --git a/go.sum b/go.sum index 2e2b1a1ec..cdca4fc12 100644 --- a/go.sum +++ b/go.sum @@ -271,8 +271,6 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= -golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= -golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= diff --git a/pkg/config/config.go b/pkg/config/config.go index 93e2acfe2..190341224 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -222,8 +222,8 @@ type AgentDefaults struct { RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead + ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` + Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead ModelFallbacks []string `json:"model_fallbacks,omitempty"` ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 1c93028c7..c5bdbf3c3 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -342,8 +342,8 @@ func TestSaveConfig_IncludesEmptyLegacyModelField(t *testing.T) { t.Fatalf("ReadFile failed: %v", err) } - if !strings.Contains(string(data), `"model": ""`) { - t.Fatalf("saved config should include empty legacy model field, got: %s", string(data)) + if !strings.Contains(string(data), `"model_name": ""`) { + t.Fatalf("saved config should include empty legacy model_name field, got: %s", string(data)) } }