fix(agent): use light provider for routed model calls (#2038)

This commit is contained in:
xiwuqi
2026-03-28 02:25:23 -05:00
committed by GitHub
parent c6061dd0d7
commit e011284d8f
3 changed files with 164 additions and 12 deletions
+23 -5
View File
@@ -48,6 +48,9 @@ type AgentInstance struct {
// LightCandidates holds the resolved provider candidates for the light model.
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
LightCandidates []providers.FallbackCandidate
// LightProvider is the concrete provider instance for the configured light model.
// It is only used when routing selects the light tier for a turn.
LightProvider providers.LLMProvider
}
// NewAgentInstance creates an agent instance from config.
@@ -171,14 +174,28 @@ func NewAgentInstance(
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
var lightCandidates []providers.FallbackCandidate
var lightProvider providers.LLMProvider
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil)
if len(resolved) > 0 {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
Threshold: rc.Threshold,
})
lightCandidates = resolved
lightModelCfg, err := resolvedModelConfig(cfg, rc.LightModel, workspace)
if err != nil {
logger.WarnCF("agent", "Routing light model config invalid; routing disabled",
map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()})
} else {
lp, _, err := providers.CreateProviderFromConfig(lightModelCfg)
if err != nil {
logger.WarnCF("agent", "Routing light model provider init failed; routing disabled",
map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()})
} else {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
Threshold: rc.Threshold,
})
lightCandidates = resolved
lightProvider = lp
}
}
} else {
logger.WarnCF("agent", "Routing light model not found; routing disabled",
map[string]any{"light_model": rc.LightModel, "agent_id": agentID})
@@ -207,6 +224,7 @@ func NewAgentInstance(
Candidates: candidates,
Router: router,
LightCandidates: lightCandidates,
LightProvider: lightProvider,
}
}
+11 -7
View File
@@ -1680,7 +1680,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.recordPersistedMessage(rootMsg)
}
activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages)
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
activeProvider := ts.agent.Provider
if usedLight && ts.agent.LightProvider != nil {
activeProvider = ts.agent.LightProvider
}
pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...)
var finalContent string
@@ -1902,7 +1906,7 @@ turnLoop:
providerCtx,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return ts.agent.Provider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
},
)
if fbErr != nil {
@@ -1918,7 +1922,7 @@ turnLoop:
}
return fbResult.Response, nil
}
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
return activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
}
var response *providers.LLMResponse
@@ -2747,9 +2751,9 @@ func (al *AgentLoop) selectCandidates(
agent *AgentInstance,
userMsg string,
history []providers.Message,
) (candidates []providers.FallbackCandidate, model string) {
) (candidates []providers.FallbackCandidate, model string, usedLight bool) {
if agent.Router == nil || len(agent.LightCandidates) == 0 {
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
}
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
@@ -2760,7 +2764,7 @@ func (al *AgentLoop) selectCandidates(
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model)
return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false
}
logger.InfoCF("agent", "Model routing: light model selected",
@@ -2770,7 +2774,7 @@ func (al *AgentLoop) selectCandidates(
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel())
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
+130
View File
@@ -1296,6 +1296,46 @@ func newChatCompletionTestServer(
}))
}
func newStrictChatCompletionTestServer(
t *testing.T,
label string,
expectedModel string,
response string,
calls *int,
) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path)
}
*calls = *calls + 1
defer r.Body.Close()
var req struct {
Model string `json:"model"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode %s request: %v", label, err)
}
if req.Model != expectedModel {
t.Fatalf("%s server model = %q, want %q", label, req.Model, expectedModel)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": response},
"finish_reason": "stop",
},
},
}); err != nil {
t.Fatalf("encode %s response: %v", label, err)
}
}))
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
@@ -1697,6 +1737,96 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
}
}
func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
heavyCalls := 0
heavyServer := newStrictChatCompletionTestServer(
t,
"heavy",
"gemini-2.5-flash",
"heavy reply",
&heavyCalls,
)
defer heavyServer.Close()
lightCalls := 0
lightServer := newStrictChatCompletionTestServer(
t,
"light",
"qwen2.5:0.5b",
"light reply",
&lightCalls,
)
defer lightServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "gemini-main",
MaxTokens: 4096,
MaxToolIterations: 10,
Routing: &config.RoutingConfig{
Enabled: true,
LightModel: "qwen-light",
Threshold: 0.99,
},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "gemini-main",
Model: "gemini/gemini-2.5-flash",
APIBase: heavyServer.URL,
},
{
ModelName: "qwen-light",
Model: "ollama/qwen2.5:0.5b",
APIBase: lightServer.URL,
},
},
}
cfg.WithSecurity(&config.SecurityConfig{
ModelList: map[string]config.ModelSecurityEntry{
"gemini-main": {APIKeys: []string{"heavy-key"}},
"qwen-light": {APIKeys: []string{"light-key"}},
},
})
msgBus := bus.NewMessageBus()
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if resp != "light reply" {
t.Fatalf("response = %q, want %q", resp, "light reply")
}
if heavyCalls != 0 {
t.Fatalf("heavy calls = %d, want 0", heavyCalls)
}
if lightCalls != 1 {
t.Fatalf("light calls = %d, want 1", lightCalls)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")