From 276a0cb92cfaa886ac0332b533659365262472ae Mon Sep 17 00:00:00 2001 From: Alix-007 Date: Thu, 19 Mar 2026 21:44:01 +0800 Subject: [PATCH] fix(agent): rebind provider after /switch model to (#1769) * fix(agent): rebind provider after model switch * test(agent): deduplicate switch model mock servers --------- Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com> --- pkg/agent/instance.go | 49 +------ pkg/agent/loop.go | 34 ++++- pkg/agent/loop_test.go | 246 +++++++++++++++++++++++++++++++++- pkg/agent/model_resolution.go | 97 ++++++++++++++ 4 files changed, 371 insertions(+), 55 deletions(-) create mode 100644 pkg/agent/model_resolution.go diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index d2a4f81a4..355e78a33 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -152,59 +152,14 @@ func NewAgentInstance( } // Resolve fallback candidates - modelCfg := providers.ModelConfig{ - Primary: model, - Fallbacks: fallbacks, - } - resolveFromModelList := func(raw string) (string, bool) { - ensureProtocol := func(model string) string { - model = strings.TrimSpace(model) - if model == "" { - return "" - } - if strings.Contains(model, "/") { - return model - } - return "openai/" + model - } - - raw = strings.TrimSpace(raw) - if raw == "" { - return "", false - } - - if cfg != nil { - if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" { - return ensureProtocol(mc.Model), true - } - - for i := range cfg.ModelList { - fullModel := strings.TrimSpace(cfg.ModelList[i].Model) - if fullModel == "" { - continue - } - if fullModel == raw { - return ensureProtocol(fullModel), true - } - _, modelID := providers.ExtractProtocol(fullModel) - if modelID == raw { - return ensureProtocol(fullModel), true - } - } - } - - return "", false - } - - candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList) + candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks) // Model routing setup: pre-resolve light model candidates at creation time // to avoid repeated model_list lookups on every incoming message. var router *routing.Router var lightCandidates []providers.FallbackCandidate if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" { - lightModelCfg := providers.ModelConfig{Primary: rc.LightModel} - resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList) + resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil) if len(resolved) > 0 { router = routing.New(routing.RouterConfig{ LightModel: rc.LightModel, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index edb0994c2..aade18014 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1477,7 +1477,7 @@ func (al *AgentLoop) selectCandidates( history []providers.Message, ) (candidates []providers.FallbackCandidate, model string) { if agent.Router == nil || len(agent.LightCandidates) == 0 { - return agent.Candidates, agent.Model + return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model) } _, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model) @@ -1488,7 +1488,7 @@ func (al *AgentLoop) selectCandidates( "score": score, "threshold": agent.Router.Threshold(), }) - return agent.Candidates, agent.Model + return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model) } logger.InfoCF("agent", "Model routing: light model selected", @@ -1498,7 +1498,7 @@ func (al *AgentLoop) selectCandidates( "score": score, "threshold": agent.Router.Threshold(), }) - return agent.LightCandidates, agent.Router.LightModel() + return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()) } // maybeSummarize triggers summarization if the session history exceeds thresholds. @@ -1961,11 +1961,37 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt } if agent != nil { rt.GetModelInfo = func() (string, string) { - return agent.Model, cfg.Agents.Defaults.Provider + return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider) } rt.SwitchModel = func(value string) (string, error) { + value = strings.TrimSpace(value) + modelCfg, err := resolvedModelConfig(cfg, value, agent.Workspace) + if err != nil { + return "", err + } + + nextProvider, _, err := providers.CreateProviderFromConfig(modelCfg) + if err != nil { + return "", fmt.Errorf("failed to initialize model %q: %w", value, err) + } + + nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks) + if len(nextCandidates) == 0 { + return "", fmt.Errorf("model %q did not resolve to any provider candidates", value) + } + oldModel := agent.Model + oldProvider := agent.Provider agent.Model = value + agent.Provider = nextProvider + agent.Candidates = nextCandidates + agent.ThinkingLevel = parseThinkingLevel(modelCfg.ThinkingLevel) + + if oldProvider != nil && oldProvider != nextProvider { + if stateful, ok := oldProvider.(providers.StatefulProvider); ok { + stateful.Close() + } + } return oldModel, nil } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 8432ccac4..b6b6c2c6c 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -2,7 +2,10 @@ package agent import ( "context" + "encoding/json" "fmt" + "net/http" + "net/http/httptest" "os" "path/filepath" "slices" @@ -444,6 +447,46 @@ type testHelper struct { al *AgentLoop } +func newChatCompletionTestServer( + t *testing.T, + label string, + response string, + calls *int, + model *string, +) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path) + } + *calls = *calls + 1 + defer r.Body.Close() + + var req struct { + Model string `json:"model"` + } + decodeErr := json.NewDecoder(r.Body).Decode(&req) + if decodeErr != nil { + t.Fatalf("decode %s request: %v", label, decodeErr) + } + *model = req.Model + + w.Header().Set("Content-Type", "application/json") + encodeErr := json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": response}, + "finish_reason": "stop", + }, + }, + }) + if encodeErr != nil { + t.Fatalf("encode %s response: %v", label, encodeErr) + } + })) +} + func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string { // Use a short timeout to avoid hanging timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout) @@ -605,11 +648,25 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { Defaults: config.AgentDefaults{ Workspace: tmpDir, Provider: "openai", - Model: "before-switch", + Model: "local", MaxTokens: 4096, MaxToolIterations: 10, }, }, + ModelList: []config.ModelConfig{ + { + ModelName: "local", + Model: "openai/local-model", + APIKey: "test-key", + APIBase: "https://local.example.invalid/v1", + }, + { + ModelName: "deepseek", + Model: "openrouter/deepseek/deepseek-v3.2", + APIKey: "test-key", + APIBase: "https://openrouter.ai/api/v1", + }, + }, } msgBus := bus.NewMessageBus() @@ -621,13 +678,13 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { Channel: "telegram", SenderID: "user1", ChatID: "chat1", - Content: "/switch model to after-switch", + Content: "/switch model to deepseek", Peer: bus.Peer{ Kind: "direct", ID: "user1", }, }) - if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") { + if !strings.Contains(switchResp, "Switched model from local to deepseek") { t.Fatalf("unexpected /switch reply: %q", switchResp) } @@ -641,7 +698,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { ID: "user1", }, }) - if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") { + if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") { t.Fatalf("unexpected /show model reply after switch: %q", showResp) } @@ -650,6 +707,187 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { } } +func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Provider: "openai", + Model: "local", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + ModelList: []config.ModelConfig{ + { + ModelName: "local", + Model: "openai/local-model", + APIKey: "test-key", + APIBase: "https://local.example.invalid/v1", + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &countingMockProvider{response: "LLM reply"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/switch model to missing", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if switchResp != `model "missing" not found in model_list or providers` { + t.Fatalf("unexpected /switch error reply: %q", switchResp) + } + + showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/show model", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(showResp, "Current Model: local (Provider: openai)") { + t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp) + } + + if provider.calls != 0 { + t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls) + } +} + +func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + localCalls := 0 + localModel := "" + localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel) + defer localServer.Close() + + remoteCalls := 0 + remoteModel := "" + remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel) + defer remoteServer.Close() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Provider: "openai", + Model: "local", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + ModelList: []config.ModelConfig{ + { + ModelName: "local", + Model: "openai/Qwen3.5-35B-A3B", + APIKey: "local-key", + APIBase: localServer.URL, + }, + { + ModelName: "deepseek", + Model: "openrouter/deepseek/deepseek-v3.2", + APIKey: "remote-key", + APIBase: remoteServer.URL, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider, _, err := providers.CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hello before switch", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if firstResp != "local reply" { + t.Fatalf("unexpected response before switch: %q", firstResp) + } + if localCalls != 1 { + t.Fatalf("local calls before switch = %d, want 1", localCalls) + } + if remoteCalls != 0 { + t.Fatalf("remote calls before switch = %d, want 0", remoteCalls) + } + if localModel != "Qwen3.5-35B-A3B" { + t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B") + } + + switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/switch model to deepseek", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(switchResp, "Switched model from local to deepseek") { + t.Fatalf("unexpected /switch reply: %q", switchResp) + } + + secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hello after switch", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if secondResp != "remote reply" { + t.Fatalf("unexpected response after switch: %q", secondResp) + } + if localCalls != 1 { + t.Fatalf("local calls after switch = %d, want 1", localCalls) + } + if remoteCalls != 1 { + t.Fatalf("remote calls after switch = %d, want 1", remoteCalls) + } + if remoteModel != "deepseek-v3.2" { + t.Fatalf( + "remote model after switch = %q, want %q", + remoteModel, + "deepseek-v3.2", + ) + } +} + // TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") diff --git a/pkg/agent/model_resolution.go b/pkg/agent/model_resolution.go new file mode 100644 index 000000000..140cff718 --- /dev/null +++ b/pkg/agent/model_resolution.go @@ -0,0 +1,97 @@ +package agent + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) { + ensureProtocol := func(model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "" + } + if strings.Contains(model, "/") { + return model + } + return "openai/" + model + } + + return func(raw string) (string, bool) { + raw = strings.TrimSpace(raw) + if raw == "" || cfg == nil { + return "", false + } + + if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" { + return ensureProtocol(mc.Model), true + } + + for i := range cfg.ModelList { + fullModel := strings.TrimSpace(cfg.ModelList[i].Model) + if fullModel == "" { + continue + } + if fullModel == raw { + return ensureProtocol(fullModel), true + } + _, modelID := providers.ExtractProtocol(fullModel) + if modelID == raw { + return ensureProtocol(fullModel), true + } + } + + return "", false + } +} + +func resolveModelCandidates( + cfg *config.Config, + defaultProvider string, + primary string, + fallbacks []string, +) []providers.FallbackCandidate { + return providers.ResolveCandidatesWithLookup( + providers.ModelConfig{ + Primary: primary, + Fallbacks: fallbacks, + }, + defaultProvider, + buildModelListResolver(cfg), + ) +} + +func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string { + if len(candidates) > 0 && strings.TrimSpace(candidates[0].Model) != "" { + return candidates[0].Model + } + return fallback +} + +func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallback string) string { + if len(candidates) > 0 && strings.TrimSpace(candidates[0].Provider) != "" { + return candidates[0].Provider + } + return fallback +} + +func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) { + if cfg == nil { + return nil, fmt.Errorf("config is nil") + } + + modelCfg, err := cfg.GetModelConfig(strings.TrimSpace(modelName)) + if err != nil { + return nil, err + } + + clone := *modelCfg + if clone.Workspace == "" { + clone.Workspace = workspace + } + + return &clone, nil +}