fix(agent): rebind provider after /switch model to (#1769)

* fix(agent): rebind provider after model switch

* test(agent): deduplicate switch model mock servers

---------

Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com>
This commit is contained in:
Alix-007
2026-03-19 21:44:01 +08:00
committed by GitHub
parent 05c65d2fe7
commit 276a0cb92c
4 changed files with 371 additions and 55 deletions
+2 -47
View File
@@ -152,59 +152,14 @@ func NewAgentInstance(
}
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
Fallbacks: fallbacks,
}
resolveFromModelList := func(raw string) (string, bool) {
ensureProtocol := func(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
}
raw = strings.TrimSpace(raw)
if raw == "" {
return "", false
}
if cfg != nil {
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return ensureProtocol(mc.Model), true
}
for i := range cfg.ModelList {
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return ensureProtocol(fullModel), true
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return ensureProtocol(fullModel), true
}
}
}
return "", false
}
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
var lightCandidates []providers.FallbackCandidate
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
if len(resolved) > 0 {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
+30 -4
View File
@@ -1477,7 +1477,7 @@ func (al *AgentLoop) selectCandidates(
history []providers.Message,
) (candidates []providers.FallbackCandidate, model string) {
if agent.Router == nil || len(agent.LightCandidates) == 0 {
return agent.Candidates, agent.Model
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
}
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
@@ -1488,7 +1488,7 @@ func (al *AgentLoop) selectCandidates(
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.Candidates, agent.Model
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
}
logger.InfoCF("agent", "Model routing: light model selected",
@@ -1498,7 +1498,7 @@ func (al *AgentLoop) selectCandidates(
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, agent.Router.LightModel()
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
@@ -1961,11 +1961,37 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
}
if agent != nil {
rt.GetModelInfo = func() (string, string) {
return agent.Model, cfg.Agents.Defaults.Provider
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
}
rt.SwitchModel = func(value string) (string, error) {
value = strings.TrimSpace(value)
modelCfg, err := resolvedModelConfig(cfg, value, agent.Workspace)
if err != nil {
return "", err
}
nextProvider, _, err := providers.CreateProviderFromConfig(modelCfg)
if err != nil {
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
}
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
if len(nextCandidates) == 0 {
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
}
oldModel := agent.Model
oldProvider := agent.Provider
agent.Model = value
agent.Provider = nextProvider
agent.Candidates = nextCandidates
agent.ThinkingLevel = parseThinkingLevel(modelCfg.ThinkingLevel)
if oldProvider != nil && oldProvider != nextProvider {
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
stateful.Close()
}
}
return oldModel, nil
}
+242 -4
View File
@@ -2,7 +2,10 @@ package agent
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
@@ -444,6 +447,46 @@ type testHelper struct {
al *AgentLoop
}
func newChatCompletionTestServer(
t *testing.T,
label string,
response string,
calls *int,
model *string,
) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
}
*calls = *calls + 1
defer r.Body.Close()
var req struct {
Model string `json:"model"`
}
decodeErr := json.NewDecoder(r.Body).Decode(&req)
if decodeErr != nil {
t.Fatalf("decode %s request: %v", label, decodeErr)
}
*model = req.Model
w.Header().Set("Content-Type", "application/json")
encodeErr := json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": response},
"finish_reason": "stop",
},
},
})
if encodeErr != nil {
t.Fatalf("encode %s response: %v", label, encodeErr)
}
}))
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
@@ -605,11 +648,25 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "before-switch",
Model: "local",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []config.ModelConfig{
{
ModelName: "local",
Model: "openai/local-model",
APIKey: "test-key",
APIBase: "https://local.example.invalid/v1",
},
{
ModelName: "deepseek",
Model: "openrouter/deepseek/deepseek-v3.2",
APIKey: "test-key",
APIBase: "https://openrouter.ai/api/v1",
},
},
}
msgBus := bus.NewMessageBus()
@@ -621,13 +678,13 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to after-switch",
Content: "/switch model to deepseek",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
}
@@ -641,7 +698,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
}
@@ -650,6 +707,187 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
}
}
func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "local",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []config.ModelConfig{
{
ModelName: "local",
Model: "openai/local-model",
APIKey: "test-key",
APIBase: "https://local.example.invalid/v1",
},
},
}
msgBus := bus.NewMessageBus()
provider := &countingMockProvider{response: "LLM reply"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to missing",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if switchResp != `model "missing" not found in model_list or providers` {
t.Fatalf("unexpected /switch error reply: %q", switchResp)
}
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/show model",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
}
if provider.calls != 0 {
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
}
}
func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
localCalls := 0
localModel := ""
localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel)
defer localServer.Close()
remoteCalls := 0
remoteModel := ""
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
defer remoteServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "local",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
ModelList: []config.ModelConfig{
{
ModelName: "local",
Model: "openai/Qwen3.5-35B-A3B",
APIKey: "local-key",
APIBase: localServer.URL,
},
{
ModelName: "deepseek",
Model: "openrouter/deepseek/deepseek-v3.2",
APIKey: "remote-key",
APIBase: remoteServer.URL,
},
},
}
msgBus := bus.NewMessageBus()
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello before switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if firstResp != "local reply" {
t.Fatalf("unexpected response before switch: %q", firstResp)
}
if localCalls != 1 {
t.Fatalf("local calls before switch = %d, want 1", localCalls)
}
if remoteCalls != 0 {
t.Fatalf("remote calls before switch = %d, want 0", remoteCalls)
}
if localModel != "Qwen3.5-35B-A3B" {
t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B")
}
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to deepseek",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
}
secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello after switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if secondResp != "remote reply" {
t.Fatalf("unexpected response after switch: %q", secondResp)
}
if localCalls != 1 {
t.Fatalf("local calls after switch = %d, want 1", localCalls)
}
if remoteCalls != 1 {
t.Fatalf("remote calls after switch = %d, want 1", remoteCalls)
}
if remoteModel != "deepseek-v3.2" {
t.Fatalf(
"remote model after switch = %q, want %q",
remoteModel,
"deepseek-v3.2",
)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
+97
View File
@@ -0,0 +1,97 @@
package agent
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
ensureProtocol := func(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
if strings.Contains(model, "/") {
return model
}
return "openai/" + model
}
return func(raw string) (string, bool) {
raw = strings.TrimSpace(raw)
if raw == "" || cfg == nil {
return "", false
}
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
return ensureProtocol(mc.Model), true
}
for i := range cfg.ModelList {
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
if fullModel == "" {
continue
}
if fullModel == raw {
return ensureProtocol(fullModel), true
}
_, modelID := providers.ExtractProtocol(fullModel)
if modelID == raw {
return ensureProtocol(fullModel), true
}
}
return "", false
}
}
func resolveModelCandidates(
cfg *config.Config,
defaultProvider string,
primary string,
fallbacks []string,
) []providers.FallbackCandidate {
return providers.ResolveCandidatesWithLookup(
providers.ModelConfig{
Primary: primary,
Fallbacks: fallbacks,
},
defaultProvider,
buildModelListResolver(cfg),
)
}
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Model) != "" {
return candidates[0].Model
}
return fallback
}
func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallback string) string {
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Provider) != "" {
return candidates[0].Provider
}
return fallback
}
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
if cfg == nil {
return nil, fmt.Errorf("config is nil")
}
modelCfg, err := cfg.GetModelConfig(strings.TrimSpace(modelName))
if err != nil {
return nil, err
}
clone := *modelCfg
if clone.Workspace == "" {
clone.Workspace = workspace
}
return &clone, nil
}