mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): use light provider for routed model calls (#2038)
This commit is contained in:
+23
-5
@@ -48,6 +48,9 @@ type AgentInstance struct {
|
||||
// LightCandidates holds the resolved provider candidates for the light model.
|
||||
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
|
||||
LightCandidates []providers.FallbackCandidate
|
||||
// LightProvider is the concrete provider instance for the configured light model.
|
||||
// It is only used when routing selects the light tier for a turn.
|
||||
LightProvider providers.LLMProvider
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -171,14 +174,28 @@ func NewAgentInstance(
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
var lightProvider providers.LLMProvider
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
lightModelCfg, err := resolvedModelConfig(cfg, rc.LightModel, workspace)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Routing light model config invalid; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()})
|
||||
} else {
|
||||
lp, _, err := providers.CreateProviderFromConfig(lightModelCfg)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Routing light model provider init failed; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()})
|
||||
} else {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
lightProvider = lp
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.WarnCF("agent", "Routing light model not found; routing disabled",
|
||||
map[string]any{"light_model": rc.LightModel, "agent_id": agentID})
|
||||
@@ -207,6 +224,7 @@ func NewAgentInstance(
|
||||
Candidates: candidates,
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
LightProvider: lightProvider,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+11
-7
@@ -1680,7 +1680,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.recordPersistedMessage(rootMsg)
|
||||
}
|
||||
|
||||
activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
|
||||
activeProvider := ts.agent.Provider
|
||||
if usedLight && ts.agent.LightProvider != nil {
|
||||
activeProvider = ts.agent.LightProvider
|
||||
}
|
||||
pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...)
|
||||
var finalContent string
|
||||
|
||||
@@ -1902,7 +1906,7 @@ turnLoop:
|
||||
providerCtx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return ts.agent.Provider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
@@ -1918,7 +1922,7 @@ turnLoop:
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
|
||||
return activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
|
||||
}
|
||||
|
||||
var response *providers.LLMResponse
|
||||
@@ -2747,9 +2751,9 @@ func (al *AgentLoop) selectCandidates(
|
||||
agent *AgentInstance,
|
||||
userMsg string,
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
) (candidates []providers.FallbackCandidate, model string, usedLight bool) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
@@ -2760,7 +2764,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
|
||||
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
@@ -2770,7 +2774,7 @@ func (al *AgentLoop) selectCandidates(
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
|
||||
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
|
||||
@@ -1296,6 +1296,46 @@ func newChatCompletionTestServer(
|
||||
}))
|
||||
}
|
||||
|
||||
func newStrictChatCompletionTestServer(
|
||||
t *testing.T,
|
||||
label string,
|
||||
expectedModel string,
|
||||
response string,
|
||||
calls *int,
|
||||
) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
|
||||
}
|
||||
*calls = *calls + 1
|
||||
defer r.Body.Close()
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode %s request: %v", label, err)
|
||||
}
|
||||
if req.Model != expectedModel {
|
||||
t.Fatalf("%s server model = %q, want %q", label, req.Model, expectedModel)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": response},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("encode %s response: %v", label, err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
@@ -1697,6 +1737,96 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
heavyCalls := 0
|
||||
heavyServer := newStrictChatCompletionTestServer(
|
||||
t,
|
||||
"heavy",
|
||||
"gemini-2.5-flash",
|
||||
"heavy reply",
|
||||
&heavyCalls,
|
||||
)
|
||||
defer heavyServer.Close()
|
||||
|
||||
lightCalls := 0
|
||||
lightServer := newStrictChatCompletionTestServer(
|
||||
t,
|
||||
"light",
|
||||
"qwen2.5:0.5b",
|
||||
"light reply",
|
||||
&lightCalls,
|
||||
)
|
||||
defer lightServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
ModelName: "gemini-main",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
Routing: &config.RoutingConfig{
|
||||
Enabled: true,
|
||||
LightModel: "qwen-light",
|
||||
Threshold: 0.99,
|
||||
},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "gemini-main",
|
||||
Model: "gemini/gemini-2.5-flash",
|
||||
APIBase: heavyServer.URL,
|
||||
},
|
||||
{
|
||||
ModelName: "qwen-light",
|
||||
Model: "ollama/qwen2.5:0.5b",
|
||||
APIBase: lightServer.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg.WithSecurity(&config.SecurityConfig{
|
||||
ModelList: map[string]config.ModelSecurityEntry{
|
||||
"gemini-main": {APIKeys: []string{"heavy-key"}},
|
||||
"qwen-light": {APIKeys: []string{"light-key"}},
|
||||
},
|
||||
})
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if resp != "light reply" {
|
||||
t.Fatalf("response = %q, want %q", resp, "light reply")
|
||||
}
|
||||
if heavyCalls != 0 {
|
||||
t.Fatalf("heavy calls = %d, want 0", heavyCalls)
|
||||
}
|
||||
if lightCalls != 1 {
|
||||
t.Fatalf("light calls = %d, want 1", lightCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
Reference in New Issue
Block a user