Files
picoclaw/pkg/routing/router_test.go
T
xiaoen 09e68cb63b fix(routing): resolve golines, gosmopolitan and misspell lint failures
- classifier.go: s/honour/honor/ (American English per misspell)
- router.go: break SelectModel signature across lines (golines)
- router_test.go: break long Message literal (golines)
- router_test.go: replace CJK string literal with rune slice so
  gosmopolitan does not flag the source file; behaviour is identical
2026-03-02 23:11:45 +08:00

395 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package routing
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ── ExtractFeatures ──────────────────────────────────────────────────────────
func TestExtractFeatures_EmptyMessage(t *testing.T) {
f := ExtractFeatures("", nil)
if f.TokenEstimate != 0 {
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
}
if f.CodeBlockCount != 0 {
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
}
if f.RecentToolCalls != 0 {
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
}
if f.ConversationDepth != 0 {
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
}
if f.HasAttachments {
t.Error("HasAttachments: got true, want false")
}
}
func TestExtractFeatures_TokenEstimate(t *testing.T) {
// 30 ASCII chars / 3 = 10 tokens
msg := strings.Repeat("a", 30)
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 10 {
t.Errorf("TokenEstimate: got %d, want 10", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
// 9 CJK runes (U+4F60 U+597D U+4E16 U+754C × 2 + U+4F60) / 3 = 3 tokens.
// Using a rune slice literal avoids CJK string literals in source.
msg := string([]rune{
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60,
})
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 3 {
t.Errorf("CJK TokenEstimate: got %d, want 3", f.TokenEstimate)
}
}
func TestExtractFeatures_CodeBlocks(t *testing.T) {
cases := []struct {
msg string
want int
}{
{"no code here", 0},
{"```go\nfmt.Println()\n```", 1},
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.CodeBlockCount != tc.want {
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
}
}
}
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
// History longer than lookbackWindow — only last lookbackWindow entries count.
history := make([]providers.Message, 10)
// Put 2 tool calls at positions 8 and 9 (within the last 6)
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
history[9] = providers.Message{
Role: "assistant",
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
}
// Position 3 is outside the lookback window and must NOT be counted
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
f := ExtractFeatures("test", history)
// 1 (position 8) + 2 (position 9) = 3
if f.RecentToolCalls != 3 {
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
}
}
func TestExtractFeatures_ConversationDepth(t *testing.T) {
history := make([]providers.Message, 7)
f := ExtractFeatures("msg", history)
if f.ConversationDepth != 7 {
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
}
}
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"plain text", false},
{"here is an image: data:image/png;base64,abc123", true},
{"audio: data:audio/mp3;base64,xyz", true},
{"video: data:video/mp4;base64,xyz", true},
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"check out photo.jpg", true},
{"see screenshot.png", true},
{"listen to audio.mp3", true},
{"watch clip.mp4", true},
{"just a .go file", false},
{"document.pdf", false}, // pdf is not in the media list
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
// ── RuleClassifier ───────────────────────────────────────────────────────────
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{})
if score != 0.0 {
t.Errorf("zero features: got %f, want 0.0", score)
}
}
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{HasAttachments: true})
if score != 1.0 {
t.Errorf("attachments: got %f, want 1.0", score)
}
}
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
c := &RuleClassifier{}
// Code block alone = 0.40, above default threshold 0.35
score := c.Score(Features{CodeBlockCount: 1})
if score < 0.35 {
t.Errorf("code block: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_LongMessage(t *testing.T) {
c := &RuleClassifier{}
// >200 tokens = 0.35, exactly at default threshold → heavy
score := c.Score(Features{TokenEstimate: 250})
if score < 0.35 {
t.Errorf("long message: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_MediumMessage(t *testing.T) {
c := &RuleClassifier{}
// 50-200 tokens = 0.15, below threshold → light
score := c.Score(Features{TokenEstimate: 100})
if score >= 0.35 {
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
}
}
func TestRuleClassifier_ShortMessage(t *testing.T) {
c := &RuleClassifier{}
// <50 tokens, no other signals = 0.0 → light
score := c.Score(Features{TokenEstimate: 10})
if score != 0.0 {
t.Errorf("short message: got %f, want 0.0", score)
}
}
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
c := &RuleClassifier{}
scoreNone := c.Score(Features{RecentToolCalls: 0})
scoreLow := c.Score(Features{RecentToolCalls: 2})
scoreHigh := c.Score(Features{RecentToolCalls: 5})
if scoreNone != 0.0 {
t.Errorf("no tools: got %f, want 0.0", scoreNone)
}
if scoreLow <= scoreNone {
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
}
if scoreHigh <= scoreLow {
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
}
}
func TestRuleClassifier_DeepConversation(t *testing.T) {
c := &RuleClassifier{}
shallow := c.Score(Features{ConversationDepth: 5})
deep := c.Score(Features{ConversationDepth: 15})
if deep <= shallow {
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
}
}
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
c := &RuleClassifier{}
// Max all signals simultaneously
f := Features{
TokenEstimate: 500,
CodeBlockCount: 3,
RecentToolCalls: 10,
ConversationDepth: 20,
}
score := c.Score(f)
if score > 1.0 {
t.Errorf("score %f exceeds 1.0", score)
}
}
// ── Router ───────────────────────────────────────────────────────────────────
func TestRouter_DefaultThreshold(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash"})
if r.Threshold() != defaultThreshold {
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
if r.Threshold() != defaultThreshold {
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "hi"
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("simple message: expected light model to be selected")
}
if model != "gemini-flash" {
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
}
}
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "```go\nfmt.Println(\"hello\")\n```"
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("code block: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "can you analyze this? data:image/png;base64,abc123"
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("attachment: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
// >200 token estimate: 210 * 3 = 630 chars
msg := strings.Repeat("word ", 210)
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("long message: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
// Routing is conservative: only promote to heavy when the signal is unambiguous.
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
}
msg := "ok"
_, usedLight := r.SelectModel(msg, history, "claude-sonnet-4-6")
if !usedLight {
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
}
}
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
}},
}
// ~55 tokens * 3 = 165 chars
msg := strings.Repeat("word ", 55)
_, usedLight := r.SelectModel(msg, history, "claude-sonnet-4-6")
if usedLight {
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
}
}
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
// Very low threshold: even a short message triggers heavy model
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
_, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("low threshold: medium message should use primary model")
}
}
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
// Very high threshold: even code blocks route to light
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
msg := "```go\nfmt.Println()\n```"
_, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("very high threshold: code block (0.40) should route to light model")
}
}
func TestRouter_LightModel(t *testing.T) {
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
if r.LightModel() != "my-fast-model" {
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
}
}
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
type fixedScoreClassifier struct{ score float64 }
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.2},
)
_, usedLight := r.SelectModel("anything", nil, "heavy")
if !usedLight {
t.Error("low score with custom classifier: expected light model")
}
}
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.8},
)
_, usedLight := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("high score with custom classifier: expected primary model")
}
}
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
// score == threshold → primary (uses >= comparison)
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.5},
)
_, usedLight := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("score == threshold: expected primary model (>= threshold → primary)")
}
}