mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #994 from is-Xiaoen/feat/model-routing
feat(routing): intelligent model routing based on structural complexity scoring
This commit is contained in:
@@ -37,6 +37,14 @@ type AgentInstance struct {
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
|
||||
// Router is non-nil when model routing is configured and the light model
|
||||
// was successfully resolved. It scores each incoming message and decides
|
||||
// whether to route to LightCandidates or stay with Candidates.
|
||||
Router *routing.Router
|
||||
// LightCandidates holds the resolved provider candidates for the light model.
|
||||
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
|
||||
LightCandidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -180,6 +188,25 @@ func NewAgentInstance(
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
|
||||
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
} else {
|
||||
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
|
||||
rc.LightModel, agentID)
|
||||
}
|
||||
}
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
@@ -200,6 +227,8 @@ func NewAgentInstance(
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+49
-5
@@ -824,6 +824,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
// all tool-follow-up iterations within the same turn so that a multi-step
|
||||
// tool chain doesn't switch models mid-way through.
|
||||
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
|
||||
|
||||
for iteration < agent.MaxIterations {
|
||||
iteration++
|
||||
|
||||
@@ -842,7 +848,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": agent.Model,
|
||||
"model": activeModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": agent.MaxTokens,
|
||||
@@ -858,7 +864,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
// Call LLM with fallback chain if candidates are configured.
|
||||
// Call LLM with fallback chain if multiple candidates are configured.
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
@@ -879,10 +885,10 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
|
||||
},
|
||||
@@ -900,7 +906,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
|
||||
}
|
||||
|
||||
// Retry loop for context/token errors
|
||||
@@ -1169,6 +1175,44 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return finalContent, iteration, nil
|
||||
}
|
||||
|
||||
// selectCandidates returns the model candidates and resolved model name to use
|
||||
// for a conversation turn. When model routing is configured and the incoming
|
||||
// message scores below the complexity threshold, it returns the light model
|
||||
// candidates instead of the primary ones.
|
||||
//
|
||||
// The returned (candidates, model) pair is used for all LLM calls within one
|
||||
// turn — tool follow-up iterations use the same tier as the initial call so
|
||||
// that a multi-step tool chain doesn't switch models mid-way.
|
||||
func (al *AgentLoop) selectCandidates(
|
||||
agent *AgentInstance,
|
||||
userMsg string,
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
if !usedLight {
|
||||
logger.DebugCF("agent", "Model routing: primary model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"light_model": agent.Router.LightModel(),
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, agent.Router.LightModel()
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
|
||||
+28
-15
@@ -167,22 +167,35 @@ type SessionConfig struct {
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
// RoutingConfig controls the intelligent model routing feature.
|
||||
// When enabled, each incoming message is scored against structural features
|
||||
// (message length, code blocks, tool call history, conversation depth, attachments).
|
||||
// Messages scoring below Threshold are sent to LightModel; all others use the
|
||||
// agent's primary model. This reduces cost and latency for simple tasks without
|
||||
// requiring any keyword matching — all scoring is language-agnostic.
|
||||
type RoutingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
|
||||
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package routing
|
||||
|
||||
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
|
||||
// A higher score indicates a more complex task that benefits from a heavy model.
|
||||
// The score is compared against the configured threshold: score >= threshold selects
|
||||
// the primary (heavy) model; score < threshold selects the light model.
|
||||
//
|
||||
// Classifier is an interface so that future implementations (ML-based, embedding-based,
|
||||
// or any other approach) can be swapped in without changing routing infrastructure.
|
||||
type Classifier interface {
|
||||
Score(f Features) float64
|
||||
}
|
||||
|
||||
// RuleClassifier is the v1 implementation.
|
||||
// It uses a weighted sum of structural signals with no external dependencies,
|
||||
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
|
||||
// that the returned score always falls within the [0, 1] contract.
|
||||
//
|
||||
// Individual weights (multiple signals can fire simultaneously):
|
||||
//
|
||||
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
|
||||
// token 50-200: 0.15 — medium length; may or may not be complex
|
||||
// code block present: 0.40 — coding tasks need the heavy model
|
||||
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
|
||||
// tool calls 1-3 (recent): 0.10 — some tool activity
|
||||
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
|
||||
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
|
||||
//
|
||||
// Default threshold is 0.35, so:
|
||||
// - Pure greetings / trivial Q&A: 0.00 → light ✓
|
||||
// - Medium prose message (50–200 tokens): 0.15 → light ✓
|
||||
// - Message with code block: 0.40 → heavy ✓
|
||||
// - Long message (>200 tokens): 0.35 → heavy ✓
|
||||
// - Active tool session + medium message: 0.25 → light (acceptable)
|
||||
// - Any message with an image/audio attachment: 1.00 → heavy ✓
|
||||
type RuleClassifier struct{}
|
||||
|
||||
// Score computes the complexity score for the given feature set.
|
||||
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
|
||||
func (c *RuleClassifier) Score(f Features) float64 {
|
||||
// Hard gate: multi-modal inputs always require the heavy model.
|
||||
if f.HasAttachments {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
var score float64
|
||||
|
||||
// Token estimate — primary verbosity signal
|
||||
switch {
|
||||
case f.TokenEstimate > 200:
|
||||
score += 0.35
|
||||
case f.TokenEstimate > 50:
|
||||
score += 0.15
|
||||
}
|
||||
|
||||
// Fenced code blocks — strongest indicator of a coding/technical task
|
||||
if f.CodeBlockCount > 0 {
|
||||
score += 0.40
|
||||
}
|
||||
|
||||
// Recent tool call density — indicates an ongoing agentic workflow
|
||||
switch {
|
||||
case f.RecentToolCalls > 3:
|
||||
score += 0.25
|
||||
case f.RecentToolCalls > 0:
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Conversation depth — accumulated context implies compound task
|
||||
if f.ConversationDepth > 10 {
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
|
||||
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
|
||||
if score > 1.0 {
|
||||
score = 1.0
|
||||
}
|
||||
return score
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// lookbackWindow is the number of recent history entries scanned for tool calls.
|
||||
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
|
||||
const lookbackWindow = 6
|
||||
|
||||
// Features holds the structural signals extracted from a message and its session context.
|
||||
// Every dimension is language-agnostic by construction — no keyword or pattern matching
|
||||
// against natural-language content. This ensures consistent routing for all locales.
|
||||
type Features struct {
|
||||
// TokenEstimate is a proxy for token count.
|
||||
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
|
||||
// This avoids API calls while giving accurate estimates for all scripts.
|
||||
TokenEstimate int
|
||||
|
||||
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
|
||||
// Coding tasks almost always require the heavy model.
|
||||
CodeBlockCount int
|
||||
|
||||
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
|
||||
// history entries. A high density indicates an active agentic workflow.
|
||||
RecentToolCalls int
|
||||
|
||||
// ConversationDepth is the total number of messages in the session history.
|
||||
// Deep sessions tend to carry implicit complexity built up over many turns.
|
||||
ConversationDepth int
|
||||
|
||||
// HasAttachments is true when the message appears to contain media (images,
|
||||
// audio, video). Multi-modal inputs require vision-capable heavy models.
|
||||
HasAttachments bool
|
||||
}
|
||||
|
||||
// ExtractFeatures computes the structural feature vector for a message.
|
||||
// It is a pure function with no side effects and zero allocations beyond
|
||||
// the returned struct.
|
||||
func ExtractFeatures(msg string, history []providers.Message) Features {
|
||||
return Features{
|
||||
TokenEstimate: estimateTokens(msg),
|
||||
CodeBlockCount: countCodeBlocks(msg),
|
||||
RecentToolCalls: countRecentToolCalls(history),
|
||||
ConversationDepth: len(history),
|
||||
HasAttachments: hasAttachments(msg),
|
||||
}
|
||||
}
|
||||
|
||||
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
|
||||
// CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) map to roughly one
|
||||
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
|
||||
// for English). Splitting the count this way avoids the 3x underestimation that a
|
||||
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
|
||||
func estimateTokens(msg string) int {
|
||||
total := utf8.RuneCountInString(msg)
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
cjk := 0
|
||||
for _, r := range msg {
|
||||
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
|
||||
cjk++
|
||||
}
|
||||
}
|
||||
return cjk + (total-cjk)/4
|
||||
}
|
||||
|
||||
// countCodeBlocks counts the number of complete fenced code blocks.
|
||||
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
|
||||
// An unclosed opening fence (odd count) is treated as zero complete blocks
|
||||
// since it may just be an inline code span or a typo.
|
||||
func countCodeBlocks(msg string) int {
|
||||
n := strings.Count(msg, "```")
|
||||
return n / 2
|
||||
}
|
||||
|
||||
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
|
||||
// entries of history. It examines the ToolCalls field rather than parsing
|
||||
// the content string, so it is robust to any message format.
|
||||
func countRecentToolCalls(history []providers.Message) int {
|
||||
start := len(history) - lookbackWindow
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, msg := range history[start:] {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
count += len(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// hasAttachments returns true when the message content contains embedded media.
|
||||
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
|
||||
// common image/audio URL extensions. This is intentionally conservative —
|
||||
// false negatives (missing an attachment) just mean the routing falls back to
|
||||
// the primary model anyway.
|
||||
func hasAttachments(msg string) bool {
|
||||
lower := strings.ToLower(msg)
|
||||
|
||||
// Base64 data URIs embedded directly in the message
|
||||
if strings.Contains(lower, "data:image/") ||
|
||||
strings.Contains(lower, "data:audio/") ||
|
||||
strings.Contains(lower, "data:video/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Common image/audio extensions in URLs or file references
|
||||
mediaExts := []string{
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
|
||||
".mp3", ".wav", ".ogg", ".m4a", ".flac",
|
||||
".mp4", ".avi", ".mov", ".webm",
|
||||
}
|
||||
for _, ext := range mediaExts {
|
||||
if strings.Contains(lower, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// defaultThreshold is used when the config threshold is zero or negative.
|
||||
// At 0.35 a message needs at least one strong signal (code block, long text,
|
||||
// or an attachment) before the heavy model is chosen.
|
||||
const defaultThreshold = 0.35
|
||||
|
||||
// RouterConfig holds the validated model routing settings.
|
||||
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
|
||||
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
|
||||
type RouterConfig struct {
|
||||
// LightModel is the model_name (from model_list) used for simple tasks.
|
||||
LightModel string
|
||||
|
||||
// Threshold is the complexity score cutoff in [0, 1].
|
||||
// score >= Threshold → primary (heavy) model.
|
||||
// score < Threshold → light model.
|
||||
Threshold float64
|
||||
}
|
||||
|
||||
// Router selects the appropriate model tier for each incoming message.
|
||||
// It is safe for concurrent use from multiple goroutines.
|
||||
type Router struct {
|
||||
cfg RouterConfig
|
||||
classifier Classifier
|
||||
}
|
||||
|
||||
// New creates a Router with the given config and the default RuleClassifier.
|
||||
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
|
||||
func New(cfg RouterConfig) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{
|
||||
cfg: cfg,
|
||||
classifier: &RuleClassifier{},
|
||||
}
|
||||
}
|
||||
|
||||
// newWithClassifier creates a Router with a custom Classifier.
|
||||
// Intended for unit tests that need to inject a deterministic scorer.
|
||||
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{cfg: cfg, classifier: c}
|
||||
}
|
||||
|
||||
// SelectModel returns the model to use for this conversation turn along with
|
||||
// the computed complexity score (for logging and debugging).
|
||||
//
|
||||
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
|
||||
// - Otherwise: returns (primaryModel, false, score)
|
||||
//
|
||||
// The caller is responsible for resolving the returned model name into
|
||||
// provider candidates (see AgentInstance.LightCandidates).
|
||||
func (r *Router) SelectModel(
|
||||
msg string,
|
||||
history []providers.Message,
|
||||
primaryModel string,
|
||||
) (model string, usedLight bool, score float64) {
|
||||
features := ExtractFeatures(msg, history)
|
||||
score = r.classifier.Score(features)
|
||||
if score < r.cfg.Threshold {
|
||||
return r.cfg.LightModel, true, score
|
||||
}
|
||||
return primaryModel, false, score
|
||||
}
|
||||
|
||||
// LightModel returns the configured light model name.
|
||||
func (r *Router) LightModel() string {
|
||||
return r.cfg.LightModel
|
||||
}
|
||||
|
||||
// Threshold returns the complexity threshold in use.
|
||||
func (r *Router) Threshold() float64 {
|
||||
return r.cfg.Threshold
|
||||
}
|
||||
@@ -0,0 +1,414 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ── ExtractFeatures ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractFeatures_EmptyMessage(t *testing.T) {
|
||||
f := ExtractFeatures("", nil)
|
||||
if f.TokenEstimate != 0 {
|
||||
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
|
||||
}
|
||||
if f.CodeBlockCount != 0 {
|
||||
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
|
||||
}
|
||||
if f.RecentToolCalls != 0 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
|
||||
}
|
||||
if f.ConversationDepth != 0 {
|
||||
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
|
||||
}
|
||||
if f.HasAttachments {
|
||||
t.Error("HasAttachments: got true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate(t *testing.T) {
|
||||
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
|
||||
msg := strings.Repeat("a", 30)
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 7 {
|
||||
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
|
||||
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
|
||||
// Using a rune slice literal avoids CJK string literals in source.
|
||||
msg := string([]rune{
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60,
|
||||
})
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 9 {
|
||||
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
|
||||
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
|
||||
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 6 {
|
||||
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_CodeBlocks(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"no code here", 0},
|
||||
{"```go\nfmt.Println()\n```", 1},
|
||||
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
|
||||
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.CodeBlockCount != tc.want {
|
||||
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
|
||||
// History longer than lookbackWindow — only last lookbackWindow entries count.
|
||||
history := make([]providers.Message, 10)
|
||||
// Put 2 tool calls at positions 8 and 9 (within the last 6)
|
||||
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
|
||||
history[9] = providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
|
||||
}
|
||||
// Position 3 is outside the lookback window and must NOT be counted
|
||||
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
|
||||
|
||||
f := ExtractFeatures("test", history)
|
||||
// 1 (position 8) + 2 (position 9) = 3
|
||||
if f.RecentToolCalls != 3 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_ConversationDepth(t *testing.T) {
|
||||
history := make([]providers.Message, 7)
|
||||
f := ExtractFeatures("msg", history)
|
||||
if f.ConversationDepth != 7 {
|
||||
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"plain text", false},
|
||||
{"here is an image: data:image/png;base64,abc123", true},
|
||||
{"audio: data:audio/mp3;base64,xyz", true},
|
||||
{"video: data:video/mp4;base64,xyz", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"check out photo.jpg", true},
|
||||
{"see screenshot.png", true},
|
||||
{"listen to audio.mp3", true},
|
||||
{"watch clip.mp4", true},
|
||||
{"just a .go file", false},
|
||||
{"document.pdf", false}, // pdf is not in the media list
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── RuleClassifier ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{})
|
||||
if score != 0.0 {
|
||||
t.Errorf("zero features: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{HasAttachments: true})
|
||||
if score != 1.0 {
|
||||
t.Errorf("attachments: got %f, want 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Code block alone = 0.40, above default threshold 0.35
|
||||
score := c.Score(Features{CodeBlockCount: 1})
|
||||
if score < 0.35 {
|
||||
t.Errorf("code block: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_LongMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// >200 tokens = 0.35, exactly at default threshold → heavy
|
||||
score := c.Score(Features{TokenEstimate: 250})
|
||||
if score < 0.35 {
|
||||
t.Errorf("long message: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_MediumMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// 50-200 tokens = 0.15, below threshold → light
|
||||
score := c.Score(Features{TokenEstimate: 100})
|
||||
if score >= 0.35 {
|
||||
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ShortMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// <50 tokens, no other signals = 0.0 → light
|
||||
score := c.Score(Features{TokenEstimate: 10})
|
||||
if score != 0.0 {
|
||||
t.Errorf("short message: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
|
||||
scoreNone := c.Score(Features{RecentToolCalls: 0})
|
||||
scoreLow := c.Score(Features{RecentToolCalls: 2})
|
||||
scoreHigh := c.Score(Features{RecentToolCalls: 5})
|
||||
|
||||
if scoreNone != 0.0 {
|
||||
t.Errorf("no tools: got %f, want 0.0", scoreNone)
|
||||
}
|
||||
if scoreLow <= scoreNone {
|
||||
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
|
||||
}
|
||||
if scoreHigh <= scoreLow {
|
||||
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_DeepConversation(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
shallow := c.Score(Features{ConversationDepth: 5})
|
||||
deep := c.Score(Features{ConversationDepth: 15})
|
||||
if deep <= shallow {
|
||||
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Max all signals simultaneously
|
||||
f := Features{
|
||||
TokenEstimate: 500,
|
||||
CodeBlockCount: 3,
|
||||
RecentToolCalls: 10,
|
||||
ConversationDepth: 20,
|
||||
}
|
||||
score := c.Score(f)
|
||||
if score > 1.0 {
|
||||
t.Errorf("score %f exceeds 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Router ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestRouter_DefaultThreshold(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash"})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "hi"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("simple message: expected light model to be selected")
|
||||
}
|
||||
if model != "gemini-flash" {
|
||||
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "```go\nfmt.Println(\"hello\")\n```"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("code block: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "can you analyze this? data:image/png;base64,abc123"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("attachment: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
// >200 token estimate: 210 * 3 = 630 chars
|
||||
msg := strings.Repeat("word ", 210)
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("long message: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
|
||||
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
|
||||
// Routing is conservative: only promote to heavy when the signal is unambiguous.
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
|
||||
}
|
||||
msg := "ok"
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
|
||||
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{
|
||||
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
|
||||
}},
|
||||
}
|
||||
// ~55 tokens * 3 = 165 chars
|
||||
msg := strings.Repeat("word ", 55)
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
|
||||
// Very low threshold: even a short message triggers heavy model
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
|
||||
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("low threshold: medium message should use primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
|
||||
// Very high threshold: even code blocks route to light
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
|
||||
msg := "```go\nfmt.Println()\n```"
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("very high threshold: code block (0.40) should route to light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_LightModel(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
|
||||
if r.LightModel() != "my-fast-model" {
|
||||
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
|
||||
}
|
||||
}
|
||||
|
||||
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
|
||||
|
||||
type fixedScoreClassifier struct{ score float64 }
|
||||
|
||||
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
|
||||
|
||||
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.2},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if !usedLight {
|
||||
t.Error("low score with custom classifier: expected light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.8},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("high score with custom classifier: expected primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
|
||||
// score == threshold → primary (uses >= comparison)
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.5},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("score == threshold: expected primary model (>= threshold → primary)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.42},
|
||||
)
|
||||
_, _, score := r.SelectModel("anything", nil, "heavy")
|
||||
if score != 0.42 {
|
||||
t.Errorf("score: got %f, want 0.42", score)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user