mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
09e68cb63b
- classifier.go: s/honour/honor/ (American English per misspell) - router.go: break SelectModel signature across lines (golines) - router_test.go: break long Message literal (golines) - router_test.go: replace CJK string literal with rune slice so gosmopolitan does not flag the source file; behaviour is identical
395 lines
13 KiB
Go
395 lines
13 KiB
Go
package routing
|
||
|
||
import (
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/sipeed/picoclaw/pkg/providers"
|
||
)
|
||
|
||
// ── ExtractFeatures ──────────────────────────────────────────────────────────
|
||
|
||
func TestExtractFeatures_EmptyMessage(t *testing.T) {
|
||
f := ExtractFeatures("", nil)
|
||
if f.TokenEstimate != 0 {
|
||
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
|
||
}
|
||
if f.CodeBlockCount != 0 {
|
||
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
|
||
}
|
||
if f.RecentToolCalls != 0 {
|
||
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
|
||
}
|
||
if f.ConversationDepth != 0 {
|
||
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
|
||
}
|
||
if f.HasAttachments {
|
||
t.Error("HasAttachments: got true, want false")
|
||
}
|
||
}
|
||
|
||
func TestExtractFeatures_TokenEstimate(t *testing.T) {
|
||
// 30 ASCII chars / 3 = 10 tokens
|
||
msg := strings.Repeat("a", 30)
|
||
f := ExtractFeatures(msg, nil)
|
||
if f.TokenEstimate != 10 {
|
||
t.Errorf("TokenEstimate: got %d, want 10", f.TokenEstimate)
|
||
}
|
||
}
|
||
|
||
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
|
||
// 9 CJK runes (U+4F60 U+597D U+4E16 U+754C × 2 + U+4F60) / 3 = 3 tokens.
|
||
// Using a rune slice literal avoids CJK string literals in source.
|
||
msg := string([]rune{
|
||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||
0x4F60,
|
||
})
|
||
f := ExtractFeatures(msg, nil)
|
||
if f.TokenEstimate != 3 {
|
||
t.Errorf("CJK TokenEstimate: got %d, want 3", f.TokenEstimate)
|
||
}
|
||
}
|
||
|
||
func TestExtractFeatures_CodeBlocks(t *testing.T) {
|
||
cases := []struct {
|
||
msg string
|
||
want int
|
||
}{
|
||
{"no code here", 0},
|
||
{"```go\nfmt.Println()\n```", 1},
|
||
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
|
||
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
|
||
}
|
||
for _, tc := range cases {
|
||
f := ExtractFeatures(tc.msg, nil)
|
||
if f.CodeBlockCount != tc.want {
|
||
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
|
||
// History longer than lookbackWindow — only last lookbackWindow entries count.
|
||
history := make([]providers.Message, 10)
|
||
// Put 2 tool calls at positions 8 and 9 (within the last 6)
|
||
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
|
||
history[9] = providers.Message{
|
||
Role: "assistant",
|
||
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
|
||
}
|
||
// Position 3 is outside the lookback window and must NOT be counted
|
||
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
|
||
|
||
f := ExtractFeatures("test", history)
|
||
// 1 (position 8) + 2 (position 9) = 3
|
||
if f.RecentToolCalls != 3 {
|
||
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
|
||
}
|
||
}
|
||
|
||
func TestExtractFeatures_ConversationDepth(t *testing.T) {
|
||
history := make([]providers.Message, 7)
|
||
f := ExtractFeatures("msg", history)
|
||
if f.ConversationDepth != 7 {
|
||
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
|
||
}
|
||
}
|
||
|
||
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
|
||
cases := []struct {
|
||
msg string
|
||
want bool
|
||
}{
|
||
{"plain text", false},
|
||
{"here is an image: data:image/png;base64,abc123", true},
|
||
{"audio: data:audio/mp3;base64,xyz", true},
|
||
{"video: data:video/mp4;base64,xyz", true},
|
||
}
|
||
for _, tc := range cases {
|
||
f := ExtractFeatures(tc.msg, nil)
|
||
if f.HasAttachments != tc.want {
|
||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
|
||
cases := []struct {
|
||
msg string
|
||
want bool
|
||
}{
|
||
{"check out photo.jpg", true},
|
||
{"see screenshot.png", true},
|
||
{"listen to audio.mp3", true},
|
||
{"watch clip.mp4", true},
|
||
{"just a .go file", false},
|
||
{"document.pdf", false}, // pdf is not in the media list
|
||
}
|
||
for _, tc := range cases {
|
||
f := ExtractFeatures(tc.msg, nil)
|
||
if f.HasAttachments != tc.want {
|
||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ── RuleClassifier ───────────────────────────────────────────────────────────
|
||
|
||
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
score := c.Score(Features{})
|
||
if score != 0.0 {
|
||
t.Errorf("zero features: got %f, want 0.0", score)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
score := c.Score(Features{HasAttachments: true})
|
||
if score != 1.0 {
|
||
t.Errorf("attachments: got %f, want 1.0", score)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
// Code block alone = 0.40, above default threshold 0.35
|
||
score := c.Score(Features{CodeBlockCount: 1})
|
||
if score < 0.35 {
|
||
t.Errorf("code block: score %f is below default threshold 0.35", score)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_LongMessage(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
// >200 tokens = 0.35, exactly at default threshold → heavy
|
||
score := c.Score(Features{TokenEstimate: 250})
|
||
if score < 0.35 {
|
||
t.Errorf("long message: score %f is below default threshold 0.35", score)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_MediumMessage(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
// 50-200 tokens = 0.15, below threshold → light
|
||
score := c.Score(Features{TokenEstimate: 100})
|
||
if score >= 0.35 {
|
||
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_ShortMessage(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
// <50 tokens, no other signals = 0.0 → light
|
||
score := c.Score(Features{TokenEstimate: 10})
|
||
if score != 0.0 {
|
||
t.Errorf("short message: got %f, want 0.0", score)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
|
||
scoreNone := c.Score(Features{RecentToolCalls: 0})
|
||
scoreLow := c.Score(Features{RecentToolCalls: 2})
|
||
scoreHigh := c.Score(Features{RecentToolCalls: 5})
|
||
|
||
if scoreNone != 0.0 {
|
||
t.Errorf("no tools: got %f, want 0.0", scoreNone)
|
||
}
|
||
if scoreLow <= scoreNone {
|
||
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
|
||
}
|
||
if scoreHigh <= scoreLow {
|
||
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_DeepConversation(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
shallow := c.Score(Features{ConversationDepth: 5})
|
||
deep := c.Score(Features{ConversationDepth: 15})
|
||
if deep <= shallow {
|
||
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
|
||
}
|
||
}
|
||
|
||
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
|
||
c := &RuleClassifier{}
|
||
// Max all signals simultaneously
|
||
f := Features{
|
||
TokenEstimate: 500,
|
||
CodeBlockCount: 3,
|
||
RecentToolCalls: 10,
|
||
ConversationDepth: 20,
|
||
}
|
||
score := c.Score(f)
|
||
if score > 1.0 {
|
||
t.Errorf("score %f exceeds 1.0", score)
|
||
}
|
||
}
|
||
|
||
// ── Router ───────────────────────────────────────────────────────────────────
|
||
|
||
func TestRouter_DefaultThreshold(t *testing.T) {
|
||
r := New(RouterConfig{LightModel: "gemini-flash"})
|
||
if r.Threshold() != defaultThreshold {
|
||
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||
}
|
||
}
|
||
|
||
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
|
||
if r.Threshold() != defaultThreshold {
|
||
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||
msg := "hi"
|
||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||
if !usedLight {
|
||
t.Error("simple message: expected light model to be selected")
|
||
}
|
||
if model != "gemini-flash" {
|
||
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||
msg := "```go\nfmt.Println(\"hello\")\n```"
|
||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||
if usedLight {
|
||
t.Error("code block: expected primary model to be selected")
|
||
}
|
||
if model != "claude-sonnet-4-6" {
|
||
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||
msg := "can you analyze this? data:image/png;base64,abc123"
|
||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||
if usedLight {
|
||
t.Error("attachment: expected primary model to be selected")
|
||
}
|
||
if model != "claude-sonnet-4-6" {
|
||
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||
// >200 token estimate: 210 * 3 = 630 chars
|
||
msg := strings.Repeat("word ", 210)
|
||
model, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||
if usedLight {
|
||
t.Error("long message: expected primary model to be selected")
|
||
}
|
||
if model != "claude-sonnet-4-6" {
|
||
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
|
||
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
|
||
// Routing is conservative: only promote to heavy when the signal is unambiguous.
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||
history := []providers.Message{
|
||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
|
||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
|
||
}
|
||
msg := "ok"
|
||
_, usedLight := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||
if !usedLight {
|
||
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
|
||
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||
history := []providers.Message{
|
||
{Role: "assistant", ToolCalls: []providers.ToolCall{
|
||
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
|
||
}},
|
||
}
|
||
// ~55 tokens * 3 = 165 chars
|
||
msg := strings.Repeat("word ", 55)
|
||
_, usedLight := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||
if usedLight {
|
||
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
|
||
// Very low threshold: even a short message triggers heavy model
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
|
||
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
|
||
_, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||
if usedLight {
|
||
t.Error("low threshold: medium message should use primary model")
|
||
}
|
||
}
|
||
|
||
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
|
||
// Very high threshold: even code blocks route to light
|
||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
|
||
msg := "```go\nfmt.Println()\n```"
|
||
_, usedLight := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||
if !usedLight {
|
||
t.Error("very high threshold: code block (0.40) should route to light model")
|
||
}
|
||
}
|
||
|
||
func TestRouter_LightModel(t *testing.T) {
|
||
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
|
||
if r.LightModel() != "my-fast-model" {
|
||
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
|
||
}
|
||
}
|
||
|
||
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
|
||
|
||
type fixedScoreClassifier struct{ score float64 }
|
||
|
||
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
|
||
|
||
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
|
||
r := newWithClassifier(
|
||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||
&fixedScoreClassifier{score: 0.2},
|
||
)
|
||
_, usedLight := r.SelectModel("anything", nil, "heavy")
|
||
if !usedLight {
|
||
t.Error("low score with custom classifier: expected light model")
|
||
}
|
||
}
|
||
|
||
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
|
||
r := newWithClassifier(
|
||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||
&fixedScoreClassifier{score: 0.8},
|
||
)
|
||
_, usedLight := r.SelectModel("anything", nil, "heavy")
|
||
if usedLight {
|
||
t.Error("high score with custom classifier: expected primary model")
|
||
}
|
||
}
|
||
|
||
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
|
||
// score == threshold → primary (uses >= comparison)
|
||
r := newWithClassifier(
|
||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||
&fixedScoreClassifier{score: 0.5},
|
||
)
|
||
_, usedLight := r.SelectModel("anything", nil, "heavy")
|
||
if usedLight {
|
||
t.Error("score == threshold: expected primary model (>= threshold → primary)")
|
||
}
|
||
}
|