mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(routing): add language-agnostic model complexity scorer
Add three new files to pkg/routing/:
features.go — ExtractFeatures(msg, history) → Features
Computes five structural dimensions with zero keyword matching:
- TokenEstimate: rune_count/3 (CJK-safe token proxy)
- CodeBlockCount: ``` pairs in the message
- RecentToolCalls: tool call count in the last 6 history entries
- ConversationDepth: total messages in session
- HasAttachments: data URIs or media file extensions
classifier.go — Classifier interface + RuleClassifier
RuleClassifier uses a weighted sum that is capped at 1.0:
code block → +0.40 (triggers heavy model alone at 0.35 threshold)
token > 200 → +0.35 (triggers heavy model alone)
tool calls > 3 → +0.25
token 50-200 → +0.15
conversation depth > 10 → +0.10
attachment → 1.00 (hard gate, always heavy)
router.go — Router wraps config + Classifier
Router.SelectModel(msg, history, primaryModel) returns either the
configured light_model or the primary model depending on whether
the complexity score clears the threshold. Threshold defaults to
0.35 when zero/negative to prevent misconfiguration.
router_test.go — 34 tests covering all branches and edge cases
This commit is contained in:
@@ -0,0 +1,80 @@
|
||||
package routing
|
||||
|
||||
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
|
||||
// A higher score indicates a more complex task that benefits from a heavy model.
|
||||
// The score is compared against the configured threshold: score >= threshold selects
|
||||
// the primary (heavy) model; score < threshold selects the light model.
|
||||
//
|
||||
// Classifier is an interface so that future implementations (ML-based, embedding-based,
|
||||
// or any other approach) can be swapped in without changing routing infrastructure.
|
||||
type Classifier interface {
|
||||
Score(f Features) float64
|
||||
}
|
||||
|
||||
// RuleClassifier is the v1 implementation.
|
||||
// It uses a weighted sum of structural signals with no external dependencies,
|
||||
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
|
||||
// that the returned score always falls within the [0, 1] contract.
|
||||
//
|
||||
// Individual weights (multiple signals can fire simultaneously):
|
||||
//
|
||||
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
|
||||
// token 50-200: 0.15 — medium length; may or may not be complex
|
||||
// code block present: 0.40 — coding tasks need the heavy model
|
||||
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
|
||||
// tool calls 1-3 (recent): 0.10 — some tool activity
|
||||
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
|
||||
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
|
||||
//
|
||||
// Default threshold is 0.35, so:
|
||||
// - Pure greetings / trivial Q&A: 0.00 → light ✓
|
||||
// - Medium prose message (50–200 tokens): 0.15 → light ✓
|
||||
// - Message with code block: 0.40 → heavy ✓
|
||||
// - Long message (>200 tokens): 0.35 → heavy ✓
|
||||
// - Active tool session + medium message: 0.25 → light (acceptable)
|
||||
// - Any message with an image/audio attachment: 1.00 → heavy ✓
|
||||
type RuleClassifier struct{}
|
||||
|
||||
// Score computes the complexity score for the given feature set.
|
||||
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
|
||||
func (c *RuleClassifier) Score(f Features) float64 {
|
||||
// Hard gate: multi-modal inputs always require the heavy model.
|
||||
if f.HasAttachments {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
var score float64
|
||||
|
||||
// Token estimate — primary verbosity signal
|
||||
switch {
|
||||
case f.TokenEstimate > 200:
|
||||
score += 0.35
|
||||
case f.TokenEstimate > 50:
|
||||
score += 0.15
|
||||
}
|
||||
|
||||
// Fenced code blocks — strongest indicator of a coding/technical task
|
||||
if f.CodeBlockCount > 0 {
|
||||
score += 0.40
|
||||
}
|
||||
|
||||
// Recent tool call density — indicates an ongoing agentic workflow
|
||||
switch {
|
||||
case f.RecentToolCalls > 3:
|
||||
score += 0.25
|
||||
case f.RecentToolCalls > 0:
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Conversation depth — accumulated context implies compound task
|
||||
if f.ConversationDepth > 10 {
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Cap at 1.0 to honour the [0, 1] contract even when multiple signals fire
|
||||
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
|
||||
if score > 1.0 {
|
||||
score = 1.0
|
||||
}
|
||||
return score
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// lookbackWindow is the number of recent history entries scanned for tool calls.
|
||||
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
|
||||
const lookbackWindow = 6
|
||||
|
||||
// Features holds the structural signals extracted from a message and its session context.
|
||||
// Every dimension is language-agnostic by construction — no keyword or pattern matching
|
||||
// against natural-language content. This ensures consistent routing for all locales.
|
||||
type Features struct {
|
||||
// TokenEstimate is a conservative proxy for token count.
|
||||
// Computed as utf8.RuneCountInString(msg) / 3, which handles CJK characters
|
||||
// (each rune ≈ 1 token for CJK, ≈ 0.25 tokens for ASCII) without any API call.
|
||||
TokenEstimate int
|
||||
|
||||
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
|
||||
// Coding tasks almost always require the heavy model.
|
||||
CodeBlockCount int
|
||||
|
||||
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
|
||||
// history entries. A high density indicates an active agentic workflow.
|
||||
RecentToolCalls int
|
||||
|
||||
// ConversationDepth is the total number of messages in the session history.
|
||||
// Deep sessions tend to carry implicit complexity built up over many turns.
|
||||
ConversationDepth int
|
||||
|
||||
// HasAttachments is true when the message appears to contain media (images,
|
||||
// audio, video). Multi-modal inputs require vision-capable heavy models.
|
||||
HasAttachments bool
|
||||
}
|
||||
|
||||
// ExtractFeatures computes the structural feature vector for a message.
|
||||
// It is a pure function with no side effects and zero allocations beyond
|
||||
// the returned struct.
|
||||
func ExtractFeatures(msg string, history []providers.Message) Features {
|
||||
return Features{
|
||||
TokenEstimate: estimateTokens(msg),
|
||||
CodeBlockCount: countCodeBlocks(msg),
|
||||
RecentToolCalls: countRecentToolCalls(history),
|
||||
ConversationDepth: len(history),
|
||||
HasAttachments: hasAttachments(msg),
|
||||
}
|
||||
}
|
||||
|
||||
// estimateTokens returns a conservative token count proxy.
|
||||
// Using rune count / 3 rather than / 4 because CJK characters each map to
|
||||
// roughly one token, while ASCII words average ~1.3 chars/token. Dividing
|
||||
// by 3 is a safe middle ground that slightly over-estimates for Latin text
|
||||
// (errs toward routing to the heavy model) and is accurate for CJK.
|
||||
func estimateTokens(msg string) int {
|
||||
rc := utf8.RuneCountInString(msg)
|
||||
return rc / 3
|
||||
}
|
||||
|
||||
// countCodeBlocks counts the number of complete fenced code blocks.
|
||||
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
|
||||
// An unclosed opening fence (odd count) is treated as zero complete blocks
|
||||
// since it may just be an inline code span or a typo.
|
||||
func countCodeBlocks(msg string) int {
|
||||
n := strings.Count(msg, "```")
|
||||
return n / 2
|
||||
}
|
||||
|
||||
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
|
||||
// entries of history. It examines the ToolCalls field rather than parsing
|
||||
// the content string, so it is robust to any message format.
|
||||
func countRecentToolCalls(history []providers.Message) int {
|
||||
start := len(history) - lookbackWindow
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, msg := range history[start:] {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
count += len(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// hasAttachments returns true when the message content contains embedded media.
|
||||
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
|
||||
// common image/audio URL extensions. This is intentionally conservative —
|
||||
// false negatives (missing an attachment) just mean the routing falls back to
|
||||
// the primary model anyway.
|
||||
func hasAttachments(msg string) bool {
|
||||
lower := strings.ToLower(msg)
|
||||
|
||||
// Base64 data URIs embedded directly in the message
|
||||
if strings.Contains(lower, "data:image/") ||
|
||||
strings.Contains(lower, "data:audio/") ||
|
||||
strings.Contains(lower, "data:video/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Common image/audio extensions in URLs or file references
|
||||
mediaExts := []string{
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
|
||||
".mp3", ".wav", ".ogg", ".m4a", ".flac",
|
||||
".mp4", ".avi", ".mov", ".webm",
|
||||
}
|
||||
for _, ext := range mediaExts {
|
||||
if strings.Contains(lower, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// defaultThreshold is used when the config threshold is zero or negative.
|
||||
// At 0.35 a message needs at least one strong signal (code block, long text,
|
||||
// or an attachment) before the heavy model is chosen.
|
||||
const defaultThreshold = 0.35
|
||||
|
||||
// RouterConfig holds the validated model routing settings.
|
||||
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
|
||||
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
|
||||
type RouterConfig struct {
|
||||
// LightModel is the model_name (from model_list) used for simple tasks.
|
||||
LightModel string
|
||||
|
||||
// Threshold is the complexity score cutoff in [0, 1].
|
||||
// score >= Threshold → primary (heavy) model.
|
||||
// score < Threshold → light model.
|
||||
Threshold float64
|
||||
}
|
||||
|
||||
// Router selects the appropriate model tier for each incoming message.
|
||||
// It is safe for concurrent use from multiple goroutines.
|
||||
type Router struct {
|
||||
cfg RouterConfig
|
||||
classifier Classifier
|
||||
}
|
||||
|
||||
// New creates a Router with the given config and the default RuleClassifier.
|
||||
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
|
||||
func New(cfg RouterConfig) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{
|
||||
cfg: cfg,
|
||||
classifier: &RuleClassifier{},
|
||||
}
|
||||
}
|
||||
|
||||
// newWithClassifier creates a Router with a custom Classifier.
|
||||
// Intended for unit tests that need to inject a deterministic scorer.
|
||||
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{cfg: cfg, classifier: c}
|
||||
}
|
||||
|
||||
// SelectModel returns the model to use for this conversation turn.
|
||||
//
|
||||
// - If score < cfg.Threshold: returns (cfg.LightModel, true)
|
||||
// - Otherwise: returns (primaryModel, false)
|
||||
//
|
||||
// The caller is responsible for resolving the returned model name into
|
||||
// provider candidates (see AgentInstance.LightCandidates).
|
||||
func (r *Router) SelectModel(msg string, history []providers.Message, primaryModel string) (model string, usedLight bool) {
|
||||
features := ExtractFeatures(msg, history)
|
||||
score := r.classifier.Score(features)
|
||||
if score < r.cfg.Threshold {
|
||||
return r.cfg.LightModel, true
|
||||
}
|
||||
return primaryModel, false
|
||||
}
|
||||
|
||||
// LightModel returns the configured light model name.
|
||||
func (r *Router) LightModel() string {
|
||||
return r.cfg.LightModel
|
||||
}
|
||||
|
||||
// Threshold returns the complexity threshold in use.
|
||||
func (r *Router) Threshold() float64 {
|
||||
return r.cfg.Threshold
|
||||
}
|
||||
@@ -0,0 +1,386 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ── ExtractFeatures ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractFeatures_EmptyMessage(t *testing.T) {
|
||||
f := ExtractFeatures("", nil)
|
||||
if f.TokenEstimate != 0 {
|
||||
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
|
||||
}
|
||||
if f.CodeBlockCount != 0 {
|
||||
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
|
||||
}
|
||||
if f.RecentToolCalls != 0 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
|
||||
}
|
||||
if f.ConversationDepth != 0 {
|
||||
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
|
||||
}
|
||||
if f.HasAttachments {
|
||||
t.Error("HasAttachments: got true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate(t *testing.T) {
|
||||
// 30 ASCII chars / 3 = 10 tokens
|
||||
msg := strings.Repeat("a", 30)
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 10 {
|
||||
t.Errorf("TokenEstimate: got %d, want 10", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
|
||||
// 9 CJK runes / 3 = 3 tokens
|
||||
msg := "你好世界你好世界你" // 9 runes
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 3 {
|
||||
t.Errorf("CJK TokenEstimate: got %d, want 3", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_CodeBlocks(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"no code here", 0},
|
||||
{"```go\nfmt.Println()\n```", 1},
|
||||
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
|
||||
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.CodeBlockCount != tc.want {
|
||||
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
|
||||
// History longer than lookbackWindow — only last lookbackWindow entries count.
|
||||
history := make([]providers.Message, 10)
|
||||
// Put 2 tool calls at positions 8 and 9 (within the last 6)
|
||||
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
|
||||
history[9] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}}
|
||||
// Position 3 is outside the lookback window and must NOT be counted
|
||||
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
|
||||
|
||||
f := ExtractFeatures("test", history)
|
||||
// 1 (position 8) + 2 (position 9) = 3
|
||||
if f.RecentToolCalls != 3 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_ConversationDepth(t *testing.T) {
|
||||
history := make([]providers.Message, 7)
|
||||
f := ExtractFeatures("msg", history)
|
||||
if f.ConversationDepth != 7 {
|
||||
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"plain text", false},
|
||||
{"here is an image: data:image/png;base64,abc123", true},
|
||||
{"audio: data:audio/mp3;base64,xyz", true},
|
||||
{"video: data:video/mp4;base64,xyz", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"check out photo.jpg", true},
|
||||
{"see screenshot.png", true},
|
||||
{"listen to audio.mp3", true},
|
||||
{"watch clip.mp4", true},
|
||||
{"just a .go file", false},
|
||||
{"document.pdf", false}, // pdf is not in the media list
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── RuleClassifier ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{})
|
||||
if score != 0.0 {
|
||||
t.Errorf("zero features: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{HasAttachments: true})
|
||||
if score != 1.0 {
|
||||
t.Errorf("attachments: got %f, want 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Code block alone = 0.40, above default threshold 0.35
|
||||
score := c.Score(Features{CodeBlockCount: 1})
|
||||
if score < 0.35 {
|
||||
t.Errorf("code block: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_LongMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// >200 tokens = 0.35, exactly at default threshold → heavy
|
||||
score := c.Score(Features{TokenEstimate: 250})
|
||||
if score < 0.35 {
|
||||
t.Errorf("long message: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_MediumMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// 50-200 tokens = 0.15, below threshold → light
|
||||
score := c.Score(Features{TokenEstimate: 100})
|
||||
if score >= 0.35 {
|
||||
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ShortMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// <50 tokens, no other signals = 0.0 → light
|
||||
score := c.Score(Features{TokenEstimate: 10})
|
||||
if score != 0.0 {
|
||||
t.Errorf("short message: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
|
||||
scoreNone := c.Score(Features{RecentToolCalls: 0})
|
||||
scoreLow := c.Score(Features{RecentToolCalls: 2})
|
||||
scoreHigh := c.Score(Features{RecentToolCalls: 5})
|
||||
|
||||
if scoreNone != 0.0 {
|
||||
t.Errorf("no tools: got %f, want 0.0", scoreNone)
|
||||
}
|
||||
if scoreLow <= scoreNone {
|
||||
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
|
||||
}
|
||||
if scoreHigh <= scoreLow {
|
||||
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_DeepConversation(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
shallow := c.Score(Features{ConversationDepth: 5})
|
||||
deep := c.Score(Features{ConversationDepth: 15})
|
||||
if deep <= shallow {
|
||||
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Max all signals simultaneously
|
||||
f := Features{
|
||||
TokenEstimate: 500,
|
||||
CodeBlockCount: 3,
|
||||
RecentToolCalls: 10,
|
||||
ConversationDepth: 20,
|
||||
}
|
||||
score := c.Score(f)
|
||||
if score > 1.0 {
|
||||
t.Errorf("score %f exceeds 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Router ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestRouter_DefaultThreshold(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash"})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "hi"
|
||||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("simple message: expected light model to be selected")
|
||||
}
|
||||
if model != "gemini-flash" {
|
||||
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "```go\nfmt.Println(\"hello\")\n```"
|
||||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("code block: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "can you analyze this? data:image/png;base64,abc123"
|
||||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("attachment: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
// >200 token estimate: 210 * 3 = 630 chars
|
||||
msg := strings.Repeat("word ", 210)
|
||||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("long message: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
|
||||
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
|
||||
// Routing is conservative: only promote to heavy when the signal is unambiguous.
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
|
||||
}
|
||||
msg := "ok"
|
||||
_, usedLight := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
|
||||
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{
|
||||
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
|
||||
}},
|
||||
}
|
||||
// ~55 tokens * 3 = 165 chars
|
||||
msg := strings.Repeat("word ", 55)
|
||||
_, usedLight := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
|
||||
// Very low threshold: even a short message triggers heavy model
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
|
||||
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
|
||||
_, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("low threshold: medium message should use primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
|
||||
// Very high threshold: even code blocks route to light
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
|
||||
msg := "```go\nfmt.Println()\n```"
|
||||
_, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("very high threshold: code block (0.40) should route to light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_LightModel(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
|
||||
if r.LightModel() != "my-fast-model" {
|
||||
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
|
||||
}
|
||||
}
|
||||
|
||||
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
|
||||
|
||||
type fixedScoreClassifier struct{ score float64 }
|
||||
|
||||
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
|
||||
|
||||
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.2},
|
||||
)
|
||||
_, usedLight := r.SelectModel("anything", nil, "heavy")
|
||||
if !usedLight {
|
||||
t.Error("low score with custom classifier: expected light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.8},
|
||||
)
|
||||
_, usedLight := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("high score with custom classifier: expected primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
|
||||
// score == threshold → primary (uses >= comparison)
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.5},
|
||||
)
|
||||
_, usedLight := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("score == threshold: expected primary model (>= threshold → primary)")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user