mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'sipeed:main' into fix/reload-config-selfkill-guard
This commit is contained in:
+57
-1
@@ -42,6 +42,9 @@ type ContextBuilder struct {
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
@@ -602,7 +605,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized
|
||||
// Second pass: ensure every assistant message with tool_calls has matching
|
||||
// tool result messages following it. This is required by strict providers
|
||||
// like DeepSeek that enforce: "An assistant message with 'tool_calls' must
|
||||
// be followed by tool messages responding to each 'tool_call_id'."
|
||||
final := make([]providers.Message, 0, len(sanitized))
|
||||
for i := 0; i < len(sanitized); i++ {
|
||||
msg := sanitized[i]
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
// Collect expected tool_call IDs
|
||||
expected := make(map[string]bool, len(msg.ToolCalls))
|
||||
for _, tc := range msg.ToolCalls {
|
||||
expected[tc.ID] = false
|
||||
}
|
||||
|
||||
// Check following messages for matching tool results
|
||||
toolMsgCount := 0
|
||||
for j := i + 1; j < len(sanitized); j++ {
|
||||
if sanitized[j].Role != "tool" {
|
||||
break
|
||||
}
|
||||
toolMsgCount++
|
||||
if _, exists := expected[sanitized[j].ToolCallID]; exists {
|
||||
expected[sanitized[j].ToolCallID] = true
|
||||
}
|
||||
}
|
||||
|
||||
// If any tool_call_id is missing, drop this assistant message and its partial tool messages
|
||||
allFound := true
|
||||
for toolCallID, found := range expected {
|
||||
if !found {
|
||||
allFound = false
|
||||
logger.DebugCF(
|
||||
"agent",
|
||||
"Dropping assistant message with incomplete tool results",
|
||||
map[string]any{
|
||||
"missing_tool_call_id": toolCallID,
|
||||
"expected_count": len(expected),
|
||||
"found_count": toolMsgCount,
|
||||
},
|
||||
)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !allFound {
|
||||
// Skip this assistant message and its tool messages
|
||||
i += toolMsgCount
|
||||
continue
|
||||
}
|
||||
}
|
||||
final = append(final, msg)
|
||||
}
|
||||
|
||||
return final
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) AddToolResult(
|
||||
|
||||
@@ -207,3 +207,77 @@ func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSanitizeHistoryForProvider_IncompleteToolResults tests the forward validation
|
||||
// that ensures assistant messages with tool_calls have ALL matching tool results.
|
||||
// This fixes the DeepSeek error: "An assistant message with 'tool_calls' must be
|
||||
// followed by tool messages responding to each 'tool_call_id'."
|
||||
func TestSanitizeHistoryForProvider_IncompleteToolResults(t *testing.T) {
|
||||
// Assistant expects tool results for both A and B, but only A is present
|
||||
history := []providers.Message{
|
||||
msg("user", "do two things"),
|
||||
assistantWithTools("A", "B"),
|
||||
toolResult("A"),
|
||||
// toolResult("B") is missing - this would cause DeepSeek to fail
|
||||
msg("user", "next question"),
|
||||
msg("assistant", "answer"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
// The assistant message with incomplete tool results should be dropped,
|
||||
// along with its partial tool result. The remaining messages are:
|
||||
// user ("do two things"), user ("next question"), assistant ("answer")
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "user", "assistant")
|
||||
}
|
||||
|
||||
// TestSanitizeHistoryForProvider_MissingAllToolResults tests the case where
|
||||
// an assistant message has tool_calls but no tool results follow at all.
|
||||
func TestSanitizeHistoryForProvider_MissingAllToolResults(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msg("user", "do something"),
|
||||
assistantWithTools("A"),
|
||||
// No tool results at all
|
||||
msg("user", "hello"),
|
||||
msg("assistant", "hi"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
// The assistant message with no tool results should be dropped.
|
||||
// Remaining: user ("do something"), user ("hello"), assistant ("hi")
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "user", "assistant")
|
||||
}
|
||||
|
||||
// TestSanitizeHistoryForProvider_PartialToolResultsInMiddle tests that
|
||||
// incomplete tool results in the middle of a conversation are properly handled.
|
||||
func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msg("user", "first"),
|
||||
assistantWithTools("A"),
|
||||
toolResult("A"),
|
||||
msg("assistant", "done"),
|
||||
msg("user", "second"),
|
||||
assistantWithTools("B", "C"),
|
||||
toolResult("B"),
|
||||
// toolResult("C") is missing
|
||||
msg("user", "third"),
|
||||
assistantWithTools("D"),
|
||||
toolResult("D"),
|
||||
msg("assistant", "all done"),
|
||||
}
|
||||
|
||||
result := sanitizeHistoryForProvider(history)
|
||||
// First round is complete (user, assistant+tools, tool, assistant),
|
||||
// second round is incomplete and dropped (assistant+tools, partial tool),
|
||||
// third round is complete (user, assistant+tools, tool, assistant).
|
||||
// Remaining: user, assistant, tool, assistant, user, user, assistant, tool, assistant
|
||||
if len(result) != 9 {
|
||||
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
|
||||
}
|
||||
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant")
|
||||
}
|
||||
|
||||
+55
-12
@@ -37,6 +37,14 @@ type AgentInstance struct {
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
|
||||
// Router is non-nil when model routing is configured and the light model
|
||||
// was successfully resolved. It scores each incoming message and decides
|
||||
// whether to route to LightCandidates or stay with Candidates.
|
||||
Router *routing.Router
|
||||
// LightCandidates holds the resolved provider candidates for the light model.
|
||||
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
|
||||
LightCandidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -60,17 +68,30 @@ func NewAgentInstance(
|
||||
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
|
||||
if cfg.Tools.IsToolEnabled("read_file") {
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("write_file") {
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("list_dir") {
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("exec") {
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
toolsRegistry.Register(execTool)
|
||||
}
|
||||
|
||||
if cfg.Tools.IsToolEnabled("edit_file") {
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("append_file") {
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
|
||||
}
|
||||
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
@@ -167,6 +188,25 @@ func NewAgentInstance(
|
||||
|
||||
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
var lightCandidates []providers.FallbackCandidate
|
||||
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
|
||||
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
|
||||
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
|
||||
if len(resolved) > 0 {
|
||||
router = routing.New(routing.RouterConfig{
|
||||
LightModel: rc.LightModel,
|
||||
Threshold: rc.Threshold,
|
||||
})
|
||||
lightCandidates = resolved
|
||||
} else {
|
||||
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
|
||||
rc.LightModel, agentID)
|
||||
}
|
||||
}
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
@@ -187,6 +227,8 @@ func NewAgentInstance(
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,12 +237,13 @@ func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentD
|
||||
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
|
||||
return expandHome(strings.TrimSpace(agentCfg.Workspace))
|
||||
}
|
||||
// Use the configured default workspace (respects PICOCLAW_HOME)
|
||||
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
|
||||
return expandHome(defaults.Workspace)
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
// For named agents without explicit workspace, use default workspace with agent ID suffix
|
||||
id := routing.NormalizeAgentID(agentCfg.ID)
|
||||
return filepath.Join(home, ".picoclaw", "workspace-"+id)
|
||||
return filepath.Join(expandHome(defaults.Workspace), "..", "workspace-"+id)
|
||||
}
|
||||
|
||||
// resolveAgentModel resolves the primary model for an agent.
|
||||
|
||||
+273
-165
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -46,6 +47,7 @@ type AgentLoop struct {
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -61,7 +63,15 @@ type processOptions struct {
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
const (
|
||||
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
metadataKeyParentPeerKind = "parent_peer_kind"
|
||||
metadataKeyParentPeerID = "parent_peer_id"
|
||||
)
|
||||
|
||||
func NewAgentLoop(
|
||||
cfg *config.Config,
|
||||
@@ -84,14 +94,17 @@ func NewAgentLoop(
|
||||
stateManager = state.NewManager(defaultAgent.Workspace)
|
||||
}
|
||||
|
||||
return &AgentLoop{
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
}
|
||||
|
||||
return al
|
||||
}
|
||||
|
||||
// registerSharedTools registers tools that are shared across all agents (web, message, spawn).
|
||||
@@ -108,76 +121,106 @@ func registerSharedTools(
|
||||
}
|
||||
|
||||
// Web tools
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
|
||||
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
|
||||
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
|
||||
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
|
||||
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
|
||||
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
|
||||
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
|
||||
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
if cfg.Tools.IsToolEnabled("web") {
|
||||
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
|
||||
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
|
||||
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
|
||||
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
|
||||
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
|
||||
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
|
||||
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
|
||||
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
|
||||
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
|
||||
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
|
||||
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
|
||||
Proxy: cfg.Tools.Web.Proxy,
|
||||
})
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
|
||||
} else if searchTool != nil {
|
||||
agent.Tools.Register(searchTool)
|
||||
}
|
||||
}
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
if cfg.Tools.IsToolEnabled("web_fetch") {
|
||||
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
agent.Tools.Register(fetchTool)
|
||||
}
|
||||
}
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
if cfg.Tools.IsToolEnabled("i2c") {
|
||||
agent.Tools.Register(tools.NewI2CTool())
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("spi") {
|
||||
agent.Tools.Register(tools.NewSPITool())
|
||||
}
|
||||
|
||||
// Message tool
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
if cfg.Tools.IsToolEnabled("message") {
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
})
|
||||
})
|
||||
})
|
||||
agent.Tools.Register(messageTool)
|
||||
agent.Tools.Register(messageTool)
|
||||
}
|
||||
|
||||
// Skill discovery and installation tools
|
||||
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
|
||||
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
|
||||
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
|
||||
})
|
||||
searchCache := skills.NewSearchCache(
|
||||
cfg.Tools.Skills.SearchCache.MaxSize,
|
||||
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
|
||||
)
|
||||
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
|
||||
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
|
||||
skills_enabled := cfg.Tools.IsToolEnabled("skills")
|
||||
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
|
||||
install_skills_enable := cfg.Tools.IsToolEnabled("install_skill")
|
||||
if skills_enabled && (find_skills_enable || install_skills_enable) {
|
||||
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
|
||||
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
|
||||
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
|
||||
})
|
||||
|
||||
if find_skills_enable {
|
||||
searchCache := skills.NewSearchCache(
|
||||
cfg.Tools.Skills.SearchCache.MaxSize,
|
||||
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
|
||||
)
|
||||
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
|
||||
}
|
||||
|
||||
if install_skills_enable {
|
||||
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
|
||||
})
|
||||
agent.Tools.Register(spawnTool)
|
||||
if cfg.Tools.IsToolEnabled("spawn") {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
|
||||
})
|
||||
agent.Tools.Register(spawnTool)
|
||||
} else {
|
||||
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,7 +228,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
// Initialize MCP servers for all agents
|
||||
if al.cfg.Tools.MCP.Enabled {
|
||||
if al.cfg.Tools.IsToolEnabled("mcp") {
|
||||
mcpManager := mcp.NewManager()
|
||||
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
|
||||
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
|
||||
@@ -227,6 +270,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
totalRegistrations++
|
||||
@@ -518,27 +562,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return al.processSystemMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// Check for commands
|
||||
if response, handled := al.handleCommand(ctx, msg); handled {
|
||||
route, agent, routeErr := al.resolveMessageRoute(msg)
|
||||
|
||||
// Commands are checked before requiring a successful route.
|
||||
// Global commands (/help, /show, /switch) work even when routing fails;
|
||||
// context-dependent commands check their own Runtime fields and report
|
||||
// "unavailable" when the required capability is nil.
|
||||
if response, handled := al.handleCommand(ctx, msg, agent); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Route to determine agent and session key
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
AccountID: msg.Metadata["account_id"],
|
||||
Peer: extractPeer(msg),
|
||||
ParentPeer: extractParentPeer(msg),
|
||||
GuildID: msg.Metadata["guild_id"],
|
||||
TeamID: msg.Metadata["team_id"],
|
||||
})
|
||||
|
||||
agent, ok := al.registry.GetAgent(route.AgentID)
|
||||
if !ok {
|
||||
agent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
|
||||
if routeErr != nil {
|
||||
return "", routeErr
|
||||
}
|
||||
|
||||
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
|
||||
@@ -548,17 +583,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
|
||||
sessionKey := route.SessionKey
|
||||
if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") {
|
||||
sessionKey = msg.SessionKey
|
||||
}
|
||||
// Resolve session key from route, while preserving explicit agent-scoped keys.
|
||||
scopeKey := resolveScopeKey(route, msg.SessionKey)
|
||||
sessionKey := scopeKey
|
||||
|
||||
logger.InfoCF("agent", "Routed message",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"session_key": sessionKey,
|
||||
"matched_by": route.MatchedBy,
|
||||
"agent_id": agent.ID,
|
||||
"scope_key": scopeKey,
|
||||
"session_key": sessionKey,
|
||||
"matched_by": route.MatchedBy,
|
||||
"route_agent": route.AgentID,
|
||||
"route_channel": route.Channel,
|
||||
})
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
@@ -573,6 +609,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
AccountID: inboundMetadata(msg, metadataKeyAccountID),
|
||||
Peer: extractPeer(msg),
|
||||
ParentPeer: extractParentPeer(msg),
|
||||
GuildID: inboundMetadata(msg, metadataKeyGuildID),
|
||||
TeamID: inboundMetadata(msg, metadataKeyTeamID),
|
||||
})
|
||||
|
||||
agent, ok := al.registry.GetAgent(route.AgentID)
|
||||
if !ok {
|
||||
agent = al.registry.GetDefaultAgent()
|
||||
}
|
||||
if agent == nil {
|
||||
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
|
||||
}
|
||||
|
||||
return route, agent, nil
|
||||
}
|
||||
|
||||
func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
|
||||
if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) {
|
||||
return msgSessionKey
|
||||
}
|
||||
return route.SessionKey
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
@@ -793,6 +857,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
// Determine effective model tier for this conversation turn.
|
||||
// selectCandidates evaluates routing once and the decision is sticky for
|
||||
// all tool-follow-up iterations within the same turn so that a multi-step
|
||||
// tool chain doesn't switch models mid-way through.
|
||||
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
|
||||
|
||||
for iteration < agent.MaxIterations {
|
||||
iteration++
|
||||
|
||||
@@ -811,7 +881,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": agent.Model,
|
||||
"model": activeModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": agent.MaxTokens,
|
||||
@@ -827,7 +897,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
// Call LLM with fallback chain if candidates are configured.
|
||||
// Call LLM with fallback chain if multiple candidates are configured.
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
@@ -848,10 +918,10 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
|
||||
},
|
||||
@@ -869,7 +939,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
|
||||
}
|
||||
|
||||
// Retry loop for context/token errors
|
||||
@@ -1138,6 +1208,44 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return finalContent, iteration, nil
|
||||
}
|
||||
|
||||
// selectCandidates returns the model candidates and resolved model name to use
|
||||
// for a conversation turn. When model routing is configured and the incoming
|
||||
// message scores below the complexity threshold, it returns the light model
|
||||
// candidates instead of the primary ones.
|
||||
//
|
||||
// The returned (candidates, model) pair is used for all LLM calls within one
|
||||
// turn — tool follow-up iterations use the same tier as the initial call so
|
||||
// that a multi-step tool chain doesn't switch models mid-way.
|
||||
func (al *AgentLoop) selectCandidates(
|
||||
agent *AgentInstance,
|
||||
userMsg string,
|
||||
history []providers.Message,
|
||||
) (candidates []providers.FallbackCandidate, model string) {
|
||||
if agent.Router == nil || len(agent.LightCandidates) == 0 {
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
|
||||
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
|
||||
if !usedLight {
|
||||
logger.DebugCF("agent", "Model routing: primary model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.Candidates, agent.Model
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Model routing: light model selected",
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"light_model": agent.Router.LightModel(),
|
||||
"score": score,
|
||||
"threshold": agent.Router.Threshold(),
|
||||
})
|
||||
return agent.LightCandidates, agent.Router.LightModel()
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
|
||||
newHistory := agent.Sessions.GetHistory(sessionKey)
|
||||
@@ -1429,94 +1537,87 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if !strings.HasPrefix(content, "/") {
|
||||
func (al *AgentLoop) handleCommand(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
agent *AgentInstance,
|
||||
) (string, bool) {
|
||||
if !commands.HasCommandPrefix(msg.Content) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.Fields(content)
|
||||
if len(parts) == 0 {
|
||||
if al.cmdRegistry == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cmd := parts[0]
|
||||
args := parts[1:]
|
||||
rt := al.buildCommandsRuntime(agent)
|
||||
executor := commands.NewExecutor(al.cmdRegistry, rt)
|
||||
|
||||
switch cmd {
|
||||
case "/show":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /show [model|channel|agents]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
|
||||
case "channel":
|
||||
return fmt.Sprintf("Current channel: %s", msg.Channel), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown show target: %s", args[0]), true
|
||||
}
|
||||
var commandReply string
|
||||
result := executor.Execute(ctx, commands.Request{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
SenderID: msg.SenderID,
|
||||
Text: msg.Content,
|
||||
Reply: func(text string) error {
|
||||
commandReply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
case "/list":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /list [models|channels|agents]", true
|
||||
switch result.Outcome {
|
||||
case commands.OutcomeHandled:
|
||||
if result.Err != nil {
|
||||
return mapCommandError(result), true
|
||||
}
|
||||
switch args[0] {
|
||||
case "models":
|
||||
return "Available models: configured in config.json per agent", true
|
||||
case "channels":
|
||||
if commandReply != "" {
|
||||
return commandReply, true
|
||||
}
|
||||
return "", true
|
||||
default: // OutcomePassthrough — let the message fall through to LLM
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime {
|
||||
rt := &commands.Runtime{
|
||||
Config: al.cfg,
|
||||
ListAgentIDs: al.registry.ListAgentIDs,
|
||||
ListDefinitions: al.cmdRegistry.Definitions,
|
||||
GetEnabledChannels: func() []string {
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
return nil
|
||||
}
|
||||
channels := al.channelManager.GetEnabledChannels()
|
||||
if len(channels) == 0 {
|
||||
return "No channels enabled", true
|
||||
}
|
||||
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
|
||||
case "agents":
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown list target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/switch":
|
||||
if len(args) < 3 || args[1] != "to" {
|
||||
return "Usage: /switch [model|channel] to <name>", true
|
||||
}
|
||||
target := args[0]
|
||||
value := args[2]
|
||||
|
||||
switch target {
|
||||
case "model":
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return "No default agent configured", true
|
||||
}
|
||||
oldModel := defaultAgent.Model
|
||||
defaultAgent.Model = value
|
||||
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
|
||||
case "channel":
|
||||
return al.channelManager.GetEnabledChannels()
|
||||
},
|
||||
SwitchChannel: func(value string) error {
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
return fmt.Errorf("channel manager not initialized")
|
||||
}
|
||||
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
|
||||
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
|
||||
return fmt.Errorf("channel '%s' not found or not enabled", value)
|
||||
}
|
||||
return fmt.Sprintf("Switched target channel to %s", value), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown switch target: %s", target), true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if agent != nil {
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, al.cfg.Agents.Defaults.Provider
|
||||
}
|
||||
rt.SwitchModel = func(value string) (string, error) {
|
||||
oldModel := agent.Model
|
||||
agent.Model = value
|
||||
return oldModel, nil
|
||||
}
|
||||
}
|
||||
return rt
|
||||
}
|
||||
|
||||
return "", false
|
||||
func mapCommandError(result commands.ExecuteResult) string {
|
||||
if result.Command == "" {
|
||||
return fmt.Sprintf("Failed to execute command: %v", result.Err)
|
||||
}
|
||||
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
|
||||
}
|
||||
|
||||
// extractPeer extracts the routing peer from the inbound message's structured Peer field.
|
||||
@@ -1535,10 +1636,17 @@ func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID}
|
||||
}
|
||||
|
||||
func inboundMetadata(msg bus.InboundMessage, key string) string {
|
||||
if msg.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
return msg.Metadata[key]
|
||||
}
|
||||
|
||||
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
|
||||
func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
parentKind := msg.Metadata["parent_peer_kind"]
|
||||
parentID := msg.Metadata["parent_peer_id"]
|
||||
parentKind := inboundMetadata(msg, metadataKeyParentPeerKind)
|
||||
parentID := inboundMetadata(msg, metadataKeyParentPeerID)
|
||||
if parentKind == "" || parentID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
+221
-10
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -227,16 +228,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) {
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = tmpDir
|
||||
cfg.Agents.Defaults.Model = "test-model"
|
||||
cfg.Agents.Defaults.MaxTokens = 4096
|
||||
cfg.Agents.Defaults.MaxToolIterations = 10
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
@@ -323,6 +319,29 @@ func (m *simpleMockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
type countingMockProvider struct {
|
||||
response string
|
||||
calls int
|
||||
}
|
||||
|
||||
func (m *countingMockProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
m.calls++
|
||||
return &providers.LLMResponse{
|
||||
Content: m.response,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *countingMockProvider) GetDefaultModel() string {
|
||||
return "counting-mock-model"
|
||||
}
|
||||
|
||||
// mockCustomTool is a simple mock tool for registration testing
|
||||
type mockCustomTool struct{}
|
||||
|
||||
@@ -364,6 +383,198 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms
|
||||
|
||||
const responseTimeout = 3 * time.Second
|
||||
|
||||
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "ok"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
Peer: extractPeer(msg),
|
||||
})
|
||||
sessionKey := route.SessionKey
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
|
||||
helper := testHelper{al: al}
|
||||
_ = helper.executeAndGetResponse(t, context.Background(), msg)
|
||||
|
||||
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) != 2 {
|
||||
t.Fatalf("expected session history len=2, got %d", len(history))
|
||||
}
|
||||
if history[0].Role != "user" || history[0].Content != "hello" {
|
||||
t.Fatalf("unexpected first message in session: %+v", history[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-channel-peer",
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
baseMsg := bus.InboundMessage{
|
||||
Channel: "whatsapp",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/show channel",
|
||||
Peer: baseMsg.Peer,
|
||||
})
|
||||
if showResp != "Current Channel: whatsapp" {
|
||||
t.Fatalf("unexpected /show reply: %q", showResp)
|
||||
}
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls)
|
||||
}
|
||||
|
||||
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/foo",
|
||||
Peer: baseMsg.Peer,
|
||||
})
|
||||
if fooResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /foo reply: %q", fooResp)
|
||||
}
|
||||
if provider.calls != 1 {
|
||||
t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls)
|
||||
}
|
||||
|
||||
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/new",
|
||||
Peer: baseMsg.Peer,
|
||||
})
|
||||
if newResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /new reply: %q", newResp)
|
||||
}
|
||||
if provider.calls != 2 {
|
||||
t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Provider: "openai",
|
||||
Model: "before-switch",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &countingMockProvider{response: "LLM reply"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to after-switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
}
|
||||
|
||||
if provider.calls != 0 {
|
||||
t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
+20
-5
@@ -212,7 +212,10 @@ func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading device code response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||
}
|
||||
@@ -300,7 +303,10 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading device code response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||
}
|
||||
@@ -360,7 +366,10 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au
|
||||
return nil, fmt.Errorf("pending")
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading device token response: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AuthorizationCode string `json:"authorization_code"`
|
||||
@@ -401,7 +410,10 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading token refresh response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
@@ -494,7 +506,10 @@ func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading token exchange response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
@@ -39,6 +39,9 @@ func (c *AuthCredential) NeedsRefresh() bool {
|
||||
}
|
||||
|
||||
func authFilePath() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return filepath.Join(home, "auth.json")
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw", "auth.json")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -26,6 +27,12 @@ const (
|
||||
sendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call)
|
||||
channelRefRe = regexp.MustCompile(`<#(\d+)>`)
|
||||
msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`)
|
||||
)
|
||||
|
||||
type DiscordChannel struct {
|
||||
*channels.BaseChannel
|
||||
session *discordgo.Session
|
||||
@@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
content = c.stripBotMention(content)
|
||||
}
|
||||
|
||||
// Resolve Discord refs in main content before concatenation to avoid
|
||||
// double-expanding links that appear in the referenced message.
|
||||
content = c.resolveDiscordRefs(s, content, m.GuildID)
|
||||
|
||||
// Prepend referenced (quoted) message content if this is a reply
|
||||
if m.MessageReference != nil && m.ReferencedMessage != nil {
|
||||
refContent := m.ReferencedMessage.Content
|
||||
if refContent != "" {
|
||||
refAuthor := "unknown"
|
||||
if m.ReferencedMessage.Author != nil {
|
||||
refAuthor = m.ReferencedMessage.Author.Username
|
||||
}
|
||||
refContent = c.resolveDiscordRefs(s, refContent, m.GuildID)
|
||||
content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s",
|
||||
refAuthor, refContent, content)
|
||||
}
|
||||
}
|
||||
|
||||
senderID := m.Author.ID
|
||||
|
||||
mediaPaths := make([]string, 0, len(m.Attachments))
|
||||
@@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and
|
||||
// expands Discord message links to show the linked message content.
|
||||
// Only links pointing to the same guild are expanded to prevent cross-guild leakage.
|
||||
func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string {
|
||||
// 1. Resolve channel references: <#id> → #channel-name
|
||||
text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string {
|
||||
parts := channelRefRe.FindStringSubmatch(match)
|
||||
if len(parts) < 2 {
|
||||
return match
|
||||
}
|
||||
// Prefer session state cache to avoid API calls
|
||||
if ch, err := s.State.Channel(parts[1]); err == nil {
|
||||
return "#" + ch.Name
|
||||
}
|
||||
if ch, err := s.Channel(parts[1]); err == nil {
|
||||
return "#" + ch.Name
|
||||
}
|
||||
return match
|
||||
})
|
||||
|
||||
// 2. Expand Discord message links (max 3, same guild only)
|
||||
matches := msgLinkRe.FindAllStringSubmatch(text, 3)
|
||||
for _, m := range matches {
|
||||
if len(m) < 4 {
|
||||
continue
|
||||
}
|
||||
linkGuildID, channelID, messageID := m[1], m[2], m[3]
|
||||
// Security: only expand links from the same guild
|
||||
if linkGuildID != guildID {
|
||||
continue
|
||||
}
|
||||
msg, err := s.ChannelMessage(channelID, messageID)
|
||||
if err != nil || msg == nil || msg.Content == "" {
|
||||
continue
|
||||
}
|
||||
author := "unknown"
|
||||
if msg.Author != nil {
|
||||
author = msg.Author.Username
|
||||
}
|
||||
text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// stripBotMention removes the bot mention from the message content.
|
||||
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
|
||||
func (c *DiscordChannel) stripBotMention(text string) string {
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package discord
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChannelRefRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantID string
|
||||
wantOK bool
|
||||
}{
|
||||
{"basic channel ref", "<#123456789>", "123456789", true},
|
||||
{"long id", "<#9876543210123456>", "9876543210123456", true},
|
||||
{"no match plain text", "hello world", "", false},
|
||||
{"no match partial", "<#>", "", false},
|
||||
{"no match letters", "<#abc>", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := channelRefRe.FindStringSubmatch(tt.input)
|
||||
if tt.wantOK {
|
||||
if len(matches) < 2 || matches[1] != tt.wantID {
|
||||
t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID)
|
||||
}
|
||||
} else {
|
||||
if len(matches) >= 2 {
|
||||
t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgLinkRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantGuild string
|
||||
wantChan string
|
||||
wantMsg string
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
"discord.com link",
|
||||
"https://discord.com/channels/111/222/333",
|
||||
"111", "222", "333", true,
|
||||
},
|
||||
{
|
||||
"discordapp.com link",
|
||||
"https://discordapp.com/channels/111/222/333",
|
||||
"111", "222", "333", true,
|
||||
},
|
||||
{
|
||||
"real world ids",
|
||||
"check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please",
|
||||
"9000000000000001", "9000000000000002", "9000000000000003", true,
|
||||
},
|
||||
{"no match http", "http://discord.com/channels/1/2/3", "", "", "", false},
|
||||
{"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false},
|
||||
{"no match plain text", "hello world", "", "", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := msgLinkRe.FindStringSubmatch(tt.input)
|
||||
if tt.wantOK {
|
||||
if len(matches) < 4 {
|
||||
t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s",
|
||||
tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg)
|
||||
}
|
||||
if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg {
|
||||
t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s",
|
||||
tt.input, matches[1], matches[2], matches[3],
|
||||
tt.wantGuild, tt.wantChan, tt.wantMsg)
|
||||
}
|
||||
} else {
|
||||
if len(matches) >= 4 {
|
||||
t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMsgLinkRegex_MultipleMatches(t *testing.T) {
|
||||
input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12"
|
||||
matches := msgLinkRe.FindAllStringSubmatch(input, 3)
|
||||
if len(matches) != 3 {
|
||||
t.Fatalf("expected 3 matches (capped), got %d", len(matches))
|
||||
}
|
||||
// Verify the 3rd match is 7/8/9 (not 10/11/12)
|
||||
if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" {
|
||||
t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2])
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
package channels
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
// TypingCapable — channels that can show a typing/thinking indicator.
|
||||
// StartTyping begins the indicator and returns a stop function.
|
||||
@@ -39,3 +43,10 @@ type PlaceholderRecorder interface {
|
||||
RecordTypingStop(channel, chatID string, stop func())
|
||||
RecordReactionUndo(channel, chatID string, undo func())
|
||||
}
|
||||
|
||||
// CommandRegistrarCapable is implemented by channels that can register
|
||||
// command menus with their upstream platform (e.g. Telegram BotCommand).
|
||||
// Channels that do not support platform-level command menus can ignore it.
|
||||
type CommandRegistrarCapable interface {
|
||||
RegisterCommands(ctx context.Context, defs []commands.Definition) error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
type mockRegistrar struct{}
|
||||
|
||||
func (mockRegistrar) RegisterCommands(context.Context, []commands.Definition) error { return nil }
|
||||
|
||||
func TestCommandRegistrarCapable_Compiles(t *testing.T) {
|
||||
var _ CommandRegistrarCapable = mockRegistrar{}
|
||||
}
|
||||
@@ -654,7 +654,10 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("reading LINE API error response: %w", err))
|
||||
}
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody)))
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
var commandRegistrationBackoff = []time.Duration{
|
||||
5 * time.Second,
|
||||
15 * time.Second,
|
||||
60 * time.Second,
|
||||
5 * time.Minute,
|
||||
10 * time.Minute,
|
||||
}
|
||||
|
||||
func commandRegistrationDelay(attempt int) time.Duration {
|
||||
if len(commandRegistrationBackoff) == 0 {
|
||||
return 0
|
||||
}
|
||||
base := commandRegistrationBackoff[min(attempt, len(commandRegistrationBackoff)-1)]
|
||||
// Full jitter in [0.5, 1.0) to avoid synchronized retries across instances.
|
||||
return time.Duration(float64(base) * (0.5 + rand.Float64()*0.5))
|
||||
}
|
||||
|
||||
// RegisterCommands registers bot commands on Telegram platform.
|
||||
func (c *TelegramChannel) RegisterCommands(ctx context.Context, defs []commands.Definition) error {
|
||||
botCommands := make([]telego.BotCommand, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
if def.Name == "" || def.Description == "" {
|
||||
continue
|
||||
}
|
||||
botCommands = append(botCommands, telego.BotCommand{
|
||||
Command: def.Name,
|
||||
Description: def.Description,
|
||||
})
|
||||
}
|
||||
|
||||
current, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{})
|
||||
if err != nil {
|
||||
// If we can't read current commands, fall through to set them.
|
||||
logger.WarnCF("telegram", "Failed to get current commands, will set unconditionally",
|
||||
map[string]any{"error": err.Error()})
|
||||
} else if slices.Equal(current, botCommands) {
|
||||
logger.DebugCF("telegram", "Bot commands are up to date", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
|
||||
Commands: botCommands,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []commands.Definition) {
|
||||
if len(defs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
register := c.registerFunc
|
||||
if register == nil {
|
||||
register = c.RegisterCommands
|
||||
}
|
||||
|
||||
regCtx, cancel := context.WithCancel(ctx)
|
||||
c.commandRegCancel = cancel
|
||||
|
||||
// Registration runs asynchronously so Telegram message intake is never blocked
|
||||
// by temporary upstream API failures. Retry stops on success or channel shutdown.
|
||||
go func() {
|
||||
attempt := 0
|
||||
timer := time.NewTimer(0)
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
defer timer.Stop()
|
||||
for {
|
||||
err := register(regCtx, defs)
|
||||
if err == nil {
|
||||
logger.InfoCF("telegram", "Telegram commands registered", map[string]any{
|
||||
"count": len(defs),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
delay := commandRegistrationDelay(attempt)
|
||||
logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry_after": delay.String(),
|
||||
})
|
||||
attempt++
|
||||
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(delay)
|
||||
|
||||
select {
|
||||
case <-regCtx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
)
|
||||
|
||||
func TestStartCommandRegistration_DoesNotBlock(t *testing.T) {
|
||||
ch := &TelegramChannel{}
|
||||
started := make(chan struct{}, 1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ch.registerFunc = func(context.Context, []commands.Definition) error {
|
||||
started <- struct{}{}
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
|
||||
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help"}})
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("registration did not start asynchronously")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) {
|
||||
ch := &TelegramChannel{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
origBackoff := commandRegistrationBackoff
|
||||
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
|
||||
defer func() { commandRegistrationBackoff = origBackoff }()
|
||||
|
||||
var attempts atomic.Int32
|
||||
ch.registerFunc = func(context.Context, []commands.Definition) error {
|
||||
n := attempts.Add(1)
|
||||
if n < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
|
||||
|
||||
deadline := time.Now().Add(250 * time.Millisecond)
|
||||
for time.Now().Before(deadline) {
|
||||
if attempts.Load() >= 3 {
|
||||
break
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
if attempts.Load() < 3 {
|
||||
t.Fatalf("expected at least 3 attempts, got %d", attempts.Load())
|
||||
}
|
||||
|
||||
stable := attempts.Load()
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
if attempts.Load() != stable {
|
||||
t.Fatalf("expected retries to stop after success, got %d -> %d", stable, attempts.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) {
|
||||
ch := &TelegramChannel{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
origBackoff := commandRegistrationBackoff
|
||||
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
|
||||
defer func() { commandRegistrationBackoff = origBackoff }()
|
||||
defer cancel()
|
||||
|
||||
var attempts atomic.Int32
|
||||
ch.registerFunc = func(context.Context, []commands.Definition) error {
|
||||
attempts.Add(1)
|
||||
return errors.New("always fail")
|
||||
}
|
||||
|
||||
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
time.Sleep(20 * time.Millisecond) // allow in-flight attempt to settle
|
||||
stable := attempts.Load()
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
if attempts.Load() != stable {
|
||||
t.Fatalf("expected retries to quiesce after cancel, got %d -> %d", stable, attempts.Load())
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/commands"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -40,13 +40,15 @@ var (
|
||||
|
||||
type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
bot *telego.Bot
|
||||
bh *th.BotHandler
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
bot *telego.Bot
|
||||
bh *th.BotHandler
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
registerFunc func(context.Context, []commands.Definition) error
|
||||
commandRegCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
@@ -93,7 +95,6 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
commands: NewTelegramCommands(bot, cfg),
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
chatIDs: make(map[string]int64),
|
||||
@@ -105,12 +106,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.initBotCommands(c.ctx); err != nil {
|
||||
logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
})
|
||||
@@ -126,21 +121,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
c.bh = bh
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Help(ctx, message)
|
||||
}, th.CommandEqual("help"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
}, th.CommandEqual("show"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.List(ctx, message)
|
||||
}, th.CommandEqual("list"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.handleMessage(ctx, &message)
|
||||
}, th.AnyMessage())
|
||||
@@ -150,6 +130,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions())
|
||||
|
||||
go func() {
|
||||
if err = bh.Start(); err != nil {
|
||||
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
|
||||
@@ -174,50 +156,8 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
|
||||
currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get commands: %w", err)
|
||||
}
|
||||
|
||||
commands := []telego.BotCommand{
|
||||
{
|
||||
Command: "start",
|
||||
Description: "Start the bot",
|
||||
},
|
||||
{
|
||||
Command: "help",
|
||||
Description: "Show a help message",
|
||||
},
|
||||
{
|
||||
Command: "show",
|
||||
Description: "Show current configuration",
|
||||
},
|
||||
{
|
||||
Command: "list",
|
||||
Description: "List available options",
|
||||
},
|
||||
}
|
||||
|
||||
// Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
|
||||
if !slices.Equal(currentCommands, commands) {
|
||||
logger.InfoC("telegram", "Updating bot commands")
|
||||
|
||||
err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
|
||||
Commands: commands,
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set commands: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.DebugC("telegram", "Bot commands are up to date")
|
||||
if c.commandRegCancel != nil {
|
||||
c.commandRegCancel()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -721,34 +661,34 @@ func escapeHTML(text string) string {
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message via entities.
|
||||
func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
|
||||
botUsername := c.bot.Username()
|
||||
if botUsername == "" {
|
||||
text, entities := telegramEntityTextAndList(message)
|
||||
if text == "" || len(entities) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
entities := message.Entities
|
||||
if entities == nil {
|
||||
entities = message.CaptionEntities
|
||||
botUsername := ""
|
||||
if c.bot != nil {
|
||||
botUsername = c.bot.Username()
|
||||
}
|
||||
runes := []rune(text)
|
||||
|
||||
for _, entity := range entities {
|
||||
if entity.Type == "mention" {
|
||||
// Extract the mention text from the message
|
||||
text := message.Text
|
||||
if text == "" {
|
||||
text = message.Caption
|
||||
}
|
||||
runes := []rune(text)
|
||||
end := entity.Offset + entity.Length
|
||||
if end <= len(runes) {
|
||||
mention := string(runes[entity.Offset:end])
|
||||
if strings.EqualFold(mention, "@"+botUsername) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
entityText, ok := telegramEntityText(runes, entity)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if entity.Type == "text_mention" && entity.User != nil {
|
||||
if entity.User.Username == botUsername {
|
||||
|
||||
switch entity.Type {
|
||||
case telego.EntityTypeMention:
|
||||
if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) {
|
||||
return true
|
||||
}
|
||||
case telego.EntityTypeTextMention:
|
||||
if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) {
|
||||
return true
|
||||
}
|
||||
case telego.EntityTypeBotCommand:
|
||||
if isBotCommandEntityForThisBot(entityText, botUsername) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -756,6 +696,46 @@ func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) {
|
||||
if message.Text != "" {
|
||||
return message.Text, message.Entities
|
||||
}
|
||||
return message.Caption, message.CaptionEntities
|
||||
}
|
||||
|
||||
func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) {
|
||||
if entity.Offset < 0 || entity.Length <= 0 {
|
||||
return "", false
|
||||
}
|
||||
end := entity.Offset + entity.Length
|
||||
if entity.Offset >= len(runes) || end > len(runes) {
|
||||
return "", false
|
||||
}
|
||||
return string(runes[entity.Offset:end]), true
|
||||
}
|
||||
|
||||
func isBotCommandEntityForThisBot(entityText, botUsername string) bool {
|
||||
if !strings.HasPrefix(entityText, "/") {
|
||||
return false
|
||||
}
|
||||
command := strings.TrimPrefix(entityText, "/")
|
||||
if command == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
at := strings.IndexRune(command, '@')
|
||||
if at == -1 {
|
||||
// A bare /command delivered to this bot is intended for this bot.
|
||||
return true
|
||||
}
|
||||
|
||||
mentionUsername := command[at+1:]
|
||||
if mentionUsername == "" || botUsername == "" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(mentionUsername, botUsername)
|
||||
}
|
||||
|
||||
// stripBotMention removes the @bot mention from the content.
|
||||
func (c *TelegramChannel) stripBotMention(content string) string {
|
||||
botUsername := c.bot.Username()
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type TelegramCommander interface {
|
||||
Help(ctx context.Context, message telego.Message) error
|
||||
Start(ctx context.Context, message telego.Message) error
|
||||
Show(ctx context.Context, message telego.Message) error
|
||||
List(ctx context.Context, message telego.Message) error
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
bot *telego.Bot
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
|
||||
return &cmd{
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func commandArgs(text string) string {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
|
||||
msg := `/start - Start the bot
|
||||
/help - Show this help message
|
||||
/show [model|channel] - Show current configuration
|
||||
/list [models|channels] - List available options
|
||||
`
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: msg,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Hello! I am PicoClaw 🦞",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /show [model|channel]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "model":
|
||||
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
|
||||
c.config.Agents.Defaults.GetModelName(),
|
||||
c.config.Agents.Defaults.Provider)
|
||||
case "channel":
|
||||
response = "Current Channel: telegram"
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) List(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /list [models|channels]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "models":
|
||||
provider := c.config.Agents.Defaults.Provider
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
|
||||
c.config.Agents.Defaults.GetModelName(), provider)
|
||||
|
||||
case "channels":
|
||||
var enabled []string
|
||||
if c.config.Channels.Telegram.Enabled {
|
||||
enabled = append(enabled, "telegram")
|
||||
}
|
||||
if c.config.Channels.WhatsApp.Enabled {
|
||||
enabled = append(enabled, "whatsapp")
|
||||
}
|
||||
if c.config.Channels.Feishu.Enabled {
|
||||
enabled = append(enabled, "feishu")
|
||||
}
|
||||
if c.config.Channels.Discord.Enabled {
|
||||
enabled = append(enabled, "discord")
|
||||
}
|
||||
if c.config.Channels.Slack.Enabled {
|
||||
enabled = append(enabled, "slack")
|
||||
}
|
||||
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
|
||||
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
)
|
||||
|
||||
func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
|
||||
chatIDs: make(map[string]int64),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: "/new",
|
||||
MessageID: 9,
|
||||
Chat: telego.Chat{
|
||||
ID: 123,
|
||||
Type: "private",
|
||||
},
|
||||
From: &telego.User{
|
||||
ID: 42,
|
||||
FirstName: "Alice",
|
||||
},
|
||||
}
|
||||
|
||||
if err := ch.handleMessage(context.Background(), msg); err != nil {
|
||||
t.Fatalf("handleMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "telegram" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/new" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
ta "github.com/mymmrac/telego/telegoapi"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type getMeCaller struct {
|
||||
username string
|
||||
}
|
||||
|
||||
func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) {
|
||||
if strings.HasSuffix(url, "/getMe") {
|
||||
result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username)
|
||||
return &ta.Response{Ok: true, Result: []byte(result)}, nil
|
||||
}
|
||||
return &ta.Response{Ok: true, Result: []byte("true")}, nil
|
||||
}
|
||||
|
||||
func newTestTelegramBot(t *testing.T, username string) *telego.Bot {
|
||||
t.Helper()
|
||||
|
||||
token := "123456:" + strings.Repeat("a", 35)
|
||||
bot, err := telego.NewBot(token,
|
||||
telego.WithAPICaller(getMeCaller{username: username}),
|
||||
telego.WithDiscardLogger(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("NewBot error: %v", err)
|
||||
}
|
||||
return bot
|
||||
}
|
||||
|
||||
func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) {
|
||||
t.Helper()
|
||||
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &TelegramChannel{
|
||||
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil,
|
||||
channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}),
|
||||
),
|
||||
bot: newTestTelegramBot(t, botUsername),
|
||||
chatIDs: make(map[string]int64),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
return ch, messageBus
|
||||
}
|
||||
|
||||
func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
wantForwarded bool
|
||||
wantContent string
|
||||
}{
|
||||
{
|
||||
name: "command with bot username",
|
||||
text: "/new@testbot",
|
||||
wantForwarded: true,
|
||||
wantContent: "/new",
|
||||
},
|
||||
{
|
||||
name: "bare command",
|
||||
text: "/new",
|
||||
wantForwarded: true,
|
||||
wantContent: "/new",
|
||||
},
|
||||
{
|
||||
name: "command for another bot",
|
||||
text: "/new@otherbot",
|
||||
wantForwarded: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ch, messageBus := newGroupMentionOnlyChannel(t, "testbot")
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: tc.text,
|
||||
Entities: []telego.MessageEntity{{
|
||||
Type: telego.EntityTypeBotCommand,
|
||||
Offset: 0,
|
||||
Length: len([]rune(tc.text)),
|
||||
}},
|
||||
MessageID: 42,
|
||||
Chat: telego.Chat{
|
||||
ID: 123,
|
||||
Type: "group",
|
||||
},
|
||||
From: &telego.User{
|
||||
ID: 7,
|
||||
FirstName: "Alice",
|
||||
},
|
||||
}
|
||||
|
||||
if err := ch.handleMessage(context.Background(), msg); err != nil {
|
||||
t.Fatalf("handleMessage error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if tc.wantForwarded {
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Content != tc.wantContent {
|
||||
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ok {
|
||||
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) {
|
||||
ch, _ := newGroupMentionOnlyChannel(t, "testbot")
|
||||
|
||||
msg := &telego.Message{
|
||||
Text: "@testbot hello",
|
||||
Entities: []telego.MessageEntity{{
|
||||
Type: telego.EntityTypeMention,
|
||||
Offset: 0,
|
||||
Length: len("@testbot"),
|
||||
}},
|
||||
}
|
||||
|
||||
if !ch.isBotMentioned(msg) {
|
||||
t.Fatal("expected mention entity to be treated as bot mention")
|
||||
}
|
||||
}
|
||||
@@ -793,7 +793,10 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err)
|
||||
}
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusTooManyRequests:
|
||||
return fmt.Errorf("response_url rate limited (%d): %s: %w",
|
||||
|
||||
@@ -321,8 +321,17 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody)))
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom upload error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return "", channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom upload error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
@@ -371,8 +380,17 @@ func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken stri
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody)))
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading wecom_app error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("wecom_app API error: %s", string(respBody)),
|
||||
)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -453,8 +453,17 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body)))
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("reading webhook error response: %w", readErr),
|
||||
)
|
||||
}
|
||||
return channels.ClassifySendError(
|
||||
resp.StatusCode,
|
||||
fmt.Errorf("webhook API error: %s", string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package whatsapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &WhatsAppChannel{
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
ch.handleIncomingMessage(map[string]any{
|
||||
"type": "message",
|
||||
"id": "mid1",
|
||||
"from": "user1",
|
||||
"chat": "chat1",
|
||||
"content": "/help",
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "whatsapp" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/help" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
//go:build whatsapp_native
|
||||
|
||||
package whatsapp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.mau.fi/whatsmeow/proto/waE2E"
|
||||
"go.mau.fi/whatsmeow/types"
|
||||
"go.mau.fi/whatsmeow/types/events"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
|
||||
messageBus := bus.NewMessageBus()
|
||||
ch := &WhatsAppNativeChannel{
|
||||
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil),
|
||||
runCtx: context.Background(),
|
||||
}
|
||||
|
||||
evt := &events.Message{
|
||||
Info: types.MessageInfo{
|
||||
MessageSource: types.MessageSource{
|
||||
Sender: types.NewJID("1001", types.DefaultUserServer),
|
||||
Chat: types.NewJID("1001", types.DefaultUserServer),
|
||||
},
|
||||
ID: "mid1",
|
||||
PushName: "Alice",
|
||||
},
|
||||
Message: &waE2E.Message{
|
||||
Conversation: proto.String("/new"),
|
||||
},
|
||||
}
|
||||
|
||||
ch.handleIncoming(evt)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbound, ok := messageBus.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected inbound message to be forwarded")
|
||||
}
|
||||
if inbound.Channel != "whatsapp_native" {
|
||||
t.Fatalf("channel=%q", inbound.Channel)
|
||||
}
|
||||
if inbound.Content != "/new" {
|
||||
t.Fatalf("content=%q", inbound.Content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package commands
|
||||
|
||||
// BuiltinDefinitions returns all built-in command definitions.
|
||||
// Each command group is defined in its own cmd_*.go file.
|
||||
// Definitions are stateless — runtime dependencies are provided
|
||||
// via the Runtime parameter passed to handlers at execution time.
|
||||
func BuiltinDefinitions() []Definition {
|
||||
return []Definition{
|
||||
startCommand(),
|
||||
helpCommand(),
|
||||
showCommand(),
|
||||
listCommand(),
|
||||
switchCommand(),
|
||||
checkCommand(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func findDefinitionByName(t *testing.T, defs []Definition, name string) Definition {
|
||||
t.Helper()
|
||||
for _, def := range defs {
|
||||
if def.Name == name {
|
||||
return def
|
||||
}
|
||||
}
|
||||
t.Fatalf("missing /%s definition", name)
|
||||
return Definition{}
|
||||
}
|
||||
|
||||
func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
|
||||
defs := BuiltinDefinitions()
|
||||
helpDef := findDefinitionByName(t, defs, "help")
|
||||
if helpDef.Handler == nil {
|
||||
t.Fatalf("/help handler should not be nil")
|
||||
}
|
||||
|
||||
var reply string
|
||||
err := helpDef.Handler(context.Background(), Request{
|
||||
Text: "/help",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("/help handler error: %v", err)
|
||||
}
|
||||
// Now uses auto-generated EffectiveUsage which includes agents
|
||||
if !strings.Contains(reply, "/show [model|channel|agents]") {
|
||||
t.Fatalf("/help reply missing /show usage, got %q", reply)
|
||||
}
|
||||
if !strings.Contains(reply, "/list [models|channels|agents]") {
|
||||
t.Fatalf("/help reply missing /list usage, got %q", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
cases := []string{"telegram", "whatsapp"}
|
||||
for _, channel := range cases {
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Channel: channel,
|
||||
Text: "/show channel",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/show channel on %s: outcome=%v, want=%v", channel, res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "Current Channel: " + channel
|
||||
if reply != want {
|
||||
t.Fatalf("/show channel reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinListChannels_UsesGetEnabledChannels(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
GetEnabledChannels: func() []string {
|
||||
return []string{"telegram", "slack"}
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/list channels",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/list channels: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "telegram") || !strings.Contains(reply, "slack") {
|
||||
t.Fatalf("/list channels reply=%q, want telegram and slack", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinShowAgents_RestoresOldBehavior(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
ListAgentIDs: func() []string {
|
||||
return []string{"default", "coder"}
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/show agents",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/show agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
|
||||
t.Fatalf("/show agents reply=%q, want agent IDs", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuiltinListAgents_RestoresOldBehavior(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
ListAgentIDs: func() []string {
|
||||
return []string{"default", "coder"}
|
||||
},
|
||||
}
|
||||
defs := BuiltinDefinitions()
|
||||
ex := NewExecutor(NewRegistry(defs), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/list agents",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("/list agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
|
||||
t.Fatalf("/list agents reply=%q, want agent IDs", reply)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func checkCommand() Definition {
|
||||
return Definition{
|
||||
Name: "check",
|
||||
Description: "Check channel availability",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "channel",
|
||||
Description: "Check if a channel is available",
|
||||
ArgsUsage: "<name>",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.SwitchChannel == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
value := nthToken(req.Text, 2)
|
||||
if value == "" {
|
||||
return req.Reply("Usage: /check channel <name>")
|
||||
}
|
||||
if err := rt.SwitchChannel(value); err != nil {
|
||||
return req.Reply(err.Error())
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Channel '%s' is available and enabled", value))
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func helpCommand() Definition {
|
||||
return Definition{
|
||||
Name: "help",
|
||||
Description: "Show this help message",
|
||||
Usage: "/help",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
var defs []Definition
|
||||
if rt != nil && rt.ListDefinitions != nil {
|
||||
defs = rt.ListDefinitions()
|
||||
} else {
|
||||
defs = BuiltinDefinitions()
|
||||
}
|
||||
return req.Reply(formatHelpMessage(defs))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func formatHelpMessage(defs []Definition) string {
|
||||
if len(defs) == 0 {
|
||||
return "No commands available."
|
||||
}
|
||||
|
||||
lines := make([]string, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
usage := def.EffectiveUsage()
|
||||
if usage == "" {
|
||||
usage = "/" + def.Name
|
||||
}
|
||||
desc := def.Description
|
||||
if desc == "" {
|
||||
desc = "No description"
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("%s - %s", usage, desc))
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func listCommand() Definition {
|
||||
return Definition{
|
||||
Name: "list",
|
||||
Description: "List available options",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "models",
|
||||
Description: "Configured models",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.GetModelInfo == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
name, provider := rt.GetModelInfo()
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
return req.Reply(fmt.Sprintf(
|
||||
"Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
|
||||
name, provider,
|
||||
))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "channels",
|
||||
Description: "Enabled channels",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.GetEnabledChannels == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
enabled := rt.GetEnabledChannels()
|
||||
if len(enabled) == 0 {
|
||||
return req.Reply("No channels enabled")
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agents",
|
||||
Description: "Registered agents",
|
||||
Handler: agentsHandler(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func showCommand() Definition {
|
||||
return Definition{
|
||||
Name: "show",
|
||||
Description: "Show current configuration",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "model",
|
||||
Description: "Current model and provider",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.GetModelInfo == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
name, provider := rt.GetModelInfo()
|
||||
return req.Reply(fmt.Sprintf("Current Model: %s (Provider: %s)", name, provider))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "channel",
|
||||
Description: "Current channel",
|
||||
Handler: func(_ context.Context, req Request, _ *Runtime) error {
|
||||
return req.Reply(fmt.Sprintf("Current Channel: %s", req.Channel))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agents",
|
||||
Description: "Registered agents",
|
||||
Handler: agentsHandler(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package commands
|
||||
|
||||
import "context"
|
||||
|
||||
func startCommand() Definition {
|
||||
return Definition{
|
||||
Name: "start",
|
||||
Description: "Start the bot",
|
||||
Usage: "/start",
|
||||
Handler: func(_ context.Context, req Request, _ *Runtime) error {
|
||||
return req.Reply("Hello! I am PicoClaw 🦞")
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func switchCommand() Definition {
|
||||
return Definition{
|
||||
Name: "switch",
|
||||
Description: "Switch model",
|
||||
SubCommands: []SubCommand{
|
||||
{
|
||||
Name: "model",
|
||||
Description: "Switch to a different model",
|
||||
ArgsUsage: "to <name>",
|
||||
Handler: func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.SwitchModel == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
// Parse: /switch model to <value>
|
||||
value := nthToken(req.Text, 3) // tokens: [/switch, model, to, <value>]
|
||||
if nthToken(req.Text, 2) != "to" || value == "" {
|
||||
return req.Reply("Usage: /switch model to <name>")
|
||||
}
|
||||
oldModel, err := rt.SwitchModel(value)
|
||||
if err != nil {
|
||||
return req.Reply(err.Error())
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Switched model from %s to %s", oldModel, value))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "channel",
|
||||
Description: "Moved to /check channel",
|
||||
Handler: func(_ context.Context, req Request, _ *Runtime) error {
|
||||
return req.Reply("This command has moved. Please use: /check channel <name>")
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSwitchModel_Success(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old-model", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "Switched model from old-model to gpt-4"
|
||||
if reply != want {
|
||||
t.Fatalf("reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_MissingToKeyword(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /switch model to <name>" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_MissingValue(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /switch model to <name>" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_Error(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "", fmt.Errorf("model not found")
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to bad-model",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "model not found" {
|
||||
t.Fatalf("reply=%q, want error message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchModel_NilDep(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch model to gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Command unavailable in current context." {
|
||||
t.Fatalf("reply=%q, want unavailable message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchChannel_Redirect(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch channel to telegram",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "This command has moved. Please use: /check channel <name>"
|
||||
if reply != want {
|
||||
t.Fatalf("reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_Success(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchChannel: func(value string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel telegram",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
want := "Channel 'telegram' is available and enabled"
|
||||
if reply != want {
|
||||
t.Fatalf("reply=%q, want=%q", reply, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_Error(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchChannel: func(value string) error {
|
||||
return fmt.Errorf("channel '%s' not found", value)
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel unknown",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "channel 'unknown' not found" {
|
||||
t.Fatalf("reply=%q, want error message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_NilDep(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel telegram",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Command unavailable in current context." {
|
||||
t.Fatalf("reply=%q, want unavailable message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckChannel_MissingValue(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchChannel: func(value string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/check channel",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /check channel <name>" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitch_BangPrefix(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
SwitchModel: func(value string) (string, error) {
|
||||
return "old", nil
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "!switch model to gpt-4",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("! prefix: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Switched model from old to gpt-4" {
|
||||
t.Fatalf("! prefix: reply=%q, want success message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitch_NoSubCommand(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/switch",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
// Should get usage message from executor's sub-command routing
|
||||
if reply == "" {
|
||||
t.Fatal("expected usage reply for bare /switch")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SubCommand defines a single sub-command within a parent command.
|
||||
type SubCommand struct {
|
||||
Name string
|
||||
Description string
|
||||
ArgsUsage string // optional, e.g. "<session-id>"
|
||||
Handler Handler
|
||||
}
|
||||
|
||||
// Definition is the single-source metadata and behavior contract for a slash command.
|
||||
//
|
||||
// Design notes (phase 1):
|
||||
// - Every channel reads command shape from this type instead of keeping local copies.
|
||||
// - Visibility is global: all definitions are considered available to all channels.
|
||||
// - Platform menu registration (for example Telegram BotCommand) also derives from this
|
||||
// same definition so UI labels and runtime behavior stay aligned.
|
||||
type Definition struct {
|
||||
Name string
|
||||
Description string
|
||||
Usage string // for simple commands; ignored when SubCommands is set
|
||||
Aliases []string
|
||||
SubCommands []SubCommand // optional; when set, Executor routes to sub-command handlers
|
||||
Handler Handler // for simple commands without sub-commands
|
||||
}
|
||||
|
||||
// EffectiveUsage returns the usage string. When SubCommands are present,
|
||||
// it is auto-generated from sub-command names so metadata and behavior
|
||||
// cannot drift.
|
||||
func (d Definition) EffectiveUsage() string {
|
||||
if len(d.SubCommands) == 0 {
|
||||
return d.Usage
|
||||
}
|
||||
names := make([]string, 0, len(d.SubCommands))
|
||||
for _, sc := range d.SubCommands {
|
||||
name := sc.Name
|
||||
if sc.ArgsUsage != "" {
|
||||
name += " " + sc.ArgsUsage
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return fmt.Sprintf("/%s [%s]", d.Name, strings.Join(names, "|"))
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefinition_EffectiveUsage_NoSubCommands(t *testing.T) {
|
||||
d := Definition{Name: "start", Usage: "/start"}
|
||||
if got := d.EffectiveUsage(); got != "/start" {
|
||||
t.Fatalf("EffectiveUsage()=%q, want %q", got, "/start")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefinition_EffectiveUsage_WithSubCommands(t *testing.T) {
|
||||
d := Definition{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"},
|
||||
{Name: "channel"},
|
||||
{Name: "agents"},
|
||||
},
|
||||
}
|
||||
want := "/show [model|channel|agents]"
|
||||
if got := d.EffectiveUsage(); got != want {
|
||||
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefinition_EffectiveUsage_WithArgsUsage(t *testing.T) {
|
||||
d := Definition{
|
||||
Name: "session",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "list"},
|
||||
{Name: "resume", ArgsUsage: "<id>"},
|
||||
},
|
||||
}
|
||||
want := "/session [list|resume <id>]"
|
||||
if got := d.EffectiveUsage(); got != want {
|
||||
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Outcome int
|
||||
|
||||
const (
|
||||
// OutcomePassthrough means this input should continue through normal agent flow.
|
||||
OutcomePassthrough Outcome = iota
|
||||
// OutcomeHandled means a command handler executed (with or without handler error).
|
||||
OutcomeHandled
|
||||
)
|
||||
|
||||
type ExecuteResult struct {
|
||||
Outcome Outcome
|
||||
Command string
|
||||
Err error
|
||||
}
|
||||
|
||||
type Executor struct {
|
||||
reg *Registry
|
||||
rt *Runtime
|
||||
}
|
||||
|
||||
func NewExecutor(reg *Registry, rt *Runtime) *Executor {
|
||||
return &Executor{reg: reg, rt: rt}
|
||||
}
|
||||
|
||||
// Execute implements a two-state command decision:
|
||||
// 1) handled: execute command immediately;
|
||||
// 2) passthrough: not a command or intentionally deferred to agent logic.
|
||||
func (e *Executor) Execute(ctx context.Context, req Request) ExecuteResult {
|
||||
cmdName, ok := parseCommandName(req.Text)
|
||||
if !ok {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough}
|
||||
}
|
||||
|
||||
if e == nil || e.reg == nil {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
|
||||
}
|
||||
|
||||
def, found := e.reg.Lookup(cmdName)
|
||||
if !found {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
|
||||
}
|
||||
|
||||
return e.executeDefinition(ctx, req, def)
|
||||
}
|
||||
|
||||
func (e *Executor) executeDefinition(ctx context.Context, req Request, def Definition) ExecuteResult {
|
||||
// Ensure Reply is always non-nil so handlers don't need to check.
|
||||
if req.Reply == nil {
|
||||
req.Reply = func(string) error { return nil }
|
||||
}
|
||||
|
||||
// Simple command — no sub-commands
|
||||
if len(def.SubCommands) == 0 {
|
||||
if def.Handler == nil {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
|
||||
}
|
||||
err := def.Handler(ctx, req, e.rt)
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
|
||||
// Sub-command routing
|
||||
subName := nthToken(req.Text, 1)
|
||||
if subName == "" {
|
||||
err := req.Reply("Usage: " + def.EffectiveUsage())
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
|
||||
normalized := normalizeCommandName(subName)
|
||||
for _, sc := range def.SubCommands {
|
||||
if normalizeCommandName(sc.Name) == normalized {
|
||||
if sc.Handler == nil {
|
||||
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
|
||||
}
|
||||
err := sc.Handler(ctx, req, e.rt)
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// Unknown sub-command
|
||||
err := req.Reply(fmt.Sprintf("Unknown option: %s. Usage: %s", subName, def.EffectiveUsage()))
|
||||
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExecutor_RegisteredWithoutHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{{Name: "show"}}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/show"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_UnknownSlashCommand_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{{Name: "show"}}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/unknown"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SupportedCommandWithHandler_ReturnsHandled(t *testing.T) {
|
||||
called := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "help",
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help@my_bot"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !called {
|
||||
t.Fatalf("expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_AliasWithoutHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
Aliases: []string{"display"},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/display"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if res.Command != "show" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "show")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_AliasWithHandler_ReturnsHandled(t *testing.T) {
|
||||
called := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "clear",
|
||||
Aliases: []string{"reset"},
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/reset"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if res.Command != "clear" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "clear")
|
||||
}
|
||||
if !called {
|
||||
t.Fatalf("expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SupportedCommandWithNilHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{Name: "placeholder"},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder list"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if res.Command != "placeholder" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_NilHandlerDoesNotMaskLaterHandler(t *testing.T) {
|
||||
// With Lookup-based dispatch, the first registered definition for a name wins.
|
||||
// A definition with nil Handler and no SubCommands returns Passthrough.
|
||||
defs := []Definition{
|
||||
{Name: "placeholder"},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if res.Command != "placeholder" {
|
||||
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_HandlerErrorIsPropagated(t *testing.T) {
|
||||
wantErr := errors.New("handler failed")
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "help",
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
return wantErr
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !errors.Is(res.Err, wantErr) {
|
||||
t.Fatalf("err=%v, want=%v", res.Err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SupportsBangPrefixAndCaseInsensitiveCommand(t *testing.T) {
|
||||
called := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "help",
|
||||
Handler: func(context.Context, Request, *Runtime) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "!HELP"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !called {
|
||||
t.Fatalf("expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_RoutesToCorrectHandler(t *testing.T) {
|
||||
modelCalled := false
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model", Handler: func(_ context.Context, _ Request, _ *Runtime) error {
|
||||
modelCalled = true
|
||||
return nil
|
||||
}},
|
||||
{Name: "channel"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Text: "/show model"})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !modelCalled {
|
||||
t.Fatal("model sub-command handler was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_NoArg_RepliesUsage(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"},
|
||||
{Name: "channel"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/show",
|
||||
Reply: func(text string) error { reply = text; return nil },
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if reply != "Usage: /show [model|channel]" {
|
||||
t.Fatalf("reply=%q, want usage message", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_UnknownArg_RepliesError(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"},
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Text: "/show foobar",
|
||||
Reply: func(text string) error { reply = text; return nil },
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if !strings.Contains(reply, "foobar") {
|
||||
t.Fatalf("reply=%q, should mention unknown sub-command", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_SubCommand_NilHandler_ReturnsPassthrough(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{
|
||||
Name: "show",
|
||||
SubCommands: []SubCommand{
|
||||
{Name: "model"}, // nil Handler
|
||||
},
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(defs), nil)
|
||||
|
||||
res := ex.Execute(context.Background(), Request{Text: "/show model"})
|
||||
if res.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// agentsHandler returns a shared handler for both /show agents and /list agents.
|
||||
func agentsHandler() Handler {
|
||||
return func(_ context.Context, req Request, rt *Runtime) error {
|
||||
if rt == nil || rt.ListAgentIDs == nil {
|
||||
return req.Reply(unavailableMsg)
|
||||
}
|
||||
ids := rt.ListAgentIDs()
|
||||
if len(ids) == 0 {
|
||||
return req.Reply("No agents registered")
|
||||
}
|
||||
return req.Reply(fmt.Sprintf("Registered agents: %s", strings.Join(ids, ", ")))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package commands
|
||||
|
||||
type Registry struct {
|
||||
defs []Definition
|
||||
index map[string]int
|
||||
}
|
||||
|
||||
// NewRegistry stores the canonical command set used by both dispatch and
|
||||
// optional platform registration adapters.
|
||||
func NewRegistry(defs []Definition) *Registry {
|
||||
stored := make([]Definition, len(defs))
|
||||
copy(stored, defs)
|
||||
|
||||
index := make(map[string]int, len(stored)*2)
|
||||
for i, def := range stored {
|
||||
registerCommandName(index, def.Name, i)
|
||||
for _, alias := range def.Aliases {
|
||||
registerCommandName(index, alias, i)
|
||||
}
|
||||
}
|
||||
|
||||
return &Registry{defs: stored, index: index}
|
||||
}
|
||||
|
||||
// Definitions returns all registered command definitions.
|
||||
// Command availability is global and no longer channel-scoped.
|
||||
func (r *Registry) Definitions() []Definition {
|
||||
out := make([]Definition, len(r.defs))
|
||||
copy(out, r.defs)
|
||||
return out
|
||||
}
|
||||
|
||||
// Lookup returns a command definition by normalized command name or alias.
|
||||
func (r *Registry) Lookup(name string) (Definition, bool) {
|
||||
key := normalizeCommandName(name)
|
||||
if key == "" {
|
||||
return Definition{}, false
|
||||
}
|
||||
idx, ok := r.index[key]
|
||||
if !ok {
|
||||
return Definition{}, false
|
||||
}
|
||||
return r.defs[idx], true
|
||||
}
|
||||
|
||||
func registerCommandName(index map[string]int, name string, defIndex int) {
|
||||
key := normalizeCommandName(name)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := index[key]; exists {
|
||||
return
|
||||
}
|
||||
index[key] = defIndex
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package commands
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRegistry_Definitions_ReturnsCopy(t *testing.T) {
|
||||
defs := []Definition{
|
||||
{Name: "help", Description: "Show help"},
|
||||
{Name: "admin", Description: "Admin command"},
|
||||
}
|
||||
r := NewRegistry(defs)
|
||||
|
||||
got := r.Definitions()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("definitions len = %d, want 2", len(got))
|
||||
}
|
||||
|
||||
got[0].Name = "mutated"
|
||||
again := r.Definitions()
|
||||
if again[0].Name != "help" {
|
||||
t.Fatalf("registry should not be mutated by caller, got first name %q", again[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Lookup_MatchesByLowercaseNameAndAlias(t *testing.T) {
|
||||
r := NewRegistry([]Definition{
|
||||
{Name: "Help", Aliases: []string{"Assist"}},
|
||||
{Name: "List"},
|
||||
})
|
||||
|
||||
def, ok := r.Lookup("help")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by lowercase name failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
|
||||
def, ok = r.Lookup("HELP")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by uppercase name failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
|
||||
def, ok = r.Lookup("assist")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by lowercase alias failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
|
||||
def, ok = r.Lookup("ASSIST")
|
||||
if !ok || def.Name != "Help" {
|
||||
t.Fatalf("lookup by uppercase alias failed: ok=%v def=%+v", ok, def)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Handler func(ctx context.Context, req Request, rt *Runtime) error
|
||||
|
||||
type Request struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
SenderID string
|
||||
Text string
|
||||
Reply func(text string) error
|
||||
}
|
||||
|
||||
const unavailableMsg = "Command unavailable in current context."
|
||||
|
||||
var commandPrefixes = []string{"/", "!"}
|
||||
|
||||
// parseCommandName accepts "/name", "!name", and Telegram's "/name@bot", then
|
||||
// normalizes to lowercase command names.
|
||||
func parseCommandName(input string) (string, bool) {
|
||||
token := nthToken(input, 0)
|
||||
if token == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
name, ok := trimCommandPrefix(token)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if i := strings.Index(name, "@"); i >= 0 {
|
||||
name = name[:i]
|
||||
}
|
||||
name = normalizeCommandName(name)
|
||||
if name == "" {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func trimCommandPrefix(token string) (string, bool) {
|
||||
for _, prefix := range commandPrefixes {
|
||||
if strings.HasPrefix(token, prefix) {
|
||||
return strings.TrimPrefix(token, prefix), true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// HasCommandPrefix returns true if the input starts with a recognized
|
||||
// command prefix (e.g. "/" or "!").
|
||||
func HasCommandPrefix(input string) bool {
|
||||
token := nthToken(input, 0)
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := trimCommandPrefix(token)
|
||||
return ok
|
||||
}
|
||||
|
||||
// nthToken returns the 0-indexed token from whitespace-split input.
|
||||
func nthToken(input string, n int) string {
|
||||
parts := strings.Fields(strings.TrimSpace(input))
|
||||
if n >= len(parts) {
|
||||
return ""
|
||||
}
|
||||
return parts[n]
|
||||
}
|
||||
|
||||
func normalizeCommandName(name string) string {
|
||||
return strings.ToLower(strings.TrimSpace(name))
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package commands
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHasCommandPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"/help", true},
|
||||
{"!help", true},
|
||||
{"/switch model to gpt-4", true},
|
||||
{"!switch model to gpt-4", true},
|
||||
{"hello", false},
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{"hello /world", false},
|
||||
{"/", true},
|
||||
{"!", true},
|
||||
{" /help", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := HasCommandPrefix(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("HasCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package commands
|
||||
|
||||
import "github.com/sipeed/picoclaw/pkg/config"
|
||||
|
||||
// Runtime provides runtime dependencies to command handlers. It is constructed
|
||||
// per-request by the agent loop so that per-request state (like session scope)
|
||||
// can coexist with long-lived callbacks (like GetModelInfo).
|
||||
type Runtime struct {
|
||||
Config *config.Config
|
||||
GetModelInfo func() (name, provider string)
|
||||
ListAgentIDs func() []string
|
||||
ListDefinitions func() []Definition
|
||||
GetEnabledChannels func() []string
|
||||
SwitchModel func(value string) (oldModel string, err error)
|
||||
SwitchChannel func(value string) error
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShowListHandlers_ChannelPolicy(t *testing.T) {
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), nil)
|
||||
|
||||
var telegramReply string
|
||||
handled := ex.Execute(context.Background(), Request{
|
||||
Channel: "telegram",
|
||||
Text: "/show channel",
|
||||
Reply: func(text string) error {
|
||||
telegramReply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if handled.Outcome != OutcomeHandled {
|
||||
t.Fatalf("telegram /show outcome=%v, want=%v", handled.Outcome, OutcomeHandled)
|
||||
}
|
||||
if telegramReply != "Current Channel: telegram" {
|
||||
t.Fatalf("telegram /show reply=%q, want=%q", telegramReply, "Current Channel: telegram")
|
||||
}
|
||||
|
||||
var whatsappReply string
|
||||
handledWhatsApp := ex.Execute(context.Background(), Request{
|
||||
Channel: "whatsapp",
|
||||
Text: "/show channel",
|
||||
Reply: func(text string) error {
|
||||
whatsappReply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if handledWhatsApp.Outcome != OutcomeHandled {
|
||||
t.Fatalf("whatsapp /show outcome=%v, want=%v", handledWhatsApp.Outcome, OutcomeHandled)
|
||||
}
|
||||
if handledWhatsApp.Command != "show" {
|
||||
t.Fatalf("whatsapp /show command=%q, want=%q", handledWhatsApp.Command, "show")
|
||||
}
|
||||
if whatsappReply != "Current Channel: whatsapp" {
|
||||
t.Fatalf("whatsapp /show reply=%q, want=%q", whatsappReply, "Current Channel: whatsapp")
|
||||
}
|
||||
|
||||
passthrough := ex.Execute(context.Background(), Request{
|
||||
Channel: "whatsapp",
|
||||
Text: "/foo",
|
||||
})
|
||||
if passthrough.Outcome != OutcomePassthrough {
|
||||
t.Fatalf("whatsapp /foo outcome=%v, want=%v", passthrough.Outcome, OutcomePassthrough)
|
||||
}
|
||||
if passthrough.Command != "foo" {
|
||||
t.Fatalf("whatsapp /foo command=%q, want=%q", passthrough.Command, "foo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowListHandlers_ListHandledOnAllChannels(t *testing.T) {
|
||||
rt := &Runtime{
|
||||
GetEnabledChannels: func() []string {
|
||||
return []string{"telegram"}
|
||||
},
|
||||
}
|
||||
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
|
||||
|
||||
var reply string
|
||||
res := ex.Execute(context.Background(), Request{
|
||||
Channel: "whatsapp",
|
||||
Text: "/list channels",
|
||||
Reply: func(text string) error {
|
||||
reply = text
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if res.Outcome != OutcomeHandled {
|
||||
t.Fatalf("whatsapp /list outcome=%v, want=%v", res.Outcome, OutcomeHandled)
|
||||
}
|
||||
if res.Command != "list" {
|
||||
t.Fatalf("whatsapp /list command=%q, want=%q", res.Command, "list")
|
||||
}
|
||||
if !strings.Contains(reply, "telegram") {
|
||||
t.Fatalf("whatsapp /list reply=%q, expected enabled channels content", reply)
|
||||
}
|
||||
}
|
||||
+121
-35
@@ -167,22 +167,35 @@ type SessionConfig struct {
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
// RoutingConfig controls the intelligent model routing feature.
|
||||
// When enabled, each incoming message is scored against structural features
|
||||
// (message length, code blocks, tool call history, conversation depth, attachments).
|
||||
// Messages scoring below Threshold are sent to LightModel; all others use the
|
||||
// agent's primary model. This reduces cost and latency for simple tasks without
|
||||
// requiring any keyword matching — all scoring is language-agnostic.
|
||||
type RoutingConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
|
||||
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
|
||||
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
Routing *RoutingConfig `json:"routing,omitempty"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
@@ -526,6 +539,10 @@ type GatewayConfig struct {
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
Enabled bool `json:"enabled" env:"ENABLED"`
|
||||
}
|
||||
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
@@ -550,6 +567,12 @@ type PerplexityConfig struct {
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type SearXNGConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_SEARXNG_ENABLED"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_SEARXNG_BASE_URL"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARXNG_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type GLMSearchConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
|
||||
@@ -561,11 +584,13 @@ type GLMSearchConfig struct {
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
Brave BraveConfig `json:"brave"`
|
||||
Tavily TavilyConfig `json:"tavily"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
GLMSearch GLMSearchConfig `json:"glm_search"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"`
|
||||
Brave BraveConfig ` json:"brave"`
|
||||
Tavily TavilyConfig ` json:"tavily"`
|
||||
DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig ` json:"perplexity"`
|
||||
SearXNG SearXNGConfig ` json:"searxng"`
|
||||
GLMSearch GLMSearchConfig ` json:"glm_search"`
|
||||
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
|
||||
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
|
||||
@@ -573,19 +598,29 @@ type WebToolsConfig struct {
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
}
|
||||
|
||||
type ExecConfig struct {
|
||||
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
|
||||
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
|
||||
CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"`
|
||||
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
|
||||
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
|
||||
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
|
||||
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"`
|
||||
Registries SkillsRegistriesConfig ` json:"registries"`
|
||||
MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"`
|
||||
SearchCache SearchCacheConfig ` json:"search_cache"`
|
||||
}
|
||||
|
||||
type MediaCleanupConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"`
|
||||
MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"`
|
||||
Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"`
|
||||
ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"`
|
||||
MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"`
|
||||
Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
@@ -597,12 +632,19 @@ type ToolsConfig struct {
|
||||
Skills SkillsToolsConfig `json:"skills"`
|
||||
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
Registries SkillsRegistriesConfig `json:"registries"`
|
||||
MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
|
||||
SearchCache SearchCacheConfig `json:"search_cache"`
|
||||
AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"`
|
||||
EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"`
|
||||
FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"`
|
||||
I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"`
|
||||
InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
|
||||
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
|
||||
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
|
||||
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
|
||||
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
|
||||
WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"`
|
||||
}
|
||||
|
||||
type SearchCacheConfig struct {
|
||||
@@ -648,8 +690,7 @@ type MCPServerConfig struct {
|
||||
|
||||
// MCPConfig defines configuration for all MCP servers
|
||||
type MCPConfig struct {
|
||||
// Enabled globally enables/disables MCP integration
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
@@ -835,3 +876,48 @@ func (c *Config) ValidateModelList() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
switch name {
|
||||
case "web":
|
||||
return t.Web.Enabled
|
||||
case "cron":
|
||||
return t.Cron.Enabled
|
||||
case "exec":
|
||||
return t.Exec.Enabled
|
||||
case "skills":
|
||||
return t.Skills.Enabled
|
||||
case "media_cleanup":
|
||||
return t.MediaCleanup.Enabled
|
||||
case "append_file":
|
||||
return t.AppendFile.Enabled
|
||||
case "edit_file":
|
||||
return t.EditFile.Enabled
|
||||
case "find_skills":
|
||||
return t.FindSkills.Enabled
|
||||
case "i2c":
|
||||
return t.I2C.Enabled
|
||||
case "install_skill":
|
||||
return t.InstallSkill.Enabled
|
||||
case "list_dir":
|
||||
return t.ListDir.Enabled
|
||||
case "message":
|
||||
return t.Message.Enabled
|
||||
case "read_file":
|
||||
return t.ReadFile.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spi":
|
||||
return t.SPI.Enabled
|
||||
case "subagent":
|
||||
return t.Subagent.Enabled
|
||||
case "web_fetch":
|
||||
return t.WebFetch.Enabled
|
||||
case "write_file":
|
||||
return t.WriteFile.Enabled
|
||||
case "mcp":
|
||||
return t.MCP.Enabled
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
+63
-2
@@ -336,11 +336,16 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
MediaCleanup: MediaCleanupConfig{
|
||||
Enabled: true,
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
MaxAge: 30,
|
||||
Interval: 5,
|
||||
},
|
||||
Web: WebToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Proxy: "",
|
||||
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
|
||||
Brave: BraveConfig{
|
||||
@@ -357,6 +362,11 @@ func DefaultConfig() *Config {
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
SearXNG: SearXNGConfig{
|
||||
Enabled: false,
|
||||
BaseURL: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
GLMSearch: GLMSearchConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
@@ -366,12 +376,22 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
ExecTimeoutMinutes: 5,
|
||||
},
|
||||
Exec: ExecConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
EnableDenyPatterns: true,
|
||||
TimeoutSeconds: 60,
|
||||
},
|
||||
Skills: SkillsToolsConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Registries: SkillsRegistriesConfig{
|
||||
ClawHub: ClawHubRegistryConfig{
|
||||
Enabled: true,
|
||||
@@ -385,9 +405,50 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: false,
|
||||
ToolConfig: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
AppendFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
EditFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
FindSkills: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
I2C: ToolConfig{
|
||||
Enabled: false, // Hardware tool - Linux only
|
||||
},
|
||||
InstallSkill: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
ListDir: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Message: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
ReadFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
SPI: ToolConfig{
|
||||
Enabled: false, // Hardware tool - Linux only
|
||||
},
|
||||
Subagent: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
WebFetch: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
WriteFile: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
|
||||
+13
-3
@@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
mcpCfg := config.MCPConfig{
|
||||
Enabled: true,
|
||||
ToolConfig: config.ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
Servers: map[string]config.MCPServerConfig{
|
||||
"test-server": {
|
||||
Enabled: true,
|
||||
@@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) {
|
||||
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
|
||||
err := mgr.LoadFromMCPConfig(
|
||||
context.Background(),
|
||||
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}},
|
||||
"/tmp",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
|
||||
}
|
||||
|
||||
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
|
||||
err = mgr.LoadFromMCPConfig(
|
||||
context.Background(),
|
||||
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}},
|
||||
"/tmp",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when no servers configured, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -640,7 +640,10 @@ func FetchAntigravityProjectID(accessToken string) (string, error) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading loadCodeAssist response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("loadCodeAssist failed: %s", string(body))
|
||||
}
|
||||
@@ -681,7 +684,10 @@ func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelIn
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading fetchAvailableModels response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf(
|
||||
"fetchAvailableModels failed (HTTP %d): %s",
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package routing
|
||||
|
||||
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
|
||||
// A higher score indicates a more complex task that benefits from a heavy model.
|
||||
// The score is compared against the configured threshold: score >= threshold selects
|
||||
// the primary (heavy) model; score < threshold selects the light model.
|
||||
//
|
||||
// Classifier is an interface so that future implementations (ML-based, embedding-based,
|
||||
// or any other approach) can be swapped in without changing routing infrastructure.
|
||||
type Classifier interface {
|
||||
Score(f Features) float64
|
||||
}
|
||||
|
||||
// RuleClassifier is the v1 implementation.
|
||||
// It uses a weighted sum of structural signals with no external dependencies,
|
||||
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
|
||||
// that the returned score always falls within the [0, 1] contract.
|
||||
//
|
||||
// Individual weights (multiple signals can fire simultaneously):
|
||||
//
|
||||
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
|
||||
// token 50-200: 0.15 — medium length; may or may not be complex
|
||||
// code block present: 0.40 — coding tasks need the heavy model
|
||||
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
|
||||
// tool calls 1-3 (recent): 0.10 — some tool activity
|
||||
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
|
||||
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
|
||||
//
|
||||
// Default threshold is 0.35, so:
|
||||
// - Pure greetings / trivial Q&A: 0.00 → light ✓
|
||||
// - Medium prose message (50–200 tokens): 0.15 → light ✓
|
||||
// - Message with code block: 0.40 → heavy ✓
|
||||
// - Long message (>200 tokens): 0.35 → heavy ✓
|
||||
// - Active tool session + medium message: 0.25 → light (acceptable)
|
||||
// - Any message with an image/audio attachment: 1.00 → heavy ✓
|
||||
type RuleClassifier struct{}
|
||||
|
||||
// Score computes the complexity score for the given feature set.
|
||||
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
|
||||
func (c *RuleClassifier) Score(f Features) float64 {
|
||||
// Hard gate: multi-modal inputs always require the heavy model.
|
||||
if f.HasAttachments {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
var score float64
|
||||
|
||||
// Token estimate — primary verbosity signal
|
||||
switch {
|
||||
case f.TokenEstimate > 200:
|
||||
score += 0.35
|
||||
case f.TokenEstimate > 50:
|
||||
score += 0.15
|
||||
}
|
||||
|
||||
// Fenced code blocks — strongest indicator of a coding/technical task
|
||||
if f.CodeBlockCount > 0 {
|
||||
score += 0.40
|
||||
}
|
||||
|
||||
// Recent tool call density — indicates an ongoing agentic workflow
|
||||
switch {
|
||||
case f.RecentToolCalls > 3:
|
||||
score += 0.25
|
||||
case f.RecentToolCalls > 0:
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Conversation depth — accumulated context implies compound task
|
||||
if f.ConversationDepth > 10 {
|
||||
score += 0.10
|
||||
}
|
||||
|
||||
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
|
||||
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
|
||||
if score > 1.0 {
|
||||
score = 1.0
|
||||
}
|
||||
return score
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// lookbackWindow is the number of recent history entries scanned for tool calls.
|
||||
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
|
||||
const lookbackWindow = 6
|
||||
|
||||
// Features holds the structural signals extracted from a message and its session context.
|
||||
// Every dimension is language-agnostic by construction — no keyword or pattern matching
|
||||
// against natural-language content. This ensures consistent routing for all locales.
|
||||
type Features struct {
|
||||
// TokenEstimate is a proxy for token count.
|
||||
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
|
||||
// This avoids API calls while giving accurate estimates for all scripts.
|
||||
TokenEstimate int
|
||||
|
||||
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
|
||||
// Coding tasks almost always require the heavy model.
|
||||
CodeBlockCount int
|
||||
|
||||
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
|
||||
// history entries. A high density indicates an active agentic workflow.
|
||||
RecentToolCalls int
|
||||
|
||||
// ConversationDepth is the total number of messages in the session history.
|
||||
// Deep sessions tend to carry implicit complexity built up over many turns.
|
||||
ConversationDepth int
|
||||
|
||||
// HasAttachments is true when the message appears to contain media (images,
|
||||
// audio, video). Multi-modal inputs require vision-capable heavy models.
|
||||
HasAttachments bool
|
||||
}
|
||||
|
||||
// ExtractFeatures computes the structural feature vector for a message.
|
||||
// It is a pure function with no side effects and zero allocations beyond
|
||||
// the returned struct.
|
||||
func ExtractFeatures(msg string, history []providers.Message) Features {
|
||||
return Features{
|
||||
TokenEstimate: estimateTokens(msg),
|
||||
CodeBlockCount: countCodeBlocks(msg),
|
||||
RecentToolCalls: countRecentToolCalls(history),
|
||||
ConversationDepth: len(history),
|
||||
HasAttachments: hasAttachments(msg),
|
||||
}
|
||||
}
|
||||
|
||||
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
|
||||
// CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) map to roughly one
|
||||
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
|
||||
// for English). Splitting the count this way avoids the 3x underestimation that a
|
||||
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
|
||||
func estimateTokens(msg string) int {
|
||||
total := utf8.RuneCountInString(msg)
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
cjk := 0
|
||||
for _, r := range msg {
|
||||
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
|
||||
cjk++
|
||||
}
|
||||
}
|
||||
return cjk + (total-cjk)/4
|
||||
}
|
||||
|
||||
// countCodeBlocks counts the number of complete fenced code blocks.
|
||||
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
|
||||
// An unclosed opening fence (odd count) is treated as zero complete blocks
|
||||
// since it may just be an inline code span or a typo.
|
||||
func countCodeBlocks(msg string) int {
|
||||
n := strings.Count(msg, "```")
|
||||
return n / 2
|
||||
}
|
||||
|
||||
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
|
||||
// entries of history. It examines the ToolCalls field rather than parsing
|
||||
// the content string, so it is robust to any message format.
|
||||
func countRecentToolCalls(history []providers.Message) int {
|
||||
start := len(history) - lookbackWindow
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, msg := range history[start:] {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
count += len(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// hasAttachments returns true when the message content contains embedded media.
|
||||
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
|
||||
// common image/audio URL extensions. This is intentionally conservative —
|
||||
// false negatives (missing an attachment) just mean the routing falls back to
|
||||
// the primary model anyway.
|
||||
func hasAttachments(msg string) bool {
|
||||
lower := strings.ToLower(msg)
|
||||
|
||||
// Base64 data URIs embedded directly in the message
|
||||
if strings.Contains(lower, "data:image/") ||
|
||||
strings.Contains(lower, "data:audio/") ||
|
||||
strings.Contains(lower, "data:video/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Common image/audio extensions in URLs or file references
|
||||
mediaExts := []string{
|
||||
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
|
||||
".mp3", ".wav", ".ogg", ".m4a", ".flac",
|
||||
".mp4", ".avi", ".mov", ".webm",
|
||||
}
|
||||
for _, ext := range mediaExts {
|
||||
if strings.Contains(lower, ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// defaultThreshold is used when the config threshold is zero or negative.
|
||||
// At 0.35 a message needs at least one strong signal (code block, long text,
|
||||
// or an attachment) before the heavy model is chosen.
|
||||
const defaultThreshold = 0.35
|
||||
|
||||
// RouterConfig holds the validated model routing settings.
|
||||
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
|
||||
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
|
||||
type RouterConfig struct {
|
||||
// LightModel is the model_name (from model_list) used for simple tasks.
|
||||
LightModel string
|
||||
|
||||
// Threshold is the complexity score cutoff in [0, 1].
|
||||
// score >= Threshold → primary (heavy) model.
|
||||
// score < Threshold → light model.
|
||||
Threshold float64
|
||||
}
|
||||
|
||||
// Router selects the appropriate model tier for each incoming message.
|
||||
// It is safe for concurrent use from multiple goroutines.
|
||||
type Router struct {
|
||||
cfg RouterConfig
|
||||
classifier Classifier
|
||||
}
|
||||
|
||||
// New creates a Router with the given config and the default RuleClassifier.
|
||||
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
|
||||
func New(cfg RouterConfig) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{
|
||||
cfg: cfg,
|
||||
classifier: &RuleClassifier{},
|
||||
}
|
||||
}
|
||||
|
||||
// newWithClassifier creates a Router with a custom Classifier.
|
||||
// Intended for unit tests that need to inject a deterministic scorer.
|
||||
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
|
||||
if cfg.Threshold <= 0 {
|
||||
cfg.Threshold = defaultThreshold
|
||||
}
|
||||
return &Router{cfg: cfg, classifier: c}
|
||||
}
|
||||
|
||||
// SelectModel returns the model to use for this conversation turn along with
|
||||
// the computed complexity score (for logging and debugging).
|
||||
//
|
||||
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
|
||||
// - Otherwise: returns (primaryModel, false, score)
|
||||
//
|
||||
// The caller is responsible for resolving the returned model name into
|
||||
// provider candidates (see AgentInstance.LightCandidates).
|
||||
func (r *Router) SelectModel(
|
||||
msg string,
|
||||
history []providers.Message,
|
||||
primaryModel string,
|
||||
) (model string, usedLight bool, score float64) {
|
||||
features := ExtractFeatures(msg, history)
|
||||
score = r.classifier.Score(features)
|
||||
if score < r.cfg.Threshold {
|
||||
return r.cfg.LightModel, true, score
|
||||
}
|
||||
return primaryModel, false, score
|
||||
}
|
||||
|
||||
// LightModel returns the configured light model name.
|
||||
func (r *Router) LightModel() string {
|
||||
return r.cfg.LightModel
|
||||
}
|
||||
|
||||
// Threshold returns the complexity threshold in use.
|
||||
func (r *Router) Threshold() float64 {
|
||||
return r.cfg.Threshold
|
||||
}
|
||||
@@ -0,0 +1,414 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// ── ExtractFeatures ──────────────────────────────────────────────────────────
|
||||
|
||||
func TestExtractFeatures_EmptyMessage(t *testing.T) {
|
||||
f := ExtractFeatures("", nil)
|
||||
if f.TokenEstimate != 0 {
|
||||
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
|
||||
}
|
||||
if f.CodeBlockCount != 0 {
|
||||
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
|
||||
}
|
||||
if f.RecentToolCalls != 0 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
|
||||
}
|
||||
if f.ConversationDepth != 0 {
|
||||
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
|
||||
}
|
||||
if f.HasAttachments {
|
||||
t.Error("HasAttachments: got true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate(t *testing.T) {
|
||||
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
|
||||
msg := strings.Repeat("a", 30)
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 7 {
|
||||
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
|
||||
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
|
||||
// Using a rune slice literal avoids CJK string literals in source.
|
||||
msg := string([]rune{
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60, 0x597D, 0x4E16, 0x754C,
|
||||
0x4F60,
|
||||
})
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 9 {
|
||||
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
|
||||
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
|
||||
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
|
||||
f := ExtractFeatures(msg, nil)
|
||||
if f.TokenEstimate != 6 {
|
||||
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_CodeBlocks(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"no code here", 0},
|
||||
{"```go\nfmt.Println()\n```", 1},
|
||||
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
|
||||
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.CodeBlockCount != tc.want {
|
||||
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
|
||||
// History longer than lookbackWindow — only last lookbackWindow entries count.
|
||||
history := make([]providers.Message, 10)
|
||||
// Put 2 tool calls at positions 8 and 9 (within the last 6)
|
||||
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
|
||||
history[9] = providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
|
||||
}
|
||||
// Position 3 is outside the lookback window and must NOT be counted
|
||||
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
|
||||
|
||||
f := ExtractFeatures("test", history)
|
||||
// 1 (position 8) + 2 (position 9) = 3
|
||||
if f.RecentToolCalls != 3 {
|
||||
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_ConversationDepth(t *testing.T) {
|
||||
history := make([]providers.Message, 7)
|
||||
f := ExtractFeatures("msg", history)
|
||||
if f.ConversationDepth != 7 {
|
||||
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"plain text", false},
|
||||
{"here is an image: data:image/png;base64,abc123", true},
|
||||
{"audio: data:audio/mp3;base64,xyz", true},
|
||||
{"video: data:video/mp4;base64,xyz", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want bool
|
||||
}{
|
||||
{"check out photo.jpg", true},
|
||||
{"see screenshot.png", true},
|
||||
{"listen to audio.mp3", true},
|
||||
{"watch clip.mp4", true},
|
||||
{"just a .go file", false},
|
||||
{"document.pdf", false}, // pdf is not in the media list
|
||||
}
|
||||
for _, tc := range cases {
|
||||
f := ExtractFeatures(tc.msg, nil)
|
||||
if f.HasAttachments != tc.want {
|
||||
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── RuleClassifier ───────────────────────────────────────────────────────────
|
||||
|
||||
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{})
|
||||
if score != 0.0 {
|
||||
t.Errorf("zero features: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
score := c.Score(Features{HasAttachments: true})
|
||||
if score != 1.0 {
|
||||
t.Errorf("attachments: got %f, want 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Code block alone = 0.40, above default threshold 0.35
|
||||
score := c.Score(Features{CodeBlockCount: 1})
|
||||
if score < 0.35 {
|
||||
t.Errorf("code block: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_LongMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// >200 tokens = 0.35, exactly at default threshold → heavy
|
||||
score := c.Score(Features{TokenEstimate: 250})
|
||||
if score < 0.35 {
|
||||
t.Errorf("long message: score %f is below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_MediumMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// 50-200 tokens = 0.15, below threshold → light
|
||||
score := c.Score(Features{TokenEstimate: 100})
|
||||
if score >= 0.35 {
|
||||
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ShortMessage(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// <50 tokens, no other signals = 0.0 → light
|
||||
score := c.Score(Features{TokenEstimate: 10})
|
||||
if score != 0.0 {
|
||||
t.Errorf("short message: got %f, want 0.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
|
||||
scoreNone := c.Score(Features{RecentToolCalls: 0})
|
||||
scoreLow := c.Score(Features{RecentToolCalls: 2})
|
||||
scoreHigh := c.Score(Features{RecentToolCalls: 5})
|
||||
|
||||
if scoreNone != 0.0 {
|
||||
t.Errorf("no tools: got %f, want 0.0", scoreNone)
|
||||
}
|
||||
if scoreLow <= scoreNone {
|
||||
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
|
||||
}
|
||||
if scoreHigh <= scoreLow {
|
||||
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_DeepConversation(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
shallow := c.Score(Features{ConversationDepth: 5})
|
||||
deep := c.Score(Features{ConversationDepth: 15})
|
||||
if deep <= shallow {
|
||||
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
|
||||
c := &RuleClassifier{}
|
||||
// Max all signals simultaneously
|
||||
f := Features{
|
||||
TokenEstimate: 500,
|
||||
CodeBlockCount: 3,
|
||||
RecentToolCalls: 10,
|
||||
ConversationDepth: 20,
|
||||
}
|
||||
score := c.Score(f)
|
||||
if score > 1.0 {
|
||||
t.Errorf("score %f exceeds 1.0", score)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Router ───────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestRouter_DefaultThreshold(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash"})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
|
||||
if r.Threshold() != defaultThreshold {
|
||||
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "hi"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("simple message: expected light model to be selected")
|
||||
}
|
||||
if model != "gemini-flash" {
|
||||
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "```go\nfmt.Println(\"hello\")\n```"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("code block: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
msg := "can you analyze this? data:image/png;base64,abc123"
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("attachment: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
// >200 token estimate: 210 * 3 = 630 chars
|
||||
msg := strings.Repeat("word ", 210)
|
||||
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("long message: expected primary model to be selected")
|
||||
}
|
||||
if model != "claude-sonnet-4-6" {
|
||||
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
|
||||
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
|
||||
// Routing is conservative: only promote to heavy when the signal is unambiguous.
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
|
||||
}
|
||||
msg := "ok"
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
|
||||
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
|
||||
history := []providers.Message{
|
||||
{Role: "assistant", ToolCalls: []providers.ToolCall{
|
||||
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
|
||||
}},
|
||||
}
|
||||
// ~55 tokens * 3 = 165 chars
|
||||
msg := strings.Repeat("word ", 55)
|
||||
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
|
||||
// Very low threshold: even a short message triggers heavy model
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
|
||||
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if usedLight {
|
||||
t.Error("low threshold: medium message should use primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
|
||||
// Very high threshold: even code blocks route to light
|
||||
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
|
||||
msg := "```go\nfmt.Println()\n```"
|
||||
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
|
||||
if !usedLight {
|
||||
t.Error("very high threshold: code block (0.40) should route to light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_LightModel(t *testing.T) {
|
||||
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
|
||||
if r.LightModel() != "my-fast-model" {
|
||||
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
|
||||
}
|
||||
}
|
||||
|
||||
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
|
||||
|
||||
type fixedScoreClassifier struct{ score float64 }
|
||||
|
||||
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
|
||||
|
||||
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.2},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if !usedLight {
|
||||
t.Error("low score with custom classifier: expected light model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.8},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("high score with custom classifier: expected primary model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
|
||||
// score == threshold → primary (uses >= comparison)
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.5},
|
||||
)
|
||||
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
|
||||
if usedLight {
|
||||
t.Error("score == threshold: expected primary model (>= threshold → primary)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
|
||||
r := newWithClassifier(
|
||||
RouterConfig{LightModel: "light", Threshold: 0.5},
|
||||
&fixedScoreClassifier{score: 0.42},
|
||||
)
|
||||
_, _, score := r.SelectModel("anything", nil, "heavy")
|
||||
if score != 0.42 {
|
||||
t.Errorf("score: got %f, want 0.42", score)
|
||||
}
|
||||
}
|
||||
@@ -259,15 +259,7 @@ func (c *ClawHubRegistry) DownloadAndInstall(
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
|
||||
tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
|
||||
tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
@@ -284,17 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall(
|
||||
// --- HTTP helper ---
|
||||
|
||||
func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
req, err := c.newGetRequest(ctx, urlStr, "application/json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
resp, err := utils.DoRequestWithRetry(c.client, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -312,3 +299,64 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", accept)
|
||||
if c.authToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.authToken)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) {
|
||||
req, err := c.newGetRequest(ctx, urlStr, "application/zip")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := utils.DoRequestWithRetry(c.client, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
errBody := make([]byte, 512)
|
||||
n, _ := io.ReadFull(resp.Body, errBody)
|
||||
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
cleanup := func() {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
|
||||
src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1)
|
||||
written, err := io.Copy(tmpFile, src)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return "", fmt.Errorf("download write failed: %w", err)
|
||||
}
|
||||
|
||||
if written > int64(c.maxZipSize) {
|
||||
cleanup()
|
||||
return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
@@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) {
|
||||
assert.Equal(t, "clawhub", results[0].RegistryName)
|
||||
}
|
||||
|
||||
func TestClawHubRegistrySearchRetries429(t *testing.T) {
|
||||
attempts := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.Header().Set("Retry-After", "0")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte("rate limited"))
|
||||
return
|
||||
}
|
||||
|
||||
slug := "github"
|
||||
name := "GitHub Integration"
|
||||
summary := "Interact with GitHub repos"
|
||||
version := "1.0.0"
|
||||
|
||||
json.NewEncoder(w).Encode(clawhubSearchResponse{
|
||||
Results: []clawhubSearchResult{
|
||||
{Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
results, err := reg.Search(context.Background(), "github", 5)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
assert.Equal(t, 2, attempts)
|
||||
assert.Equal(t, "github", results[0].Slug)
|
||||
}
|
||||
|
||||
func TestClawHubRegistryGetSkillMeta(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
|
||||
@@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
|
||||
assert.Contains(t, string(readmeContent), "# Test Skill")
|
||||
}
|
||||
|
||||
func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) {
|
||||
zipBuf := createTestZip(t, map[string]string{
|
||||
"SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill",
|
||||
})
|
||||
|
||||
downloadAttempts := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/skills/retry-skill":
|
||||
json.NewEncoder(w).Encode(clawhubSkillResponse{
|
||||
Slug: "retry-skill",
|
||||
DisplayName: "Retry Skill",
|
||||
Summary: "A retry test skill",
|
||||
LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
|
||||
})
|
||||
case "/api/v1/download":
|
||||
downloadAttempts++
|
||||
if downloadAttempts == 1 {
|
||||
w.Header().Set("Retry-After", "0")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte("rate limited"))
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "retry-skill", r.URL.Query().Get("slug"))
|
||||
w.Header().Set("Content-Type", "application/zip")
|
||||
w.Write(zipBuf)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
targetDir := filepath.Join(tmpDir, "retry-skill")
|
||||
|
||||
reg := newTestRegistry(srv.URL, "")
|
||||
result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "1.0.0", result.Version)
|
||||
assert.Equal(t, 2, downloadAttempts)
|
||||
|
||||
skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(skillContent), "Hello skill")
|
||||
}
|
||||
|
||||
func TestClawHubRegistryAuthToken(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
|
||||
+6
-1
@@ -132,9 +132,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
|
||||
timeout := 60 * time.Second
|
||||
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
|
||||
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
return &ExecTool{
|
||||
workingDir: workingDir,
|
||||
timeout: 60 * time.Second,
|
||||
timeout: timeout,
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
|
||||
+71
-1
@@ -395,6 +395,68 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type SearXNGSearchProvider struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (p *SearXNGSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := fmt.Sprintf("%s/search?q=%s&format=json&categories=general",
|
||||
strings.TrimSuffix(p.baseURL, "/"),
|
||||
url.QueryEscape(query))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("SearXNG returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
Engine string `json:"engine"`
|
||||
Score float64 `json:"score"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
// Limit results to requested count
|
||||
if len(result.Results) > count {
|
||||
result.Results = result.Results[:count]
|
||||
}
|
||||
|
||||
// Format results in standard PicoClaw format
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("Results for: %s (via SearXNG)\n", query))
|
||||
for i, r := range result.Results {
|
||||
b.WriteString(fmt.Sprintf("%d. %s\n", i+1, r.Title))
|
||||
b.WriteString(fmt.Sprintf(" %s\n", r.URL))
|
||||
if r.Content != "" {
|
||||
b.WriteString(fmt.Sprintf(" %s\n", r.Content))
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
type GLMSearchProvider struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -495,6 +557,9 @@ type WebSearchToolOptions struct {
|
||||
PerplexityAPIKey string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
SearXNGBaseURL string
|
||||
SearXNGMaxResults int
|
||||
SearXNGEnabled bool
|
||||
GLMSearchAPIKey string
|
||||
GLMSearchBaseURL string
|
||||
GLMSearchEngine string
|
||||
@@ -507,7 +572,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
|
||||
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
|
||||
if err != nil {
|
||||
@@ -526,6 +591,11 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
}
|
||||
} else if opts.SearXNGEnabled && opts.SearXNGBaseURL != "" {
|
||||
provider = &SearXNGSearchProvider{baseURL: opts.SearXNGBaseURL}
|
||||
if opts.SearXNGMaxResults > 0 {
|
||||
maxResults = opts.SearXNGMaxResults
|
||||
}
|
||||
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
|
||||
client, err := createHTTPClient(opts.Proxy, searchTimeout)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user