mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
b84adacc2f
1. CJK token estimation: replace flat rune_count/3 with script-aware counting — CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) count as 1 token each, non-CJK runes at /4. This fixes a 3x underestimate for Chinese/Japanese/Korean text that could incorrectly route complex CJK messages to the light model. 2. Routing observability: SelectModel now returns the computed score as a third value. selectCandidates logs the score on both paths — Info level for light model selection, Debug level for primary model selection. 3. Added tests: TestExtractFeatures_TokenEstimate_Mixed (CJK+ASCII mix), TestRouter_SelectModel_ReturnsScore. Addresses review feedback from @mingmxren.
415 lines
14 KiB
Go
415 lines
14 KiB
Go
package routing
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
)
|
|
|
|
// ── ExtractFeatures ──────────────────────────────────────────────────────────
|
|
|
|
func TestExtractFeatures_EmptyMessage(t *testing.T) {
|
|
f := ExtractFeatures("", nil)
|
|
if f.TokenEstimate != 0 {
|
|
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
|
|
}
|
|
if f.CodeBlockCount != 0 {
|
|
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
|
|
}
|
|
if f.RecentToolCalls != 0 {
|
|
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
|
|
}
|
|
if f.ConversationDepth != 0 {
|
|
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
|
|
}
|
|
if f.HasAttachments {
|
|
t.Error("HasAttachments: got true, want false")
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_TokenEstimate(t *testing.T) {
|
|
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
|
|
msg := strings.Repeat("a", 30)
|
|
f := ExtractFeatures(msg, nil)
|
|
if f.TokenEstimate != 7 {
|
|
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
|
|
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
|
|
// Using a rune slice literal avoids CJK string literals in source.
|
|
msg := string([]rune{
|
|
0x4F60, 0x597D, 0x4E16, 0x754C,
|
|
0x4F60, 0x597D, 0x4E16, 0x754C,
|
|
0x4F60,
|
|
})
|
|
f := ExtractFeatures(msg, nil)
|
|
if f.TokenEstimate != 9 {
|
|
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
|
|
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
|
|
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
|
|
f := ExtractFeatures(msg, nil)
|
|
if f.TokenEstimate != 6 {
|
|
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_CodeBlocks(t *testing.T) {
|
|
cases := []struct {
|
|
msg string
|
|
want int
|
|
}{
|
|
{"no code here", 0},
|
|
{"```go\nfmt.Println()\n```", 1},
|
|
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
|
|
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
|
|
}
|
|
for _, tc := range cases {
|
|
f := ExtractFeatures(tc.msg, nil)
|
|
if f.CodeBlockCount != tc.want {
|
|
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
|
|
// History longer than lookbackWindow — only last lookbackWindow entries count.
|
|
history := make([]providers.Message, 10)
|
|
// Put 2 tool calls at positions 8 and 9 (within the last 6)
|
|
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
|
|
history[9] = providers.Message{
|
|
Role: "assistant",
|
|
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
|
|
}
|
|
// Position 3 is outside the lookback window and must NOT be counted
|
|
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
|
|
|
|
f := ExtractFeatures("test", history)
|
|
// 1 (position 8) + 2 (position 9) = 3
|
|
if f.RecentToolCalls != 3 {
|
|
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_ConversationDepth(t *testing.T) {
|
|
history := make([]providers.Message, 7)
|
|
f := ExtractFeatures("msg", history)
|
|
if f.ConversationDepth != 7 {
|
|
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
|
|
cases := []struct {
|
|
msg string
|
|
want bool
|
|
}{
|
|
{"plain text", false},
|
|
{"here is an image: data:image/png;base64,abc123", true},
|
|
{"audio: data:audio/mp3;base64,xyz", true},
|
|
{"video: data:video/mp4;base64,xyz", true},
|
|
}
|
|
for _, tc := range cases {
|
|
f := ExtractFeatures(tc.msg, nil)
|
|
if f.HasAttachments != tc.want {
|
|
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
|
|
cases := []struct {
|
|
msg string
|
|
want bool
|
|
}{
|
|
{"check out photo.jpg", true},
|
|
{"see screenshot.png", true},
|
|
{"listen to audio.mp3", true},
|
|
{"watch clip.mp4", true},
|
|
{"just a .go file", false},
|
|
{"document.pdf", false}, // pdf is not in the media list
|
|
}
|
|
for _, tc := range cases {
|
|
f := ExtractFeatures(tc.msg, nil)
|
|
if f.HasAttachments != tc.want {
|
|
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── RuleClassifier ───────────────────────────────────────────────────────────
|
|
|
|
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
score := c.Score(Features{})
|
|
if score != 0.0 {
|
|
t.Errorf("zero features: got %f, want 0.0", score)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
score := c.Score(Features{HasAttachments: true})
|
|
if score != 1.0 {
|
|
t.Errorf("attachments: got %f, want 1.0", score)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
// Code block alone = 0.40, above default threshold 0.35
|
|
score := c.Score(Features{CodeBlockCount: 1})
|
|
if score < 0.35 {
|
|
t.Errorf("code block: score %f is below default threshold 0.35", score)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_LongMessage(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
// >200 tokens = 0.35, exactly at default threshold → heavy
|
|
score := c.Score(Features{TokenEstimate: 250})
|
|
if score < 0.35 {
|
|
t.Errorf("long message: score %f is below default threshold 0.35", score)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_MediumMessage(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
// 50-200 tokens = 0.15, below threshold → light
|
|
score := c.Score(Features{TokenEstimate: 100})
|
|
if score >= 0.35 {
|
|
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_ShortMessage(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
// <50 tokens, no other signals = 0.0 → light
|
|
score := c.Score(Features{TokenEstimate: 10})
|
|
if score != 0.0 {
|
|
t.Errorf("short message: got %f, want 0.0", score)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
|
|
scoreNone := c.Score(Features{RecentToolCalls: 0})
|
|
scoreLow := c.Score(Features{RecentToolCalls: 2})
|
|
scoreHigh := c.Score(Features{RecentToolCalls: 5})
|
|
|
|
if scoreNone != 0.0 {
|
|
t.Errorf("no tools: got %f, want 0.0", scoreNone)
|
|
}
|
|
if scoreLow <= scoreNone {
|
|
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
|
|
}
|
|
if scoreHigh <= scoreLow {
|
|
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_DeepConversation(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
shallow := c.Score(Features{ConversationDepth: 5})
|
|
deep := c.Score(Features{ConversationDepth: 15})
|
|
if deep <= shallow {
|
|
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
|
|
}
|
|
}
|
|
|
|
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
|
|
c := &RuleClassifier{}
|
|
// Max all signals simultaneously
|
|
f := Features{
|
|
TokenEstimate: 500,
|
|
CodeBlockCount: 3,
|
|
RecentToolCalls: 10,
|
|
ConversationDepth: 20,
|
|
}
|
|
score := c.Score(f)
|
|
if score > 1.0 {
|
|
t.Errorf("score %f exceeds 1.0", score)
|
|
}
|
|
}
|
|
|
|
// ── Router ───────────────────────────────────────────────────────────────────
|
|
|
|
func TestRouter_DefaultThreshold(t *testing.T) {
|
|
r := New(RouterConfig{LightModel: "gemini-flash"})
|
|
if r.Threshold() != defaultThreshold {
|
|
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
|
}
|
|
}
|
|
|
|
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
|
|
if r.Threshold() != defaultThreshold {
|
|
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
|
msg := "hi"
|
|
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
|
if !usedLight {
|
|
t.Error("simple message: expected light model to be selected")
|
|
}
|
|
if model != "gemini-flash" {
|
|
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
|
msg := "```go\nfmt.Println(\"hello\")\n```"
|
|
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
|
if usedLight {
|
|
t.Error("code block: expected primary model to be selected")
|
|
}
|
|
if model != "claude-sonnet-4-6" {
|
|
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
|
msg := "can you analyze this? data:image/png;base64,abc123"
|
|
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
|
if usedLight {
|
|
t.Error("attachment: expected primary model to be selected")
|
|
}
|
|
if model != "claude-sonnet-4-6" {
|
|
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
|
// >200 token estimate: 210 * 3 = 630 chars
|
|
msg := strings.Repeat("word ", 210)
|
|
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
|
if usedLight {
|
|
t.Error("long message: expected primary model to be selected")
|
|
}
|
|
if model != "claude-sonnet-4-6" {
|
|
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
|
|
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
|
|
// Routing is conservative: only promote to heavy when the signal is unambiguous.
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
|
history := []providers.Message{
|
|
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
|
|
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
|
|
}
|
|
msg := "ok"
|
|
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
|
if !usedLight {
|
|
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
|
|
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
|
history := []providers.Message{
|
|
{Role: "assistant", ToolCalls: []providers.ToolCall{
|
|
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
|
|
}},
|
|
}
|
|
// ~55 tokens * 3 = 165 chars
|
|
msg := strings.Repeat("word ", 55)
|
|
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
|
if usedLight {
|
|
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
|
|
// Very low threshold: even a short message triggers heavy model
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
|
|
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
|
|
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
|
if usedLight {
|
|
t.Error("low threshold: medium message should use primary model")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
|
|
// Very high threshold: even code blocks route to light
|
|
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
|
|
msg := "```go\nfmt.Println()\n```"
|
|
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
|
if !usedLight {
|
|
t.Error("very high threshold: code block (0.40) should route to light model")
|
|
}
|
|
}
|
|
|
|
func TestRouter_LightModel(t *testing.T) {
|
|
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
|
|
if r.LightModel() != "my-fast-model" {
|
|
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
|
|
}
|
|
}
|
|
|
|
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
|
|
|
|
type fixedScoreClassifier struct{ score float64 }
|
|
|
|
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
|
|
|
|
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
|
|
r := newWithClassifier(
|
|
RouterConfig{LightModel: "light", Threshold: 0.5},
|
|
&fixedScoreClassifier{score: 0.2},
|
|
)
|
|
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
|
if !usedLight {
|
|
t.Error("low score with custom classifier: expected light model")
|
|
}
|
|
}
|
|
|
|
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
|
|
r := newWithClassifier(
|
|
RouterConfig{LightModel: "light", Threshold: 0.5},
|
|
&fixedScoreClassifier{score: 0.8},
|
|
)
|
|
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
|
if usedLight {
|
|
t.Error("high score with custom classifier: expected primary model")
|
|
}
|
|
}
|
|
|
|
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
|
|
// score == threshold → primary (uses >= comparison)
|
|
r := newWithClassifier(
|
|
RouterConfig{LightModel: "light", Threshold: 0.5},
|
|
&fixedScoreClassifier{score: 0.5},
|
|
)
|
|
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
|
if usedLight {
|
|
t.Error("score == threshold: expected primary model (>= threshold → primary)")
|
|
}
|
|
}
|
|
|
|
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
|
|
r := newWithClassifier(
|
|
RouterConfig{LightModel: "light", Threshold: 0.5},
|
|
&fixedScoreClassifier{score: 0.42},
|
|
)
|
|
_, _, score := r.SelectModel("anything", nil, "heavy")
|
|
if score != 0.42 {
|
|
t.Errorf("score: got %f, want 0.42", score)
|
|
}
|
|
}
|