diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index cef736981..880725660 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -48,6 +48,9 @@ type AgentInstance struct { // LightCandidates holds the resolved provider candidates for the light model. // Pre-computed at agent creation to avoid repeated model_list lookups at runtime. LightCandidates []providers.FallbackCandidate + // LightProvider is the concrete provider instance for the configured light model. + // It is only used when routing selects the light tier for a turn. + LightProvider providers.LLMProvider } // NewAgentInstance creates an agent instance from config. @@ -171,14 +174,28 @@ func NewAgentInstance( // to avoid repeated model_list lookups on every incoming message. var router *routing.Router var lightCandidates []providers.FallbackCandidate + var lightProvider providers.LLMProvider if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" { resolved := resolveModelCandidates(cfg, defaults.Provider, rc.LightModel, nil) if len(resolved) > 0 { - router = routing.New(routing.RouterConfig{ - LightModel: rc.LightModel, - Threshold: rc.Threshold, - }) - lightCandidates = resolved + lightModelCfg, err := resolvedModelConfig(cfg, rc.LightModel, workspace) + if err != nil { + logger.WarnCF("agent", "Routing light model config invalid; routing disabled", + map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()}) + } else { + lp, _, err := providers.CreateProviderFromConfig(lightModelCfg) + if err != nil { + logger.WarnCF("agent", "Routing light model provider init failed; routing disabled", + map[string]any{"light_model": rc.LightModel, "agent_id": agentID, "error": err.Error()}) + } else { + router = routing.New(routing.RouterConfig{ + LightModel: rc.LightModel, + Threshold: rc.Threshold, + }) + lightCandidates = resolved + lightProvider = lp + } + } } else { logger.WarnCF("agent", "Routing light model not found; routing disabled", map[string]any{"light_model": rc.LightModel, "agent_id": agentID}) @@ -207,6 +224,7 @@ func NewAgentInstance( Candidates: candidates, Router: router, LightCandidates: lightCandidates, + LightProvider: lightProvider, } } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index fb9edda25..48932b10b 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1680,7 +1680,11 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er ts.recordPersistedMessage(rootMsg) } - activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages) + activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages) + activeProvider := ts.agent.Provider + if usedLight && ts.agent.LightProvider != nil { + activeProvider = ts.agent.LightProvider + } pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...) var finalContent string @@ -1902,7 +1906,7 @@ turnLoop: providerCtx, activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return ts.agent.Provider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts) + return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts) }, ) if fbErr != nil { @@ -1918,7 +1922,7 @@ turnLoop: } return fbResult.Response, nil } - return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts) + return activeProvider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts) } var response *providers.LLMResponse @@ -2747,9 +2751,9 @@ func (al *AgentLoop) selectCandidates( agent *AgentInstance, userMsg string, history []providers.Message, -) (candidates []providers.FallbackCandidate, model string) { +) (candidates []providers.FallbackCandidate, model string, usedLight bool) { if agent.Router == nil || len(agent.LightCandidates) == 0 { - return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model) + return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false } _, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model) @@ -2760,7 +2764,7 @@ func (al *AgentLoop) selectCandidates( "score": score, "threshold": agent.Router.Threshold(), }) - return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model) + return agent.Candidates, resolvedCandidateModel(agent.Candidates, agent.Model), false } logger.InfoCF("agent", "Model routing: light model selected", @@ -2770,7 +2774,7 @@ func (al *AgentLoop) selectCandidates( "score": score, "threshold": agent.Router.Threshold(), }) - return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()) + return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true } // maybeSummarize triggers summarization if the session history exceeds thresholds. diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 77c2e3c10..4fd916d33 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -1296,6 +1296,46 @@ func newChatCompletionTestServer( })) } +func newStrictChatCompletionTestServer( + t *testing.T, + label string, + expectedModel string, + response string, + calls *int, +) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + t.Fatalf("%s server path = %q, want /chat/completions", label, r.URL.Path) + } + *calls = *calls + 1 + defer r.Body.Close() + + var req struct { + Model string `json:"model"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode %s request: %v", label, err) + } + if req.Model != expectedModel { + t.Fatalf("%s server model = %q, want %q", label, req.Model, expectedModel) + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": response}, + "finish_reason": "stop", + }, + }, + }); err != nil { + t.Fatalf("encode %s response: %v", label, err) + } + })) +} + func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string { // Use a short timeout to avoid hanging timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout) @@ -1697,6 +1737,96 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t } } +func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + heavyCalls := 0 + heavyServer := newStrictChatCompletionTestServer( + t, + "heavy", + "gemini-2.5-flash", + "heavy reply", + &heavyCalls, + ) + defer heavyServer.Close() + + lightCalls := 0 + lightServer := newStrictChatCompletionTestServer( + t, + "light", + "qwen2.5:0.5b", + "light reply", + &lightCalls, + ) + defer lightServer.Close() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "gemini-main", + MaxTokens: 4096, + MaxToolIterations: 10, + Routing: &config.RoutingConfig{ + Enabled: true, + LightModel: "qwen-light", + Threshold: 0.99, + }, + }, + }, + ModelList: []*config.ModelConfig{ + { + ModelName: "gemini-main", + Model: "gemini/gemini-2.5-flash", + APIBase: heavyServer.URL, + }, + { + ModelName: "qwen-light", + Model: "ollama/qwen2.5:0.5b", + APIBase: lightServer.URL, + }, + }, + } + cfg.WithSecurity(&config.SecurityConfig{ + ModelList: map[string]config.ModelSecurityEntry{ + "gemini-main": {APIKeys: []string{"heavy-key"}}, + "qwen-light": {APIKeys: []string{"light-key"}}, + }, + }) + + msgBus := bus.NewMessageBus() + provider, _, err := providers.CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hi", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if resp != "light reply" { + t.Fatalf("response = %q, want %q", resp, "light reply") + } + if heavyCalls != 0 { + t.Fatalf("heavy calls = %d, want 0", heavyCalls) + } + if lightCalls != 1 { + t.Fatalf("light calls = %d, want 1", lightCalls) + } +} + // TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*")