mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): rebind provider after /switch model to (#1769)
* fix(agent): rebind provider after model switch * test(agent): deduplicate switch model mock servers --------- Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com>
This commit is contained in:
+2
-47
@@ -152,59 +152,14 @@ func NewAgentInstance(
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
Fallbacks: fallbacks,
|
||||
}
|
||||
resolveFromModelList := func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if cfg != nil {
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
|
||||
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
|
||||
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
|
||||
+30
-4
@@ -1477,7 +1477,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, agent.Model
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
@@ -1488,7 +1488,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, agent.Model
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
@@ -1498,7 +1498,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, agent.Router.LightModel()
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
@@ -1961,11 +1961,37 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
}
|
||||
if agent != nil {
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, cfg.Agents.Defaults.Provider
|
||||
return agent.Model, resolvedCandidateProvider(agent.Candidates, cfg.Agents.Defaults.Provider)
|
||||
}
|
||||
rt.SwitchModel = func(value string) (string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
modelCfg, err := resolvedModelConfig(cfg, value, agent.Workspace)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nextProvider, _, err := providers.CreateProviderFromConfig(modelCfg)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
|
||||
}
|
||||
|
||||
nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
|
||||
if len(nextCandidates) == 0 {
|
||||
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
|
||||
}
|
||||
|
||||
oldModel := agent.Model
|
||||
oldProvider := agent.Provider
|
||||
agent.Model = value
|
||||
agent.Provider = nextProvider
|
||||
agent.Candidates = nextCandidates
|
||||
agent.ThinkingLevel = parseThinkingLevel(modelCfg.ThinkingLevel)
|
||||
|
||||
if oldProvider != nil && oldProvider != nextProvider {
|
||||
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
|
||||
stateful.Close()
|
||||
}
|
||||
}
|
||||
return oldModel, nil
|
||||
}
|
||||
|
||||
|
||||
+242
-4
@@ -2,7 +2,10 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@@ -444,6 +447,46 @@ type testHelper struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func newChatCompletionTestServer(
|
||||
t *testing.T,
|
||||
label string,
|
||||
response string,
|
||||
calls *int,
|
||||
model *string,
|
||||
) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
||||
}
|
||||
*calls = *calls + 1
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
decodeErr := json.NewDecoder(r.Body).Decode(&req)
|
||||
if decodeErr != nil {
|
||||
t.Fatalf("decode %s request: %v", label, decodeErr)
|
||||
}
|
||||
*model = req.Model
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
encodeErr := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": response},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
})
|
||||
if encodeErr != nil {
|
||||
t.Fatalf("encode %s response: %v", label, encodeErr)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
@@ -605,11 +648,25 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "before-switch",
|
||||
Model: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
@@ -621,13 +678,13 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to after-switch",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
@@ -641,7 +698,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
|
||||
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
}
|
||||
|
||||
@@ -650,6 +707,187 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/local-model",
|
||||
APIKey: "test-key",
|
||||
APIBase: "https://local.example.invalid/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to missing",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if switchResp != `model "missing" not found in model_list or providers` {
|
||||
t.Fatalf("unexpected /switch error reply: %q", switchResp)
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
|
||||
}
|
||||
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for rejected /switch and /show, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
localCalls := 0
|
||||
localModel := ""
|
||||
localServer := newChatCompletionTestServer(t, "local", "local reply", &localCalls, &localModel)
|
||||
defer localServer.Close()
|
||||
|
||||
remoteCalls := 0
|
||||
remoteModel := ""
|
||||
remoteServer := newChatCompletionTestServer(t, "remote", "remote reply", &remoteCalls, &remoteModel)
|
||||
defer remoteServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "local",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
ModelList: []config.ModelConfig{
|
||||
{
|
||||
ModelName: "local",
|
||||
Model: "openai/Qwen3.5-35B-A3B",
|
||||
APIKey: "local-key",
|
||||
APIBase: localServer.URL,
|
||||
},
|
||||
{
|
||||
ModelName: "deepseek",
|
||||
Model: "openrouter/deepseek/deepseek-v3.2",
|
||||
APIKey: "remote-key",
|
||||
APIBase: remoteServer.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
firstResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello before switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if firstResp != "local reply" {
|
||||
t.Fatalf("unexpected response before switch: %q", firstResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls before switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 0 {
|
||||
t.Fatalf("remote calls before switch = %d, want 0", remoteCalls)
|
||||
}
|
||||
if localModel != "Qwen3.5-35B-A3B" {
|
||||
t.Fatalf("local model before switch = %q, want %q", localModel, "Qwen3.5-35B-A3B")
|
||||
}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
secondResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello after switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if secondResp != "remote reply" {
|
||||
t.Fatalf("unexpected response after switch: %q", secondResp)
|
||||
}
|
||||
if localCalls != 1 {
|
||||
t.Fatalf("local calls after switch = %d, want 1", localCalls)
|
||||
}
|
||||
if remoteCalls != 1 {
|
||||
t.Fatalf("remote calls after switch = %d, want 1", remoteCalls)
|
||||
}
|
||||
if remoteModel != "deepseek-v3.2" {
|
||||
t.Fatalf(
|
||||
"remote model after switch = %q, want %q",
|
||||
remoteModel,
|
||||
"deepseek-v3.2",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
|
||||
ensureProtocol := func(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(model, "/") {
|
||||
return model
|
||||
}
|
||||
return "openai/" + model
|
||||
}
|
||||
|
||||
return func(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || cfg == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
|
||||
return ensureProtocol(mc.Model), true
|
||||
}
|
||||
|
||||
for i := range cfg.ModelList {
|
||||
fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
|
||||
if fullModel == "" {
|
||||
continue
|
||||
}
|
||||
if fullModel == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
_, modelID := providers.ExtractProtocol(fullModel)
|
||||
if modelID == raw {
|
||||
return ensureProtocol(fullModel), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func resolveModelCandidates(
|
||||
cfg *config.Config,
|
||||
defaultProvider string,
|
||||
primary string,
|
||||
fallbacks []string,
|
||||
) []providers.FallbackCandidate {
|
||||
return providers.ResolveCandidatesWithLookup(
|
||||
providers.ModelConfig{
|
||||
Primary: primary,
|
||||
Fallbacks: fallbacks,
|
||||
},
|
||||
defaultProvider,
|
||||
buildModelListResolver(cfg),
|
||||
)
|
||||
}
|
||||
|
||||
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Model) != "" {
|
||||
return candidates[0].Model
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedCandidateProvider(candidates []providers.FallbackCandidate, fallback string) string {
|
||||
if len(candidates) > 0 && strings.TrimSpace(candidates[0].Provider) != "" {
|
||||
return candidates[0].Provider
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolvedModelConfig(cfg *config.Config, modelName, workspace string) (*config.ModelConfig, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is nil")
|
||||
}
|
||||
|
||||
modelCfg, err := cfg.GetModelConfig(strings.TrimSpace(modelName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clone := *modelCfg
|
||||
if clone.Workspace == "" {
|
||||
clone.Workspace = workspace
|
||||
}
|
||||
|
||||
return &clone, nil
|
||||
}
|
||||
Reference in New Issue
Block a user