Merge pull request #994 from is-Xiaoen/feat/model-routing

feat(routing): intelligent model routing based on structural complexity scoring
This commit is contained in:
Meng Zhuo
2026-03-06 15:47:53 +08:00
committed by GitHub
7 changed files with 809 additions and 20 deletions
+29
View File
@@ -37,6 +37,14 @@ type AgentInstance struct {
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
// Router is non-nil when model routing is configured and the light model
// was successfully resolved. It scores each incoming message and decides
// whether to route to LightCandidates or stay with Candidates.
Router *routing.Router
// LightCandidates holds the resolved provider candidates for the light model.
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
LightCandidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -180,6 +188,25 @@ func NewAgentInstance(
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
var lightCandidates []providers.FallbackCandidate
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
if len(resolved) > 0 {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
Threshold: rc.Threshold,
})
lightCandidates = resolved
} else {
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
rc.LightModel, agentID)
}
}
return &AgentInstance{
ID: agentID,
Name: agentName,
@@ -200,6 +227,8 @@ func NewAgentInstance(
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
Router: router,
LightCandidates: lightCandidates,
}
}
+49 -5
View File
@@ -824,6 +824,12 @@ func (al *AgentLoop) runLLMIteration(
iteration := 0
var finalContent string
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
// all tool-follow-up iterations within the same turn so that a multi-step
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
for iteration < agent.MaxIterations {
iteration++
@@ -842,7 +848,7 @@ func (al *AgentLoop) runLLMIteration(
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"model": agent.Model,
"model": activeModel,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"max_tokens": agent.MaxTokens,
@@ -858,7 +864,7 @@ func (al *AgentLoop) runLLMIteration(
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM with fallback chain if candidates are configured.
// Call LLM with fallback chain if multiple candidates are configured.
var response *providers.LLMResponse
var err error
@@ -879,10 +885,10 @@ func (al *AgentLoop) runLLMIteration(
}
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
ctx,
agent.Candidates,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
},
@@ -900,7 +906,7 @@ func (al *AgentLoop) runLLMIteration(
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
}
// Retry loop for context/token errors
@@ -1169,6 +1175,44 @@ func (al *AgentLoop) runLLMIteration(
return finalContent, iteration, nil
}
// selectCandidates returns the model candidates and resolved model name to use
// for a conversation turn. When model routing is configured and the incoming
// message scores below the complexity threshold, it returns the light model
// candidates instead of the primary ones.
//
// The returned (candidates, model) pair is used for all LLM calls within one
// turn — tool follow-up iterations use the same tier as the initial call so
// that a multi-step tool chain doesn't switch models mid-way.
func (al *AgentLoop) selectCandidates(
agent *AgentInstance,
userMsg string,
history []providers.Message,
) (candidates []providers.FallbackCandidate, model string) {
if agent.Router == nil || len(agent.LightCandidates) == 0 {
return agent.Candidates, agent.Model
}
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
if !usedLight {
logger.DebugCF("agent", "Model routing: primary model selected",
map[string]any{
"agent_id": agent.ID,
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.Candidates, agent.Model
}
logger.InfoCF("agent", "Model routing: light model selected",
map[string]any{
"agent_id": agent.ID,
"light_model": agent.Router.LightModel(),
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, agent.Router.LightModel()
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
+28 -15
View File
@@ -167,22 +167,35 @@ type SessionConfig struct {
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
}
// RoutingConfig controls the intelligent model routing feature.
// When enabled, each incoming message is scored against structural features
// (message length, code blocks, tool call history, conversation depth, attachments).
// Messages scoring below Threshold are sent to LightModel; all others use the
// agent's primary model. This reduces cost and latency for simple tasks without
// requiring any keyword matching — all scoring is language-agnostic.
type RoutingConfig struct {
Enabled bool `json:"enabled"`
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
+80
View File
@@ -0,0 +1,80 @@
package routing
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
// A higher score indicates a more complex task that benefits from a heavy model.
// The score is compared against the configured threshold: score >= threshold selects
// the primary (heavy) model; score < threshold selects the light model.
//
// Classifier is an interface so that future implementations (ML-based, embedding-based,
// or any other approach) can be swapped in without changing routing infrastructure.
type Classifier interface {
Score(f Features) float64
}
// RuleClassifier is the v1 implementation.
// It uses a weighted sum of structural signals with no external dependencies,
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
// that the returned score always falls within the [0, 1] contract.
//
// Individual weights (multiple signals can fire simultaneously):
//
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
// token 50-200: 0.15 — medium length; may or may not be complex
// code block present: 0.40 — coding tasks need the heavy model
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
// tool calls 1-3 (recent): 0.10 — some tool activity
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
//
// Default threshold is 0.35, so:
// - Pure greetings / trivial Q&A: 0.00 → light ✓
// - Medium prose message (50200 tokens): 0.15 → light ✓
// - Message with code block: 0.40 → heavy ✓
// - Long message (>200 tokens): 0.35 → heavy ✓
// - Active tool session + medium message: 0.25 → light (acceptable)
// - Any message with an image/audio attachment: 1.00 → heavy ✓
type RuleClassifier struct{}
// Score computes the complexity score for the given feature set.
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
func (c *RuleClassifier) Score(f Features) float64 {
// Hard gate: multi-modal inputs always require the heavy model.
if f.HasAttachments {
return 1.0
}
var score float64
// Token estimate — primary verbosity signal
switch {
case f.TokenEstimate > 200:
score += 0.35
case f.TokenEstimate > 50:
score += 0.15
}
// Fenced code blocks — strongest indicator of a coding/technical task
if f.CodeBlockCount > 0 {
score += 0.40
}
// Recent tool call density — indicates an ongoing agentic workflow
switch {
case f.RecentToolCalls > 3:
score += 0.25
case f.RecentToolCalls > 0:
score += 0.10
}
// Conversation depth — accumulated context implies compound task
if f.ConversationDepth > 10 {
score += 0.10
}
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
if score > 1.0 {
score = 1.0
}
return score
}
+127
View File
@@ -0,0 +1,127 @@
package routing
import (
"strings"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// lookbackWindow is the number of recent history entries scanned for tool calls.
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
const lookbackWindow = 6
// Features holds the structural signals extracted from a message and its session context.
// Every dimension is language-agnostic by construction — no keyword or pattern matching
// against natural-language content. This ensures consistent routing for all locales.
type Features struct {
// TokenEstimate is a proxy for token count.
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
// This avoids API calls while giving accurate estimates for all scripts.
TokenEstimate int
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
// Coding tasks almost always require the heavy model.
CodeBlockCount int
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
// history entries. A high density indicates an active agentic workflow.
RecentToolCalls int
// ConversationDepth is the total number of messages in the session history.
// Deep sessions tend to carry implicit complexity built up over many turns.
ConversationDepth int
// HasAttachments is true when the message appears to contain media (images,
// audio, video). Multi-modal inputs require vision-capable heavy models.
HasAttachments bool
}
// ExtractFeatures computes the structural feature vector for a message.
// It is a pure function with no side effects and zero allocations beyond
// the returned struct.
func ExtractFeatures(msg string, history []providers.Message) Features {
return Features{
TokenEstimate: estimateTokens(msg),
CodeBlockCount: countCodeBlocks(msg),
RecentToolCalls: countRecentToolCalls(history),
ConversationDepth: len(history),
HasAttachments: hasAttachments(msg),
}
}
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
// CJK runes (U+2E80U+9FFF, U+F900U+FAFF, U+AC00U+D7AF) map to roughly one
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
// for English). Splitting the count this way avoids the 3x underestimation that a
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
func estimateTokens(msg string) int {
total := utf8.RuneCountInString(msg)
if total == 0 {
return 0
}
cjk := 0
for _, r := range msg {
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
cjk++
}
}
return cjk + (total-cjk)/4
}
// countCodeBlocks counts the number of complete fenced code blocks.
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
// An unclosed opening fence (odd count) is treated as zero complete blocks
// since it may just be an inline code span or a typo.
func countCodeBlocks(msg string) int {
n := strings.Count(msg, "```")
return n / 2
}
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
// entries of history. It examines the ToolCalls field rather than parsing
// the content string, so it is robust to any message format.
func countRecentToolCalls(history []providers.Message) int {
start := len(history) - lookbackWindow
if start < 0 {
start = 0
}
count := 0
for _, msg := range history[start:] {
if len(msg.ToolCalls) > 0 {
count += len(msg.ToolCalls)
}
}
return count
}
// hasAttachments returns true when the message content contains embedded media.
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
// common image/audio URL extensions. This is intentionally conservative —
// false negatives (missing an attachment) just mean the routing falls back to
// the primary model anyway.
func hasAttachments(msg string) bool {
lower := strings.ToLower(msg)
// Base64 data URIs embedded directly in the message
if strings.Contains(lower, "data:image/") ||
strings.Contains(lower, "data:audio/") ||
strings.Contains(lower, "data:video/") {
return true
}
// Common image/audio extensions in URLs or file references
mediaExts := []string{
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
".mp3", ".wav", ".ogg", ".m4a", ".flac",
".mp4", ".avi", ".mov", ".webm",
}
for _, ext := range mediaExts {
if strings.Contains(lower, ext) {
return true
}
}
return false
}
+82
View File
@@ -0,0 +1,82 @@
package routing
import (
"github.com/sipeed/picoclaw/pkg/providers"
)
// defaultThreshold is used when the config threshold is zero or negative.
// At 0.35 a message needs at least one strong signal (code block, long text,
// or an attachment) before the heavy model is chosen.
const defaultThreshold = 0.35
// RouterConfig holds the validated model routing settings.
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
type RouterConfig struct {
// LightModel is the model_name (from model_list) used for simple tasks.
LightModel string
// Threshold is the complexity score cutoff in [0, 1].
// score >= Threshold → primary (heavy) model.
// score < Threshold → light model.
Threshold float64
}
// Router selects the appropriate model tier for each incoming message.
// It is safe for concurrent use from multiple goroutines.
type Router struct {
cfg RouterConfig
classifier Classifier
}
// New creates a Router with the given config and the default RuleClassifier.
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
func New(cfg RouterConfig) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{
cfg: cfg,
classifier: &RuleClassifier{},
}
}
// newWithClassifier creates a Router with a custom Classifier.
// Intended for unit tests that need to inject a deterministic scorer.
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{cfg: cfg, classifier: c}
}
// SelectModel returns the model to use for this conversation turn along with
// the computed complexity score (for logging and debugging).
//
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
// - Otherwise: returns (primaryModel, false, score)
//
// The caller is responsible for resolving the returned model name into
// provider candidates (see AgentInstance.LightCandidates).
func (r *Router) SelectModel(
msg string,
history []providers.Message,
primaryModel string,
) (model string, usedLight bool, score float64) {
features := ExtractFeatures(msg, history)
score = r.classifier.Score(features)
if score < r.cfg.Threshold {
return r.cfg.LightModel, true, score
}
return primaryModel, false, score
}
// LightModel returns the configured light model name.
func (r *Router) LightModel() string {
return r.cfg.LightModel
}
// Threshold returns the complexity threshold in use.
func (r *Router) Threshold() float64 {
return r.cfg.Threshold
}
+414
View File
@@ -0,0 +1,414 @@
package routing
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ── ExtractFeatures ──────────────────────────────────────────────────────────
func TestExtractFeatures_EmptyMessage(t *testing.T) {
f := ExtractFeatures("", nil)
if f.TokenEstimate != 0 {
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
}
if f.CodeBlockCount != 0 {
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
}
if f.RecentToolCalls != 0 {
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
}
if f.ConversationDepth != 0 {
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
}
if f.HasAttachments {
t.Error("HasAttachments: got true, want false")
}
}
func TestExtractFeatures_TokenEstimate(t *testing.T) {
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
msg := strings.Repeat("a", 30)
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 7 {
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
// Using a rune slice literal avoids CJK string literals in source.
msg := string([]rune{
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60,
})
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 9 {
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 6 {
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
}
}
func TestExtractFeatures_CodeBlocks(t *testing.T) {
cases := []struct {
msg string
want int
}{
{"no code here", 0},
{"```go\nfmt.Println()\n```", 1},
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.CodeBlockCount != tc.want {
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
}
}
}
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
// History longer than lookbackWindow — only last lookbackWindow entries count.
history := make([]providers.Message, 10)
// Put 2 tool calls at positions 8 and 9 (within the last 6)
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
history[9] = providers.Message{
Role: "assistant",
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
}
// Position 3 is outside the lookback window and must NOT be counted
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
f := ExtractFeatures("test", history)
// 1 (position 8) + 2 (position 9) = 3
if f.RecentToolCalls != 3 {
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
}
}
func TestExtractFeatures_ConversationDepth(t *testing.T) {
history := make([]providers.Message, 7)
f := ExtractFeatures("msg", history)
if f.ConversationDepth != 7 {
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
}
}
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"plain text", false},
{"here is an image: data:image/png;base64,abc123", true},
{"audio: data:audio/mp3;base64,xyz", true},
{"video: data:video/mp4;base64,xyz", true},
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"check out photo.jpg", true},
{"see screenshot.png", true},
{"listen to audio.mp3", true},
{"watch clip.mp4", true},
{"just a .go file", false},
{"document.pdf", false}, // pdf is not in the media list
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
// ── RuleClassifier ───────────────────────────────────────────────────────────
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{})
if score != 0.0 {
t.Errorf("zero features: got %f, want 0.0", score)
}
}
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{HasAttachments: true})
if score != 1.0 {
t.Errorf("attachments: got %f, want 1.0", score)
}
}
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
c := &RuleClassifier{}
// Code block alone = 0.40, above default threshold 0.35
score := c.Score(Features{CodeBlockCount: 1})
if score < 0.35 {
t.Errorf("code block: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_LongMessage(t *testing.T) {
c := &RuleClassifier{}
// >200 tokens = 0.35, exactly at default threshold → heavy
score := c.Score(Features{TokenEstimate: 250})
if score < 0.35 {
t.Errorf("long message: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_MediumMessage(t *testing.T) {
c := &RuleClassifier{}
// 50-200 tokens = 0.15, below threshold → light
score := c.Score(Features{TokenEstimate: 100})
if score >= 0.35 {
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
}
}
func TestRuleClassifier_ShortMessage(t *testing.T) {
c := &RuleClassifier{}
// <50 tokens, no other signals = 0.0 → light
score := c.Score(Features{TokenEstimate: 10})
if score != 0.0 {
t.Errorf("short message: got %f, want 0.0", score)
}
}
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
c := &RuleClassifier{}
scoreNone := c.Score(Features{RecentToolCalls: 0})
scoreLow := c.Score(Features{RecentToolCalls: 2})
scoreHigh := c.Score(Features{RecentToolCalls: 5})
if scoreNone != 0.0 {
t.Errorf("no tools: got %f, want 0.0", scoreNone)
}
if scoreLow <= scoreNone {
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
}
if scoreHigh <= scoreLow {
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
}
}
func TestRuleClassifier_DeepConversation(t *testing.T) {
c := &RuleClassifier{}
shallow := c.Score(Features{ConversationDepth: 5})
deep := c.Score(Features{ConversationDepth: 15})
if deep <= shallow {
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
}
}
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
c := &RuleClassifier{}
// Max all signals simultaneously
f := Features{
TokenEstimate: 500,
CodeBlockCount: 3,
RecentToolCalls: 10,
ConversationDepth: 20,
}
score := c.Score(f)
if score > 1.0 {
t.Errorf("score %f exceeds 1.0", score)
}
}
// ── Router ───────────────────────────────────────────────────────────────────
func TestRouter_DefaultThreshold(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash"})
if r.Threshold() != defaultThreshold {
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
if r.Threshold() != defaultThreshold {
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "hi"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("simple message: expected light model to be selected")
}
if model != "gemini-flash" {
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
}
}
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "```go\nfmt.Println(\"hello\")\n```"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("code block: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "can you analyze this? data:image/png;base64,abc123"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("attachment: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
// >200 token estimate: 210 * 3 = 630 chars
msg := strings.Repeat("word ", 210)
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("long message: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
// Routing is conservative: only promote to heavy when the signal is unambiguous.
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
}
msg := "ok"
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if !usedLight {
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
}
}
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
}},
}
// ~55 tokens * 3 = 165 chars
msg := strings.Repeat("word ", 55)
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if usedLight {
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
}
}
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
// Very low threshold: even a short message triggers heavy model
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("low threshold: medium message should use primary model")
}
}
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
// Very high threshold: even code blocks route to light
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
msg := "```go\nfmt.Println()\n```"
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("very high threshold: code block (0.40) should route to light model")
}
}
func TestRouter_LightModel(t *testing.T) {
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
if r.LightModel() != "my-fast-model" {
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
}
}
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
type fixedScoreClassifier struct{ score float64 }
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.2},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if !usedLight {
t.Error("low score with custom classifier: expected light model")
}
}
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.8},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("high score with custom classifier: expected primary model")
}
}
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
// score == threshold → primary (uses >= comparison)
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.5},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("score == threshold: expected primary model (>= threshold → primary)")
}
}
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.42},
)
_, _, score := r.SelectModel("anything", nil, "heavy")
if score != 0.42 {
t.Errorf("score: got %f, want 0.42", score)
}
}