add model command to set default model (#1250)

* add model command to set default model

* fix for ci

* fix test for model

* fix active agent not recognized

* implement test for model command

* fix local-model can not set as default issue

* fix review comment

* fix for comment
This commit is contained in:
Cytown
2026-03-13 14:10:11 +08:00
committed by GitHub
parent 9fed4ec136
commit dfa36f39cb
7 changed files with 514 additions and 6 deletions
+138
View File
@@ -0,0 +1,138 @@
package model
import (
"fmt"
"github.com/spf13/cobra"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/config"
)
// LocalModel is a special model name that indicates that the model is local and with or without api_key.
const LocalModel = "local-model"
func NewModelCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "model [model_name]",
Short: "Show or change the default model",
Long: `Show or change the default model configuration.
If no argument is provided, shows the current default model.
If a model name is provided, sets it as the default model.
Examples:
picoclaw model # Show current default model
picoclaw model gpt-5.2 # Set gpt-5.2 as default
picoclaw model claude-sonnet-4.6 # Set claude-sonnet-4.6 as default
picoclaw model local-model # Set local VLLM server as default
Note: 'local-model' is a special value for using a local VLLM server
(running at localhost:8000 by default) which does not require an API key.`,
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
configPath := internal.GetConfigPath()
// Load current config
cfg, err := config.LoadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
if len(args) == 0 {
// Show current default model
showCurrentModel(cfg)
return nil
}
// Set new default model
modelName := args[0]
return setDefaultModel(configPath, cfg, modelName)
},
}
return cmd
}
func showCurrentModel(cfg *config.Config) {
defaultModel := cfg.Agents.Defaults.ModelName
if defaultModel == "" {
defaultModel = cfg.Agents.Defaults.Model
}
if defaultModel == "" {
fmt.Println("No default model is currently set.")
fmt.Println("\nAvailable models in your config:")
listAvailableModels(cfg)
} else {
fmt.Printf("Current default model: %s\n", defaultModel)
fmt.Println("\nAvailable models in your config:")
listAvailableModels(cfg)
}
}
func listAvailableModels(cfg *config.Config) {
if len(cfg.ModelList) == 0 {
fmt.Println(" No models configured in model_list")
return
}
defaultModel := cfg.Agents.Defaults.ModelName
if defaultModel == "" {
defaultModel = cfg.Agents.Defaults.Model
}
for _, model := range cfg.ModelList {
marker := " "
if model.ModelName == defaultModel {
marker = "> "
}
if model.APIKey == "" {
continue
}
fmt.Printf("%s- %s (%s)\n", marker, model.ModelName, model.Model)
}
}
func setDefaultModel(configPath string, cfg *config.Config, modelName string) error {
// Validate that the model exists in model_list
modelFound := false
for _, model := range cfg.ModelList {
if model.APIKey != "" && model.ModelName == modelName {
modelFound = true
break
}
}
if !modelFound && modelName != LocalModel {
return fmt.Errorf("cannot found model '%s' in config", modelName)
}
// Update the default model
// Clear old model field and set new model_name
oldModel := cfg.Agents.Defaults.ModelName
if oldModel == "" {
oldModel = cfg.Agents.Defaults.Model
}
cfg.Agents.Defaults.ModelName = modelName
cfg.Agents.Defaults.Model = "" // Clear deprecated field
// Save config back to file
if err := config.SaveConfig(configPath, cfg); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
fmt.Printf("✓ Default model changed from '%s' to '%s'\n",
formatModelName(oldModel), modelName)
fmt.Println("\nThe new default model will be used for all agent interactions.")
return nil
}
func formatModelName(name string) string {
if name == "" {
return "(none)"
}
return name
}
+369
View File
@@ -0,0 +1,369 @@
package model
import (
"bytes"
"io"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/config"
)
var configPath = ""
func initTest(t *testing.T) {
tmpDir := t.TempDir()
configPath = filepath.Join(tmpDir, "config.json")
_ = os.Setenv("PICOCLAW_CONFIG", configPath)
}
// captureStdout captures stdout during the execution of fn and returns the captured output
func captureStdout(fn func()) string {
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
fn()
w.Close()
os.Stdout = oldStdout
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func TestNewModelCommand(t *testing.T) {
cmd := NewModelCommand()
require.NotNil(t, cmd)
assert.Equal(t, "model [model_name]", cmd.Use)
assert.Equal(t, "Show or change the default model", cmd.Short)
assert.Len(t, cmd.Aliases, 0)
assert.False(t, cmd.HasFlags())
assert.Nil(t, cmd.Run)
assert.NotNil(t, cmd.RunE)
assert.Nil(t, cmd.PersistentPreRunE)
assert.Nil(t, cmd.PersistentPreRun)
assert.Nil(t, cmd.PersistentPostRun)
}
func TestShowCurrentModel_WithDefaultModel(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "gpt-4",
},
},
ModelList: []config.ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKey: "test"},
{ModelName: "claude-3", Model: "anthropic/claude-3", APIKey: "test"},
},
}
output := captureStdout(func() {
showCurrentModel(cfg)
})
assert.Contains(t, output, "Current default model: gpt-4")
assert.Contains(t, output, "Available models in your config:")
assert.Contains(t, output, "gpt-4")
assert.Contains(t, output, "claude-3")
}
func TestShowCurrentModel_NoDefaultModel(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "",
Model: "",
},
},
ModelList: []config.ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKey: "test"},
},
}
output := captureStdout(func() {
showCurrentModel(cfg)
})
assert.Contains(t, output, "No default model is currently set.")
assert.Contains(t, output, "Available models in your config:")
}
func TestShowCurrentModel_BackwardCompatibility(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Model: "legacy-model",
},
},
ModelList: []config.ModelConfig{},
}
output := captureStdout(func() {
showCurrentModel(cfg)
})
assert.Contains(t, output, "Current default model: legacy-model")
}
func TestListAvailableModels_Empty(t *testing.T) {
cfg := &config.Config{
ModelList: []config.ModelConfig{},
}
output := captureStdout(func() {
listAvailableModels(cfg)
})
assert.Contains(t, output, "No models configured in model_list")
}
func TestListAvailableModels_WithModels(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "gpt-4",
},
},
ModelList: []config.ModelConfig{
{ModelName: "gpt-4", Model: "openai/gpt-4", APIKey: "test"},
{ModelName: "claude-3", Model: "anthropic/claude-3", APIKey: "test"},
{ModelName: "no-key-model", Model: "openai/test", APIKey: ""},
},
}
output := captureStdout(func() {
listAvailableModels(cfg)
})
assert.NotEmpty(t, output)
assert.Contains(t, output, "> - gpt-4 (openai/gpt-4)")
assert.Contains(t, output, "claude-3 (anthropic/claude-3)")
assert.NotContains(t, output, "no-key-model")
}
func TestSetDefaultModel_ValidModel(t *testing.T) {
initTest(t)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "old-model",
},
},
ModelList: []config.ModelConfig{
{ModelName: "new-model", Model: "openai/new-model", APIKey: "test"},
{ModelName: "old-model", Model: "openai/old-model", APIKey: "test"},
},
}
output := captureStdout(func() {
err := setDefaultModel(configPath, cfg, "new-model")
assert.NoError(t, err)
})
assert.Contains(t, output, "Default model changed from 'old-model' to 'new-model'")
// Verify config was updated
updatedCfg, err := config.LoadConfig(configPath)
require.NoError(t, err)
assert.Equal(t, "new-model", updatedCfg.Agents.Defaults.ModelName)
assert.Empty(t, updatedCfg.Agents.Defaults.Model)
}
func TestSetDefaultModel_LegacyModelField(t *testing.T) {
initTest(t)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Model: "legacy-old",
},
},
ModelList: []config.ModelConfig{
{ModelName: "new-model", Model: "openai/new-model", APIKey: "test"},
},
}
output := captureStdout(func() {
err := setDefaultModel(configPath, cfg, "new-model")
assert.NoError(t, err)
})
assert.Contains(t, output, "Default model changed from 'legacy-old' to 'new-model'")
}
func TestSetDefaultModel_InvalidModel(t *testing.T) {
initTest(t)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "existing-model",
},
},
ModelList: []config.ModelConfig{
{ModelName: "existing-model", Model: "openai/existing", APIKey: "test"},
},
}
assert.Error(t, setDefaultModel(configPath, cfg, "nonexistent-model"))
}
func TestSetDefaultModel_ModelWithoutAPIKey(t *testing.T) {
initTest(t)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "existing-model",
},
},
ModelList: []config.ModelConfig{
{ModelName: "existing-model", Model: "openai/existing", APIKey: "test"},
{ModelName: "no-key-model", Model: "openai/nokey", APIKey: ""},
},
}
assert.Error(t, setDefaultModel(configPath, cfg, "no-key-model"))
}
func TestSetDefaultModel_SaveConfigError(t *testing.T) {
// Use an invalid path to trigger save error
invalidPath := "/nonexistent/directory/config.json"
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "old-model",
},
},
ModelList: []config.ModelConfig{
{ModelName: "new-model", Model: "openai/new-model", APIKey: "test"},
},
}
err := setDefaultModel(invalidPath, cfg, "new-model")
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to save config")
}
func TestFormatModelName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"empty string", "", "(none)"},
{"simple model", "gpt-4", "gpt-4"},
{"model with version", "claude-sonnet-4.6", "claude-sonnet-4.6"},
{"model with spaces", "my model", "my model"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatModelName(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestModelCommandExecution_Show(t *testing.T) {
initTest(t)
// Create a test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "test-model",
},
},
ModelList: []config.ModelConfig{
{ModelName: "test-model", Model: "openai/test", APIKey: "test"},
},
}
err := config.SaveConfig(configPath, cfg)
require.NoError(t, err)
cmd := NewModelCommand()
output := captureStdout(func() {
err = cmd.RunE(cmd, []string{})
assert.NoError(t, err)
})
assert.Contains(t, output, "Current default model: test-model")
}
func TestModelCommandExecution_Set(t *testing.T) {
initTest(t)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "old-model",
},
},
ModelList: []config.ModelConfig{
{ModelName: "old-model", Model: "openai/old", APIKey: "test"},
{ModelName: "new-model", Model: "openai/new", APIKey: "test"},
},
}
err := config.SaveConfig(configPath, cfg)
require.NoError(t, err)
cmd := NewModelCommand()
output := captureStdout(func() {
err = cmd.RunE(cmd, []string{"new-model"})
assert.NoError(t, err)
})
assert.Contains(t, output, "Default model changed from 'old-model' to 'new-model'")
}
func TestModelCommandExecution_TooManyArgs(t *testing.T) {
cmd := NewModelCommand()
err := cmd.RunE(cmd, []string{"model1", "model2"})
assert.Error(t, err)
}
func TestListAvailableModels_MarkerLogic(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
ModelName: "middle-model",
},
},
ModelList: []config.ModelConfig{
{ModelName: "first-model", Model: "openai/first", APIKey: "test"},
{ModelName: "middle-model", Model: "openai/middle", APIKey: "test"},
{ModelName: "last-model", Model: "openai/last", APIKey: "test"},
},
}
output := captureStdout(func() {
listAvailableModels(cfg)
})
assert.Contains(t, output, " - first-model (openai/first)")
assert.Contains(t, output, "> - middle-model (openai/middle)")
assert.Contains(t, output, " - last-model (openai/last)")
}
+2
View File
@@ -18,6 +18,7 @@ import (
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/cron"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/gateway"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/migrate"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/model"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/onboard"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/skills"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/status"
@@ -43,6 +44,7 @@ func NewPicoclawCommand() *cobra.Command {
cron.NewCronCommand(),
migrate.NewMigrateCommand(),
skills.NewSkillsCommand(),
model.NewModelCommand(),
version.NewVersionCommand(),
)
+1
View File
@@ -39,6 +39,7 @@ func TestNewPicoclawCommand(t *testing.T) {
"cron",
"gateway",
"migrate",
"model",
"onboard",
"skills",
"status",
-2
View File
@@ -271,8 +271,6 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
+2 -2
View File
@@ -222,8 +222,8 @@ type AgentDefaults struct {
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
+2 -2
View File
@@ -342,8 +342,8 @@ func TestSaveConfig_IncludesEmptyLegacyModelField(t *testing.T) {
t.Fatalf("ReadFile failed: %v", err)
}
if !strings.Contains(string(data), `"model": ""`) {
t.Fatalf("saved config should include empty legacy model field, got: %s", string(data))
if !strings.Contains(string(data), `"model_name": ""`) {
t.Fatalf("saved config should include empty legacy model_name field, got: %s", string(data))
}
}