Merge branch 'sipeed:main' into fix/reload-config-selfkill-guard

This commit is contained in:
mosir
2026-03-06 18:27:46 +08:00
committed by GitHub
72 changed files with 4512 additions and 556 deletions
+57 -1
View File
@@ -42,6 +42,9 @@ type ContextBuilder struct {
}
func getGlobalConfigDir() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
return home
}
home, err := os.UserHomeDir()
if err != nil {
return ""
@@ -602,7 +605,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
}
}
return sanitized
// Second pass: ensure every assistant message with tool_calls has matching
// tool result messages following it. This is required by strict providers
// like DeepSeek that enforce: "An assistant message with 'tool_calls' must
// be followed by tool messages responding to each 'tool_call_id'."
final := make([]providers.Message, 0, len(sanitized))
for i := 0; i < len(sanitized); i++ {
msg := sanitized[i]
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
// Collect expected tool_call IDs
expected := make(map[string]bool, len(msg.ToolCalls))
for _, tc := range msg.ToolCalls {
expected[tc.ID] = false
}
// Check following messages for matching tool results
toolMsgCount := 0
for j := i + 1; j < len(sanitized); j++ {
if sanitized[j].Role != "tool" {
break
}
toolMsgCount++
if _, exists := expected[sanitized[j].ToolCallID]; exists {
expected[sanitized[j].ToolCallID] = true
}
}
// If any tool_call_id is missing, drop this assistant message and its partial tool messages
allFound := true
for toolCallID, found := range expected {
if !found {
allFound = false
logger.DebugCF(
"agent",
"Dropping assistant message with incomplete tool results",
map[string]any{
"missing_tool_call_id": toolCallID,
"expected_count": len(expected),
"found_count": toolMsgCount,
},
)
break
}
}
if !allFound {
// Skip this assistant message and its tool messages
i += toolMsgCount
continue
}
}
final = append(final, msg)
}
return final
}
func (cb *ContextBuilder) AddToolResult(
+74
View File
@@ -207,3 +207,77 @@ func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) {
}
}
}
// TestSanitizeHistoryForProvider_IncompleteToolResults tests the forward validation
// that ensures assistant messages with tool_calls have ALL matching tool results.
// This fixes the DeepSeek error: "An assistant message with 'tool_calls' must be
// followed by tool messages responding to each 'tool_call_id'."
func TestSanitizeHistoryForProvider_IncompleteToolResults(t *testing.T) {
// Assistant expects tool results for both A and B, but only A is present
history := []providers.Message{
msg("user", "do two things"),
assistantWithTools("A", "B"),
toolResult("A"),
// toolResult("B") is missing - this would cause DeepSeek to fail
msg("user", "next question"),
msg("assistant", "answer"),
}
result := sanitizeHistoryForProvider(history)
// The assistant message with incomplete tool results should be dropped,
// along with its partial tool result. The remaining messages are:
// user ("do two things"), user ("next question"), assistant ("answer")
if len(result) != 3 {
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "user", "assistant")
}
// TestSanitizeHistoryForProvider_MissingAllToolResults tests the case where
// an assistant message has tool_calls but no tool results follow at all.
func TestSanitizeHistoryForProvider_MissingAllToolResults(t *testing.T) {
history := []providers.Message{
msg("user", "do something"),
assistantWithTools("A"),
// No tool results at all
msg("user", "hello"),
msg("assistant", "hi"),
}
result := sanitizeHistoryForProvider(history)
// The assistant message with no tool results should be dropped.
// Remaining: user ("do something"), user ("hello"), assistant ("hi")
if len(result) != 3 {
t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "user", "assistant")
}
// TestSanitizeHistoryForProvider_PartialToolResultsInMiddle tests that
// incomplete tool results in the middle of a conversation are properly handled.
func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) {
history := []providers.Message{
msg("user", "first"),
assistantWithTools("A"),
toolResult("A"),
msg("assistant", "done"),
msg("user", "second"),
assistantWithTools("B", "C"),
toolResult("B"),
// toolResult("C") is missing
msg("user", "third"),
assistantWithTools("D"),
toolResult("D"),
msg("assistant", "all done"),
}
result := sanitizeHistoryForProvider(history)
// First round is complete (user, assistant+tools, tool, assistant),
// second round is incomplete and dropped (assistant+tools, partial tool),
// third round is complete (user, assistant+tools, tool, assistant).
// Remaining: user, assistant, tool, assistant, user, user, assistant, tool, assistant
if len(result) != 9 {
t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
}
assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant")
}
+55 -12
View File
@@ -37,6 +37,14 @@ type AgentInstance struct {
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
// Router is non-nil when model routing is configured and the light model
// was successfully resolved. It scores each incoming message and decides
// whether to route to LightCandidates or stay with Candidates.
Router *routing.Router
// LightCandidates holds the resolved provider candidates for the light model.
// Pre-computed at agent creation to avoid repeated model_list lookups at runtime.
LightCandidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -60,17 +68,30 @@ func NewAgentInstance(
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
if cfg.Tools.IsToolEnabled("read_file") {
toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("write_file") {
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
}
if cfg.Tools.IsToolEnabled("list_dir") {
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
}
if cfg.Tools.IsToolEnabled("edit_file") {
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
}
if cfg.Tools.IsToolEnabled("append_file") {
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
}
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
@@ -167,6 +188,25 @@ func NewAgentInstance(
candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
var lightCandidates []providers.FallbackCandidate
if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" {
lightModelCfg := providers.ModelConfig{Primary: rc.LightModel}
resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList)
if len(resolved) > 0 {
router = routing.New(routing.RouterConfig{
LightModel: rc.LightModel,
Threshold: rc.Threshold,
})
lightCandidates = resolved
} else {
log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q",
rc.LightModel, agentID)
}
}
return &AgentInstance{
ID: agentID,
Name: agentName,
@@ -187,6 +227,8 @@ func NewAgentInstance(
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
Router: router,
LightCandidates: lightCandidates,
}
}
@@ -195,12 +237,13 @@ func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentD
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
return expandHome(strings.TrimSpace(agentCfg.Workspace))
}
// Use the configured default workspace (respects PICOCLAW_HOME)
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
return expandHome(defaults.Workspace)
}
home, _ := os.UserHomeDir()
// For named agents without explicit workspace, use default workspace with agent ID suffix
id := routing.NormalizeAgentID(agentCfg.ID)
return filepath.Join(home, ".picoclaw", "workspace-"+id)
return filepath.Join(expandHome(defaults.Workspace), "..", "workspace-"+id)
}
// resolveAgentModel resolves the primary model for an agent.
+273 -165
View File
@@ -21,6 +21,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -46,6 +47,7 @@ type AgentLoop struct {
channelManager *channels.Manager
mediaStore media.MediaStore
transcriber voice.Transcriber
cmdRegistry *commands.Registry
}
// processOptions configures how a message is processed
@@ -61,7 +63,15 @@ type processOptions struct {
NoHistory bool // If true, don't load session history (for heartbeat)
}
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
const (
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
sessionKeyAgentPrefix = "agent:"
metadataKeyAccountID = "account_id"
metadataKeyGuildID = "guild_id"
metadataKeyTeamID = "team_id"
metadataKeyParentPeerKind = "parent_peer_kind"
metadataKeyParentPeerID = "parent_peer_id"
)
func NewAgentLoop(
cfg *config.Config,
@@ -84,14 +94,17 @@ func NewAgentLoop(
stateManager = state.NewManager(defaultAgent.Workspace)
}
return &AgentLoop{
al := &AgentLoop{
bus: msgBus,
cfg: cfg,
registry: registry,
state: stateManager,
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
}
return al
}
// registerSharedTools registers tools that are shared across all agents (web, message, spawn).
@@ -108,76 +121,106 @@ func registerSharedTools(
}
// Web tools
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
if cfg.Tools.IsToolEnabled("web") {
searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
})
if err != nil {
logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
} else if searchTool != nil {
agent.Tools.Register(searchTool)
}
}
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
if cfg.Tools.IsToolEnabled("web_fetch") {
fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
if err != nil {
logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
} else {
agent.Tools.Register(fetchTool)
}
}
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
agent.Tools.Register(tools.NewSPITool())
if cfg.Tools.IsToolEnabled("i2c") {
agent.Tools.Register(tools.NewI2CTool())
}
if cfg.Tools.IsToolEnabled("spi") {
agent.Tools.Register(tools.NewSPITool())
}
// Message tool
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
if cfg.Tools.IsToolEnabled("message") {
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
})
})
})
agent.Tools.Register(messageTool)
agent.Tools.Register(messageTool)
}
// Skill discovery and installation tools
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
searchCache := skills.NewSearchCache(
cfg.Tools.Skills.SearchCache.MaxSize,
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
)
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
skills_enabled := cfg.Tools.IsToolEnabled("skills")
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
install_skills_enable := cfg.Tools.IsToolEnabled("install_skill")
if skills_enabled && (find_skills_enable || install_skills_enable) {
registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
})
if find_skills_enable {
searchCache := skills.NewSearchCache(
cfg.Tools.Skills.SearchCache.MaxSize,
time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
)
agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
}
if install_skills_enable {
agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
}
}
// Spawn tool with allowlist checker
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
} else {
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
}
}
}
}
@@ -185,7 +228,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
// Initialize MCP servers for all agents
if al.cfg.Tools.MCP.Enabled {
if al.cfg.Tools.IsToolEnabled("mcp") {
mcpManager := mcp.NewManager()
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
@@ -227,6 +270,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
if !ok {
continue
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
agent.Tools.Register(mcpTool)
totalRegistrations++
@@ -518,27 +562,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return al.processSystemMessage(ctx, msg)
}
// Check for commands
if response, handled := al.handleCommand(ctx, msg); handled {
route, agent, routeErr := al.resolveMessageRoute(msg)
// Commands are checked before requiring a successful route.
// Global commands (/help, /show, /switch) work even when routing fails;
// context-dependent commands check their own Runtime fields and report
// "unavailable" when the required capability is nil.
if response, handled := al.handleCommand(ctx, msg, agent); handled {
return response, nil
}
// Route to determine agent and session key
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
AccountID: msg.Metadata["account_id"],
Peer: extractPeer(msg),
ParentPeer: extractParentPeer(msg),
GuildID: msg.Metadata["guild_id"],
TeamID: msg.Metadata["team_id"],
})
agent, ok := al.registry.GetAgent(route.AgentID)
if !ok {
agent = al.registry.GetDefaultAgent()
}
if agent == nil {
return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
if routeErr != nil {
return "", routeErr
}
// Reset message-tool state for this round so we don't skip publishing due to a previous round.
@@ -548,17 +583,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
}
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
sessionKey := route.SessionKey
if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") {
sessionKey = msg.SessionKey
}
// Resolve session key from route, while preserving explicit agent-scoped keys.
scopeKey := resolveScopeKey(route, msg.SessionKey)
sessionKey := scopeKey
logger.InfoCF("agent", "Routed message",
map[string]any{
"agent_id": agent.ID,
"session_key": sessionKey,
"matched_by": route.MatchedBy,
"agent_id": agent.ID,
"scope_key": scopeKey,
"session_key": sessionKey,
"matched_by": route.MatchedBy,
"route_agent": route.AgentID,
"route_channel": route.Channel,
})
return al.runAgentLoop(ctx, agent, processOptions{
@@ -573,6 +609,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
})
}
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
AccountID: inboundMetadata(msg, metadataKeyAccountID),
Peer: extractPeer(msg),
ParentPeer: extractParentPeer(msg),
GuildID: inboundMetadata(msg, metadataKeyGuildID),
TeamID: inboundMetadata(msg, metadataKeyTeamID),
})
agent, ok := al.registry.GetAgent(route.AgentID)
if !ok {
agent = al.registry.GetDefaultAgent()
}
if agent == nil {
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
}
return route, agent, nil
}
func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) {
return msgSessionKey
}
return route.SessionKey
}
func (al *AgentLoop) processSystemMessage(
ctx context.Context,
msg bus.InboundMessage,
@@ -793,6 +857,12 @@ func (al *AgentLoop) runLLMIteration(
iteration := 0
var finalContent string
// Determine effective model tier for this conversation turn.
// selectCandidates evaluates routing once and the decision is sticky for
// all tool-follow-up iterations within the same turn so that a multi-step
// tool chain doesn't switch models mid-way through.
activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
for iteration < agent.MaxIterations {
iteration++
@@ -811,7 +881,7 @@ func (al *AgentLoop) runLLMIteration(
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"model": agent.Model,
"model": activeModel,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"max_tokens": agent.MaxTokens,
@@ -827,7 +897,7 @@ func (al *AgentLoop) runLLMIteration(
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM with fallback chain if candidates are configured.
// Call LLM with fallback chain if multiple candidates are configured.
var response *providers.LLMResponse
var err error
@@ -848,10 +918,10 @@ func (al *AgentLoop) runLLMIteration(
}
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
ctx,
agent.Candidates,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
},
@@ -869,7 +939,7 @@ func (al *AgentLoop) runLLMIteration(
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts)
return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
}
// Retry loop for context/token errors
@@ -1138,6 +1208,44 @@ func (al *AgentLoop) runLLMIteration(
return finalContent, iteration, nil
}
// selectCandidates returns the model candidates and resolved model name to use
// for a conversation turn. When model routing is configured and the incoming
// message scores below the complexity threshold, it returns the light model
// candidates instead of the primary ones.
//
// The returned (candidates, model) pair is used for all LLM calls within one
// turn — tool follow-up iterations use the same tier as the initial call so
// that a multi-step tool chain doesn't switch models mid-way.
func (al *AgentLoop) selectCandidates(
agent *AgentInstance,
userMsg string,
history []providers.Message,
) (candidates []providers.FallbackCandidate, model string) {
if agent.Router == nil || len(agent.LightCandidates) == 0 {
return agent.Candidates, agent.Model
}
_, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model)
if !usedLight {
logger.DebugCF("agent", "Model routing: primary model selected",
map[string]any{
"agent_id": agent.ID,
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.Candidates, agent.Model
}
logger.InfoCF("agent", "Model routing: light model selected",
map[string]any{
"agent_id": agent.ID,
"light_model": agent.Router.LightModel(),
"score": score,
"threshold": agent.Router.Threshold(),
})
return agent.LightCandidates, agent.Router.LightModel()
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
@@ -1429,94 +1537,87 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
return totalChars * 2 / 5
}
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
content := strings.TrimSpace(msg.Content)
if !strings.HasPrefix(content, "/") {
func (al *AgentLoop) handleCommand(
ctx context.Context,
msg bus.InboundMessage,
agent *AgentInstance,
) (string, bool) {
if !commands.HasCommandPrefix(msg.Content) {
return "", false
}
parts := strings.Fields(content)
if len(parts) == 0 {
if al.cmdRegistry == nil {
return "", false
}
cmd := parts[0]
args := parts[1:]
rt := al.buildCommandsRuntime(agent)
executor := commands.NewExecutor(al.cmdRegistry, rt)
switch cmd {
case "/show":
if len(args) < 1 {
return "Usage: /show [model|channel|agents]", true
}
switch args[0] {
case "model":
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
return fmt.Sprintf("Current model: %s", defaultAgent.Model), true
case "channel":
return fmt.Sprintf("Current channel: %s", msg.Channel), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown show target: %s", args[0]), true
}
var commandReply string
result := executor.Execute(ctx, commands.Request{
Channel: msg.Channel,
ChatID: msg.ChatID,
SenderID: msg.SenderID,
Text: msg.Content,
Reply: func(text string) error {
commandReply = text
return nil
},
})
case "/list":
if len(args) < 1 {
return "Usage: /list [models|channels|agents]", true
switch result.Outcome {
case commands.OutcomeHandled:
if result.Err != nil {
return mapCommandError(result), true
}
switch args[0] {
case "models":
return "Available models: configured in config.json per agent", true
case "channels":
if commandReply != "" {
return commandReply, true
}
return "", true
default: // OutcomePassthrough — let the message fall through to LLM
return "", false
}
}
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime {
rt := &commands.Runtime{
Config: al.cfg,
ListAgentIDs: al.registry.ListAgentIDs,
ListDefinitions: al.cmdRegistry.Definitions,
GetEnabledChannels: func() []string {
if al.channelManager == nil {
return "Channel manager not initialized", true
return nil
}
channels := al.channelManager.GetEnabledChannels()
if len(channels) == 0 {
return "No channels enabled", true
}
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
case "agents":
agentIDs := al.registry.ListAgentIDs()
return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true
default:
return fmt.Sprintf("Unknown list target: %s", args[0]), true
}
case "/switch":
if len(args) < 3 || args[1] != "to" {
return "Usage: /switch [model|channel] to <name>", true
}
target := args[0]
value := args[2]
switch target {
case "model":
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
return "No default agent configured", true
}
oldModel := defaultAgent.Model
defaultAgent.Model = value
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
case "channel":
return al.channelManager.GetEnabledChannels()
},
SwitchChannel: func(value string) error {
if al.channelManager == nil {
return "Channel manager not initialized", true
return fmt.Errorf("channel manager not initialized")
}
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
return fmt.Errorf("channel '%s' not found or not enabled", value)
}
return fmt.Sprintf("Switched target channel to %s", value), true
default:
return fmt.Sprintf("Unknown switch target: %s", target), true
return nil
},
}
if agent != nil {
rt.GetModelInfo = func() (string, string) {
return agent.Model, al.cfg.Agents.Defaults.Provider
}
rt.SwitchModel = func(value string) (string, error) {
oldModel := agent.Model
agent.Model = value
return oldModel, nil
}
}
return rt
}
return "", false
func mapCommandError(result commands.ExecuteResult) string {
if result.Command == "" {
return fmt.Sprintf("Failed to execute command: %v", result.Err)
}
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
}
// extractPeer extracts the routing peer from the inbound message's structured Peer field.
@@ -1535,10 +1636,17 @@ func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID}
}
func inboundMetadata(msg bus.InboundMessage, key string) string {
if msg.Metadata == nil {
return ""
}
return msg.Metadata[key]
}
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
parentKind := msg.Metadata["parent_peer_kind"]
parentID := msg.Metadata["parent_peer_id"]
parentKind := inboundMetadata(msg, metadataKeyParentPeerKind)
parentID := inboundMetadata(msg, metadataKeyParentPeerID)
if parentKind == "" || parentID == "" {
return nil
}
+221 -10
View File
@@ -15,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -227,16 +228,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) {
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Workspace = tmpDir
cfg.Agents.Defaults.Model = "test-model"
cfg.Agents.Defaults.MaxTokens = 4096
cfg.Agents.Defaults.MaxToolIterations = 10
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
@@ -323,6 +319,29 @@ func (m *simpleMockProvider) GetDefaultModel() string {
return "mock-model"
}
type countingMockProvider struct {
response string
calls int
}
func (m *countingMockProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *countingMockProvider) GetDefaultModel() string {
return "counting-mock-model"
}
// mockCustomTool is a simple mock tool for registration testing
type mockCustomTool struct{}
@@ -364,6 +383,198 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms
const responseTimeout = 3 * time.Second
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "ok"}
al := NewAgentLoop(cfg, msgBus, provider)
msg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hello",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
route := al.registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
Peer: extractPeer(msg),
})
sessionKey := route.SessionKey
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
helper := testHelper{al: al}
_ = helper.executeAndGetResponse(t, context.Background(), msg)
history := defaultAgent.Sessions.GetHistory(sessionKey)
if len(history) != 2 {
t.Fatalf("expected session history len=2, got %d", len(history))
}
if history[0].Role != "user" || history[0].Content != "hello" {
t.Fatalf("unexpected first message in session: %+v", history[0])
}
}
func TestProcessMessage_CommandOutcomes(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Session: config.SessionConfig{
DMScope: "per-channel-peer",
},
}
msgBus := bus.NewMessageBus()
provider := &countingMockProvider{response: "LLM reply"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
baseMsg := bus.InboundMessage{
Channel: "whatsapp",
SenderID: "user1",
ChatID: "chat1",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/show channel",
Peer: baseMsg.Peer,
})
if showResp != "Current Channel: whatsapp" {
t.Fatalf("unexpected /show reply: %q", showResp)
}
if provider.calls != 0 {
t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls)
}
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/foo",
Peer: baseMsg.Peer,
})
if fooResp != "LLM reply" {
t.Fatalf("unexpected /foo reply: %q", fooResp)
}
if provider.calls != 1 {
t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls)
}
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: baseMsg.Channel,
SenderID: baseMsg.SenderID,
ChatID: baseMsg.ChatID,
Content: "/new",
Peer: baseMsg.Peer,
})
if newResp != "LLM reply" {
t.Fatalf("unexpected /new reply: %q", newResp)
}
if provider.calls != 2 {
t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls)
}
}
func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Provider: "openai",
Model: "before-switch",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &countingMockProvider{response: "LLM reply"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/switch model to after-switch",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") {
t.Fatalf("unexpected /switch reply: %q", switchResp)
}
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "/show model",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
})
if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") {
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
}
if provider.calls != 0 {
t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
+20 -5
View File
@@ -212,7 +212,10 @@ func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading device code response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
@@ -300,7 +303,10 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading device code response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
@@ -360,7 +366,10 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au
return nil, fmt.Errorf("pending")
}
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading device token response: %w", err)
}
var tokenResp struct {
AuthorizationCode string `json:"authorization_code"`
@@ -401,7 +410,10 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading token refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
@@ -494,7 +506,10 @@ func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading token exchange response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", string(body))
}
+3
View File
@@ -39,6 +39,9 @@ func (c *AuthCredential) NeedsRefresh() bool {
}
func authFilePath() string {
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
return filepath.Join(home, "auth.json")
}
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "auth.json")
}
+70
View File
@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"os"
"regexp"
"strings"
"sync"
"time"
@@ -26,6 +27,12 @@ const (
sendTimeout = 10 * time.Second
)
var (
// Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call)
channelRefRe = regexp.MustCompile(`<#(\d+)>`)
msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`)
)
type DiscordChannel struct {
*channels.BaseChannel
session *discordgo.Session
@@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = c.stripBotMention(content)
}
// Resolve Discord refs in main content before concatenation to avoid
// double-expanding links that appear in the referenced message.
content = c.resolveDiscordRefs(s, content, m.GuildID)
// Prepend referenced (quoted) message content if this is a reply
if m.MessageReference != nil && m.ReferencedMessage != nil {
refContent := m.ReferencedMessage.Content
if refContent != "" {
refAuthor := "unknown"
if m.ReferencedMessage.Author != nil {
refAuthor = m.ReferencedMessage.Author.Username
}
refContent = c.resolveDiscordRefs(s, refContent, m.GuildID)
content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s",
refAuthor, refContent, content)
}
}
senderID := m.Author.ID
mediaPaths := make([]string, 0, len(m.Attachments))
@@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
return nil
}
// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and
// expands Discord message links to show the linked message content.
// Only links pointing to the same guild are expanded to prevent cross-guild leakage.
func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string {
// 1. Resolve channel references: <#id> → #channel-name
text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string {
parts := channelRefRe.FindStringSubmatch(match)
if len(parts) < 2 {
return match
}
// Prefer session state cache to avoid API calls
if ch, err := s.State.Channel(parts[1]); err == nil {
return "#" + ch.Name
}
if ch, err := s.Channel(parts[1]); err == nil {
return "#" + ch.Name
}
return match
})
// 2. Expand Discord message links (max 3, same guild only)
matches := msgLinkRe.FindAllStringSubmatch(text, 3)
for _, m := range matches {
if len(m) < 4 {
continue
}
linkGuildID, channelID, messageID := m[1], m[2], m[3]
// Security: only expand links from the same guild
if linkGuildID != guildID {
continue
}
msg, err := s.ChannelMessage(channelID, messageID)
if err != nil || msg == nil || msg.Content == "" {
continue
}
author := "unknown"
if msg.Author != nil {
author = msg.Author.Username
}
text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content)
}
return text
}
// stripBotMention removes the bot mention from the message content.
// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
func (c *DiscordChannel) stripBotMention(text string) string {
@@ -0,0 +1,98 @@
package discord
import (
"testing"
)
func TestChannelRefRegex(t *testing.T) {
tests := []struct {
name string
input string
wantID string
wantOK bool
}{
{"basic channel ref", "<#123456789>", "123456789", true},
{"long id", "<#9876543210123456>", "9876543210123456", true},
{"no match plain text", "hello world", "", false},
{"no match partial", "<#>", "", false},
{"no match letters", "<#abc>", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := channelRefRe.FindStringSubmatch(tt.input)
if tt.wantOK {
if len(matches) < 2 || matches[1] != tt.wantID {
t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID)
}
} else {
if len(matches) >= 2 {
t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches)
}
}
})
}
}
func TestMsgLinkRegex(t *testing.T) {
tests := []struct {
name string
input string
wantGuild string
wantChan string
wantMsg string
wantOK bool
}{
{
"discord.com link",
"https://discord.com/channels/111/222/333",
"111", "222", "333", true,
},
{
"discordapp.com link",
"https://discordapp.com/channels/111/222/333",
"111", "222", "333", true,
},
{
"real world ids",
"check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please",
"9000000000000001", "9000000000000002", "9000000000000003", true,
},
{"no match http", "http://discord.com/channels/1/2/3", "", "", "", false},
{"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false},
{"no match plain text", "hello world", "", "", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := msgLinkRe.FindStringSubmatch(tt.input)
if tt.wantOK {
if len(matches) < 4 {
t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s",
tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg)
}
if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg {
t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s",
tt.input, matches[1], matches[2], matches[3],
tt.wantGuild, tt.wantChan, tt.wantMsg)
}
} else {
if len(matches) >= 4 {
t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches)
}
}
})
}
}
func TestMsgLinkRegex_MultipleMatches(t *testing.T) {
input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12"
matches := msgLinkRe.FindAllStringSubmatch(input, 3)
if len(matches) != 3 {
t.Fatalf("expected 3 matches (capped), got %d", len(matches))
}
// Verify the 3rd match is 7/8/9 (not 10/11/12)
if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" {
t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2])
}
}
+12 -1
View File
@@ -1,6 +1,10 @@
package channels
import "context"
import (
"context"
"github.com/sipeed/picoclaw/pkg/commands"
)
// TypingCapable — channels that can show a typing/thinking indicator.
// StartTyping begins the indicator and returns a stop function.
@@ -39,3 +43,10 @@ type PlaceholderRecorder interface {
RecordTypingStop(channel, chatID string, stop func())
RecordReactionUndo(channel, chatID string, undo func())
}
// CommandRegistrarCapable is implemented by channels that can register
// command menus with their upstream platform (e.g. Telegram BotCommand).
// Channels that do not support platform-level command menus can ignore it.
type CommandRegistrarCapable interface {
RegisterCommands(ctx context.Context, defs []commands.Definition) error
}
+16
View File
@@ -0,0 +1,16 @@
package channels
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/commands"
)
type mockRegistrar struct{}
func (mockRegistrar) RegisterCommands(context.Context, []commands.Definition) error { return nil }
func TestCommandRegistrarCapable_Compiles(t *testing.T) {
var _ CommandRegistrarCapable = mockRegistrar{}
}
+4 -1
View File
@@ -654,7 +654,10 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("reading LINE API error response: %w", err))
}
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody)))
}
@@ -0,0 +1,116 @@
package telegram
import (
"context"
"math/rand"
"slices"
"time"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/logger"
)
var commandRegistrationBackoff = []time.Duration{
5 * time.Second,
15 * time.Second,
60 * time.Second,
5 * time.Minute,
10 * time.Minute,
}
func commandRegistrationDelay(attempt int) time.Duration {
if len(commandRegistrationBackoff) == 0 {
return 0
}
base := commandRegistrationBackoff[min(attempt, len(commandRegistrationBackoff)-1)]
// Full jitter in [0.5, 1.0) to avoid synchronized retries across instances.
return time.Duration(float64(base) * (0.5 + rand.Float64()*0.5))
}
// RegisterCommands registers bot commands on Telegram platform.
func (c *TelegramChannel) RegisterCommands(ctx context.Context, defs []commands.Definition) error {
botCommands := make([]telego.BotCommand, 0, len(defs))
for _, def := range defs {
if def.Name == "" || def.Description == "" {
continue
}
botCommands = append(botCommands, telego.BotCommand{
Command: def.Name,
Description: def.Description,
})
}
current, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{})
if err != nil {
// If we can't read current commands, fall through to set them.
logger.WarnCF("telegram", "Failed to get current commands, will set unconditionally",
map[string]any{"error": err.Error()})
} else if slices.Equal(current, botCommands) {
logger.DebugCF("telegram", "Bot commands are up to date", nil)
return nil
}
return c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
Commands: botCommands,
})
}
func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []commands.Definition) {
if len(defs) == 0 {
return
}
register := c.registerFunc
if register == nil {
register = c.RegisterCommands
}
regCtx, cancel := context.WithCancel(ctx)
c.commandRegCancel = cancel
// Registration runs asynchronously so Telegram message intake is never blocked
// by temporary upstream API failures. Retry stops on success or channel shutdown.
go func() {
attempt := 0
timer := time.NewTimer(0)
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
defer timer.Stop()
for {
err := register(regCtx, defs)
if err == nil {
logger.InfoCF("telegram", "Telegram commands registered", map[string]any{
"count": len(defs),
})
return
}
delay := commandRegistrationDelay(attempt)
logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{
"error": err.Error(),
"retry_after": delay.String(),
})
attempt++
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(delay)
select {
case <-regCtx.Done():
return
case <-timer.C:
}
}
}()
}
@@ -0,0 +1,96 @@
package telegram
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/commands"
)
func TestStartCommandRegistration_DoesNotBlock(t *testing.T) {
ch := &TelegramChannel{}
started := make(chan struct{}, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch.registerFunc = func(context.Context, []commands.Definition) error {
started <- struct{}{}
return errors.New("temporary failure")
}
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help"}})
select {
case <-started:
case <-time.After(time.Second):
t.Fatal("registration did not start asynchronously")
}
}
func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) {
ch := &TelegramChannel{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
origBackoff := commandRegistrationBackoff
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
defer func() { commandRegistrationBackoff = origBackoff }()
var attempts atomic.Int32
ch.registerFunc = func(context.Context, []commands.Definition) error {
n := attempts.Add(1)
if n < 3 {
return errors.New("temporary failure")
}
return nil
}
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
deadline := time.Now().Add(250 * time.Millisecond)
for time.Now().Before(deadline) {
if attempts.Load() >= 3 {
break
}
time.Sleep(5 * time.Millisecond)
}
if attempts.Load() < 3 {
t.Fatalf("expected at least 3 attempts, got %d", attempts.Load())
}
stable := attempts.Load()
time.Sleep(30 * time.Millisecond)
if attempts.Load() != stable {
t.Fatalf("expected retries to stop after success, got %d -> %d", stable, attempts.Load())
}
}
func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) {
ch := &TelegramChannel{}
ctx, cancel := context.WithCancel(context.Background())
origBackoff := commandRegistrationBackoff
commandRegistrationBackoff = []time.Duration{5 * time.Millisecond}
defer func() { commandRegistrationBackoff = origBackoff }()
defer cancel()
var attempts atomic.Int32
ch.registerFunc = func(context.Context, []commands.Definition) error {
attempts.Add(1)
return errors.New("always fail")
}
ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}})
time.Sleep(20 * time.Millisecond)
cancel()
time.Sleep(20 * time.Millisecond) // allow in-flight attempt to settle
stable := attempts.Load()
time.Sleep(30 * time.Millisecond)
if attempts.Load() != stable {
t.Fatalf("expected retries to quiesce after cancel, got %d -> %d", stable, attempts.Load())
}
}
+75 -95
View File
@@ -7,7 +7,6 @@ import (
"net/url"
"os"
"regexp"
"slices"
"strconv"
"strings"
"time"
@@ -18,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -40,13 +40,15 @@ var (
type TelegramChannel struct {
*channels.BaseChannel
bot *telego.Bot
bh *th.BotHandler
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
ctx context.Context
cancel context.CancelFunc
bot *telego.Bot
bh *th.BotHandler
config *config.Config
chatIDs map[string]int64
ctx context.Context
cancel context.CancelFunc
registerFunc func(context.Context, []commands.Definition) error
commandRegCancel context.CancelFunc
}
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
@@ -93,7 +95,6 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
return &TelegramChannel{
BaseChannel: base,
commands: NewTelegramCommands(bot, cfg),
bot: bot,
config: cfg,
chatIDs: make(map[string]int64),
@@ -105,12 +106,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
c.ctx, c.cancel = context.WithCancel(ctx)
if err := c.initBotCommands(c.ctx); err != nil {
logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
"error": err.Error(),
})
}
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
@@ -126,21 +121,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}
c.bh = bh
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Start(ctx, message)
}, th.CommandEqual("start"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Help(ctx, message)
}, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Show(ctx, message)
}, th.CommandEqual("show"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.List(ctx, message)
}, th.CommandEqual("list"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.handleMessage(ctx, &message)
}, th.AnyMessage())
@@ -150,6 +130,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
"username": c.bot.Username(),
})
c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions())
go func() {
if err = bh.Start(); err != nil {
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
@@ -174,50 +156,8 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
if c.cancel != nil {
c.cancel()
}
return nil
}
func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
Scope: tu.ScopeDefault(),
})
if err != nil {
return fmt.Errorf("get commands: %w", err)
}
commands := []telego.BotCommand{
{
Command: "start",
Description: "Start the bot",
},
{
Command: "help",
Description: "Show a help message",
},
{
Command: "show",
Description: "Show current configuration",
},
{
Command: "list",
Description: "List available options",
},
}
// Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
if !slices.Equal(currentCommands, commands) {
logger.InfoC("telegram", "Updating bot commands")
err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
Commands: commands,
Scope: tu.ScopeDefault(),
})
if err != nil {
return fmt.Errorf("set commands: %w", err)
}
} else {
logger.DebugC("telegram", "Bot commands are up to date")
if c.commandRegCancel != nil {
c.commandRegCancel()
}
return nil
@@ -721,34 +661,34 @@ func escapeHTML(text string) string {
// isBotMentioned checks if the bot is mentioned in the message via entities.
func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
botUsername := c.bot.Username()
if botUsername == "" {
text, entities := telegramEntityTextAndList(message)
if text == "" || len(entities) == 0 {
return false
}
entities := message.Entities
if entities == nil {
entities = message.CaptionEntities
botUsername := ""
if c.bot != nil {
botUsername = c.bot.Username()
}
runes := []rune(text)
for _, entity := range entities {
if entity.Type == "mention" {
// Extract the mention text from the message
text := message.Text
if text == "" {
text = message.Caption
}
runes := []rune(text)
end := entity.Offset + entity.Length
if end <= len(runes) {
mention := string(runes[entity.Offset:end])
if strings.EqualFold(mention, "@"+botUsername) {
return true
}
}
entityText, ok := telegramEntityText(runes, entity)
if !ok {
continue
}
if entity.Type == "text_mention" && entity.User != nil {
if entity.User.Username == botUsername {
switch entity.Type {
case telego.EntityTypeMention:
if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) {
return true
}
case telego.EntityTypeTextMention:
if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) {
return true
}
case telego.EntityTypeBotCommand:
if isBotCommandEntityForThisBot(entityText, botUsername) {
return true
}
}
@@ -756,6 +696,46 @@ func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
return false
}
func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) {
if message.Text != "" {
return message.Text, message.Entities
}
return message.Caption, message.CaptionEntities
}
func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) {
if entity.Offset < 0 || entity.Length <= 0 {
return "", false
}
end := entity.Offset + entity.Length
if entity.Offset >= len(runes) || end > len(runes) {
return "", false
}
return string(runes[entity.Offset:end]), true
}
func isBotCommandEntityForThisBot(entityText, botUsername string) bool {
if !strings.HasPrefix(entityText, "/") {
return false
}
command := strings.TrimPrefix(entityText, "/")
if command == "" {
return false
}
at := strings.IndexRune(command, '@')
if at == -1 {
// A bare /command delivered to this bot is intended for this bot.
return true
}
mentionUsername := command[at+1:]
if mentionUsername == "" || botUsername == "" {
return false
}
return strings.EqualFold(mentionUsername, botUsername)
}
// stripBotMention removes the @bot mention from the content.
func (c *TelegramChannel) stripBotMention(content string) string {
botUsername := c.bot.Username()
-156
View File
@@ -1,156 +0,0 @@
package telegram
import (
"context"
"fmt"
"strings"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/config"
)
type TelegramCommander interface {
Help(ctx context.Context, message telego.Message) error
Start(ctx context.Context, message telego.Message) error
Show(ctx context.Context, message telego.Message) error
List(ctx context.Context, message telego.Message) error
}
type cmd struct {
bot *telego.Bot
config *config.Config
}
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
return &cmd{
bot: bot,
config: cfg,
}
}
func commandArgs(text string) string {
parts := strings.SplitN(text, " ", 2)
if len(parts) < 2 {
return ""
}
return strings.TrimSpace(parts[1])
}
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
msg := `/start - Start the bot
/help - Show this help message
/show [model|channel] - Show current configuration
/list [models|channels] - List available options
`
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: msg,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Hello! I am PicoClaw 🦞",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /show [model|channel]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "model":
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
c.config.Agents.Defaults.GetModelName(),
c.config.Agents.Defaults.Provider)
case "channel":
response = "Current Channel: telegram"
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) List(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /list [models|channels]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "models":
provider := c.config.Agents.Defaults.Provider
if provider == "" {
provider = "configured default"
}
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
c.config.Agents.Defaults.GetModelName(), provider)
case "channels":
var enabled []string
if c.config.Channels.Telegram.Enabled {
enabled = append(enabled, "telegram")
}
if c.config.Channels.WhatsApp.Enabled {
enabled = append(enabled, "whatsapp")
}
if c.config.Channels.Feishu.Enabled {
enabled = append(enabled, "feishu")
}
if c.config.Channels.Discord.Enabled {
enabled = append(enabled, "discord")
}
if c.config.Channels.Slack.Enabled {
enabled = append(enabled, "slack")
}
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
@@ -0,0 +1,52 @@
package telegram
import (
"context"
"testing"
"time"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
)
func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
msg := &telego.Message{
Text: "/new",
MessageID: 9,
Chat: telego.Chat{
ID: 123,
Type: "private",
},
From: &telego.User{
ID: 42,
FirstName: "Alice",
},
}
if err := ch.handleMessage(context.Background(), msg); err != nil {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "telegram" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
@@ -0,0 +1,147 @@
package telegram
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
type getMeCaller struct {
username string
}
func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) {
if strings.HasSuffix(url, "/getMe") {
result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username)
return &ta.Response{Ok: true, Result: []byte(result)}, nil
}
return &ta.Response{Ok: true, Result: []byte("true")}, nil
}
func newTestTelegramBot(t *testing.T, username string) *telego.Bot {
t.Helper()
token := "123456:" + strings.Repeat("a", 35)
bot, err := telego.NewBot(token,
telego.WithAPICaller(getMeCaller{username: username}),
telego.WithDiscardLogger(),
)
if err != nil {
t.Fatalf("NewBot error: %v", err)
}
return bot
}
func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) {
t.Helper()
messageBus := bus.NewMessageBus()
ch := &TelegramChannel{
BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil,
channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}),
),
bot: newTestTelegramBot(t, botUsername),
chatIDs: make(map[string]int64),
ctx: context.Background(),
}
return ch, messageBus
}
func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
tests := []struct {
name string
text string
wantForwarded bool
wantContent string
}{
{
name: "command with bot username",
text: "/new@testbot",
wantForwarded: true,
wantContent: "/new",
},
{
name: "bare command",
text: "/new",
wantForwarded: true,
wantContent: "/new",
},
{
name: "command for another bot",
text: "/new@otherbot",
wantForwarded: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ch, messageBus := newGroupMentionOnlyChannel(t, "testbot")
msg := &telego.Message{
Text: tc.text,
Entities: []telego.MessageEntity{{
Type: telego.EntityTypeBotCommand,
Offset: 0,
Length: len([]rune(tc.text)),
}},
MessageID: 42,
Chat: telego.Chat{
ID: 123,
Type: "group",
},
From: &telego.User{
ID: 7,
FirstName: "Alice",
},
}
if err := ch.handleMessage(context.Background(), msg); err != nil {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
}
return
}
if ok {
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
}
})
}
}
func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) {
ch, _ := newGroupMentionOnlyChannel(t, "testbot")
msg := &telego.Message{
Text: "@testbot hello",
Entities: []telego.MessageEntity{{
Type: telego.EntityTypeMention,
Offset: 0,
Length: len("@testbot"),
}},
}
if !ch.isBotMentioned(msg) {
t.Fatal("expected mention entity to be treated as bot mention")
}
}
+4 -1
View File
@@ -793,7 +793,10 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro
return nil
}
respBody, _ := io.ReadAll(resp.Body)
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err)
}
switch {
case resp.StatusCode == http.StatusTooManyRequests:
return fmt.Errorf("response_url rate limited (%d): %s: %w",
+22 -4
View File
@@ -321,8 +321,17 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody)))
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom upload error response: %w", readErr),
)
}
return "", channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom upload error: %s", string(respBody)),
)
}
var result struct {
@@ -371,8 +380,17 @@ func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken stri
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody)))
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading wecom_app error response: %w", readErr),
)
}
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("wecom_app API error: %s", string(respBody)),
)
}
respBody, err := io.ReadAll(resp.Body)
+11 -2
View File
@@ -453,8 +453,17 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body)))
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("reading webhook error response: %w", readErr),
)
}
return channels.ClassifySendError(
resp.StatusCode,
fmt.Errorf("webhook API error: %s", string(body)),
)
}
body, err := io.ReadAll(resp.Body)
@@ -0,0 +1,41 @@
package whatsapp
import (
"context"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &WhatsAppChannel{
BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil),
ctx: context.Background(),
}
ch.handleIncomingMessage(map[string]any{
"type": "message",
"id": "mid1",
"from": "user1",
"chat": "chat1",
"content": "/help",
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/help" {
t.Fatalf("content=%q", inbound.Content)
}
}
@@ -0,0 +1,56 @@
//go:build whatsapp_native
package whatsapp
import (
"context"
"testing"
"time"
"go.mau.fi/whatsmeow/proto/waE2E"
"go.mau.fi/whatsmeow/types"
"go.mau.fi/whatsmeow/types/events"
"google.golang.org/protobuf/proto"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
messageBus := bus.NewMessageBus()
ch := &WhatsAppNativeChannel{
BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil),
runCtx: context.Background(),
}
evt := &events.Message{
Info: types.MessageInfo{
MessageSource: types.MessageSource{
Sender: types.NewJID("1001", types.DefaultUserServer),
Chat: types.NewJID("1001", types.DefaultUserServer),
},
ID: "mid1",
PushName: "Alice",
},
Message: &waE2E.Message{
Conversation: proto.String("/new"),
},
}
ch.handleIncoming(evt)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp_native" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
+16
View File
@@ -0,0 +1,16 @@
package commands
// BuiltinDefinitions returns all built-in command definitions.
// Each command group is defined in its own cmd_*.go file.
// Definitions are stateless — runtime dependencies are provided
// via the Runtime parameter passed to handlers at execution time.
func BuiltinDefinitions() []Definition {
return []Definition{
startCommand(),
helpCommand(),
showCommand(),
listCommand(),
switchCommand(),
checkCommand(),
}
}
+145
View File
@@ -0,0 +1,145 @@
package commands
import (
"context"
"strings"
"testing"
)
func findDefinitionByName(t *testing.T, defs []Definition, name string) Definition {
t.Helper()
for _, def := range defs {
if def.Name == name {
return def
}
}
t.Fatalf("missing /%s definition", name)
return Definition{}
}
func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) {
defs := BuiltinDefinitions()
helpDef := findDefinitionByName(t, defs, "help")
if helpDef.Handler == nil {
t.Fatalf("/help handler should not be nil")
}
var reply string
err := helpDef.Handler(context.Background(), Request{
Text: "/help",
Reply: func(text string) error {
reply = text
return nil
},
}, nil)
if err != nil {
t.Fatalf("/help handler error: %v", err)
}
// Now uses auto-generated EffectiveUsage which includes agents
if !strings.Contains(reply, "/show [model|channel|agents]") {
t.Fatalf("/help reply missing /show usage, got %q", reply)
}
if !strings.Contains(reply, "/list [models|channels|agents]") {
t.Fatalf("/help reply missing /list usage, got %q", reply)
}
}
func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) {
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), nil)
cases := []string{"telegram", "whatsapp"}
for _, channel := range cases {
var reply string
res := ex.Execute(context.Background(), Request{
Channel: channel,
Text: "/show channel",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/show channel on %s: outcome=%v, want=%v", channel, res.Outcome, OutcomeHandled)
}
want := "Current Channel: " + channel
if reply != want {
t.Fatalf("/show channel reply=%q, want=%q", reply, want)
}
}
}
func TestBuiltinListChannels_UsesGetEnabledChannels(t *testing.T) {
rt := &Runtime{
GetEnabledChannels: func() []string {
return []string{"telegram", "slack"}
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/list channels",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/list channels: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "telegram") || !strings.Contains(reply, "slack") {
t.Fatalf("/list channels reply=%q, want telegram and slack", reply)
}
}
func TestBuiltinShowAgents_RestoresOldBehavior(t *testing.T) {
rt := &Runtime{
ListAgentIDs: func() []string {
return []string{"default", "coder"}
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/show agents",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/show agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
t.Fatalf("/show agents reply=%q, want agent IDs", reply)
}
}
func TestBuiltinListAgents_RestoresOldBehavior(t *testing.T) {
rt := &Runtime{
ListAgentIDs: func() []string {
return []string{"default", "coder"}
},
}
defs := BuiltinDefinitions()
ex := NewExecutor(NewRegistry(defs), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/list agents",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("/list agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") {
t.Fatalf("/list agents reply=%q, want agent IDs", reply)
}
}
+33
View File
@@ -0,0 +1,33 @@
package commands
import (
"context"
"fmt"
)
func checkCommand() Definition {
return Definition{
Name: "check",
Description: "Check channel availability",
SubCommands: []SubCommand{
{
Name: "channel",
Description: "Check if a channel is available",
ArgsUsage: "<name>",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.SwitchChannel == nil {
return req.Reply(unavailableMsg)
}
value := nthToken(req.Text, 2)
if value == "" {
return req.Reply("Usage: /check channel <name>")
}
if err := rt.SwitchChannel(value); err != nil {
return req.Reply(err.Error())
}
return req.Reply(fmt.Sprintf("Channel '%s' is available and enabled", value))
},
},
},
}
}
+44
View File
@@ -0,0 +1,44 @@
package commands
import (
"context"
"fmt"
"strings"
)
func helpCommand() Definition {
return Definition{
Name: "help",
Description: "Show this help message",
Usage: "/help",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
var defs []Definition
if rt != nil && rt.ListDefinitions != nil {
defs = rt.ListDefinitions()
} else {
defs = BuiltinDefinitions()
}
return req.Reply(formatHelpMessage(defs))
},
}
}
func formatHelpMessage(defs []Definition) string {
if len(defs) == 0 {
return "No commands available."
}
lines := make([]string, 0, len(defs))
for _, def := range defs {
usage := def.EffectiveUsage()
if usage == "" {
usage = "/" + def.Name
}
desc := def.Description
if desc == "" {
desc = "No description"
}
lines = append(lines, fmt.Sprintf("%s - %s", usage, desc))
}
return strings.Join(lines, "\n")
}
+52
View File
@@ -0,0 +1,52 @@
package commands
import (
"context"
"fmt"
"strings"
)
func listCommand() Definition {
return Definition{
Name: "list",
Description: "List available options",
SubCommands: []SubCommand{
{
Name: "models",
Description: "Configured models",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.GetModelInfo == nil {
return req.Reply(unavailableMsg)
}
name, provider := rt.GetModelInfo()
if provider == "" {
provider = "configured default"
}
return req.Reply(fmt.Sprintf(
"Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
name, provider,
))
},
},
{
Name: "channels",
Description: "Enabled channels",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.GetEnabledChannels == nil {
return req.Reply(unavailableMsg)
}
enabled := rt.GetEnabledChannels()
if len(enabled) == 0 {
return req.Reply("No channels enabled")
}
return req.Reply(fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")))
},
},
{
Name: "agents",
Description: "Registered agents",
Handler: agentsHandler(),
},
},
}
}
+38
View File
@@ -0,0 +1,38 @@
package commands
import (
"context"
"fmt"
)
func showCommand() Definition {
return Definition{
Name: "show",
Description: "Show current configuration",
SubCommands: []SubCommand{
{
Name: "model",
Description: "Current model and provider",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.GetModelInfo == nil {
return req.Reply(unavailableMsg)
}
name, provider := rt.GetModelInfo()
return req.Reply(fmt.Sprintf("Current Model: %s (Provider: %s)", name, provider))
},
},
{
Name: "channel",
Description: "Current channel",
Handler: func(_ context.Context, req Request, _ *Runtime) error {
return req.Reply(fmt.Sprintf("Current Channel: %s", req.Channel))
},
},
{
Name: "agents",
Description: "Registered agents",
Handler: agentsHandler(),
},
},
}
}
+14
View File
@@ -0,0 +1,14 @@
package commands
import "context"
func startCommand() Definition {
return Definition{
Name: "start",
Description: "Start the bot",
Usage: "/start",
Handler: func(_ context.Context, req Request, _ *Runtime) error {
return req.Reply("Hello! I am PicoClaw 🦞")
},
}
}
+42
View File
@@ -0,0 +1,42 @@
package commands
import (
"context"
"fmt"
)
func switchCommand() Definition {
return Definition{
Name: "switch",
Description: "Switch model",
SubCommands: []SubCommand{
{
Name: "model",
Description: "Switch to a different model",
ArgsUsage: "to <name>",
Handler: func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.SwitchModel == nil {
return req.Reply(unavailableMsg)
}
// Parse: /switch model to <value>
value := nthToken(req.Text, 3) // tokens: [/switch, model, to, <value>]
if nthToken(req.Text, 2) != "to" || value == "" {
return req.Reply("Usage: /switch model to <name>")
}
oldModel, err := rt.SwitchModel(value)
if err != nil {
return req.Reply(err.Error())
}
return req.Reply(fmt.Sprintf("Switched model from %s to %s", oldModel, value))
},
},
{
Name: "channel",
Description: "Moved to /check channel",
Handler: func(_ context.Context, req Request, _ *Runtime) error {
return req.Reply("This command has moved. Please use: /check channel <name>")
},
},
},
}
}
+279
View File
@@ -0,0 +1,279 @@
package commands
import (
"context"
"fmt"
"testing"
)
func TestSwitchModel_Success(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old-model", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
want := "Switched model from old-model to gpt-4"
if reply != want {
t.Fatalf("reply=%q, want=%q", reply, want)
}
}
func TestSwitchModel_MissingToKeyword(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /switch model to <name>" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestSwitchModel_MissingValue(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /switch model to <name>" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestSwitchModel_Error(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "", fmt.Errorf("model not found")
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to bad-model",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "model not found" {
t.Fatalf("reply=%q, want error message", reply)
}
}
func TestSwitchModel_NilDep(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch model to gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Command unavailable in current context." {
t.Fatalf("reply=%q, want unavailable message", reply)
}
}
func TestSwitchChannel_Redirect(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch channel to telegram",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
want := "This command has moved. Please use: /check channel <name>"
if reply != want {
t.Fatalf("reply=%q, want=%q", reply, want)
}
}
func TestCheckChannel_Success(t *testing.T) {
rt := &Runtime{
SwitchChannel: func(value string) error {
return nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel telegram",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
want := "Channel 'telegram' is available and enabled"
if reply != want {
t.Fatalf("reply=%q, want=%q", reply, want)
}
}
func TestCheckChannel_Error(t *testing.T) {
rt := &Runtime{
SwitchChannel: func(value string) error {
return fmt.Errorf("channel '%s' not found", value)
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel unknown",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "channel 'unknown' not found" {
t.Fatalf("reply=%q, want error message", reply)
}
}
func TestCheckChannel_NilDep(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel telegram",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Command unavailable in current context." {
t.Fatalf("reply=%q, want unavailable message", reply)
}
}
func TestCheckChannel_MissingValue(t *testing.T) {
rt := &Runtime{
SwitchChannel: func(value string) error {
return nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/check channel",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /check channel <name>" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestSwitch_BangPrefix(t *testing.T) {
rt := &Runtime{
SwitchModel: func(value string) (string, error) {
return "old", nil
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "!switch model to gpt-4",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("! prefix: outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Switched model from old to gpt-4" {
t.Fatalf("! prefix: reply=%q, want success message", reply)
}
}
func TestSwitch_NoSubCommand(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{})
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/switch",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
// Should get usage message from executor's sub-command routing
if reply == "" {
t.Fatal("expected usage reply for bare /switch")
}
}
+48
View File
@@ -0,0 +1,48 @@
package commands
import (
"fmt"
"strings"
)
// SubCommand defines a single sub-command within a parent command.
type SubCommand struct {
Name string
Description string
ArgsUsage string // optional, e.g. "<session-id>"
Handler Handler
}
// Definition is the single-source metadata and behavior contract for a slash command.
//
// Design notes (phase 1):
// - Every channel reads command shape from this type instead of keeping local copies.
// - Visibility is global: all definitions are considered available to all channels.
// - Platform menu registration (for example Telegram BotCommand) also derives from this
// same definition so UI labels and runtime behavior stay aligned.
type Definition struct {
Name string
Description string
Usage string // for simple commands; ignored when SubCommands is set
Aliases []string
SubCommands []SubCommand // optional; when set, Executor routes to sub-command handlers
Handler Handler // for simple commands without sub-commands
}
// EffectiveUsage returns the usage string. When SubCommands are present,
// it is auto-generated from sub-command names so metadata and behavior
// cannot drift.
func (d Definition) EffectiveUsage() string {
if len(d.SubCommands) == 0 {
return d.Usage
}
names := make([]string, 0, len(d.SubCommands))
for _, sc := range d.SubCommands {
name := sc.Name
if sc.ArgsUsage != "" {
name += " " + sc.ArgsUsage
}
names = append(names, name)
}
return fmt.Sprintf("/%s [%s]", d.Name, strings.Join(names, "|"))
}
+41
View File
@@ -0,0 +1,41 @@
package commands
import (
"testing"
)
func TestDefinition_EffectiveUsage_NoSubCommands(t *testing.T) {
d := Definition{Name: "start", Usage: "/start"}
if got := d.EffectiveUsage(); got != "/start" {
t.Fatalf("EffectiveUsage()=%q, want %q", got, "/start")
}
}
func TestDefinition_EffectiveUsage_WithSubCommands(t *testing.T) {
d := Definition{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"},
{Name: "channel"},
{Name: "agents"},
},
}
want := "/show [model|channel|agents]"
if got := d.EffectiveUsage(); got != want {
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
}
}
func TestDefinition_EffectiveUsage_WithArgsUsage(t *testing.T) {
d := Definition{
Name: "session",
SubCommands: []SubCommand{
{Name: "list"},
{Name: "resume", ArgsUsage: "<id>"},
},
}
want := "/session [list|resume <id>]"
if got := d.EffectiveUsage(); got != want {
t.Fatalf("EffectiveUsage()=%q, want %q", got, want)
}
}
+89
View File
@@ -0,0 +1,89 @@
package commands
import (
"context"
"fmt"
)
type Outcome int
const (
// OutcomePassthrough means this input should continue through normal agent flow.
OutcomePassthrough Outcome = iota
// OutcomeHandled means a command handler executed (with or without handler error).
OutcomeHandled
)
type ExecuteResult struct {
Outcome Outcome
Command string
Err error
}
type Executor struct {
reg *Registry
rt *Runtime
}
func NewExecutor(reg *Registry, rt *Runtime) *Executor {
return &Executor{reg: reg, rt: rt}
}
// Execute implements a two-state command decision:
// 1) handled: execute command immediately;
// 2) passthrough: not a command or intentionally deferred to agent logic.
func (e *Executor) Execute(ctx context.Context, req Request) ExecuteResult {
cmdName, ok := parseCommandName(req.Text)
if !ok {
return ExecuteResult{Outcome: OutcomePassthrough}
}
if e == nil || e.reg == nil {
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
}
def, found := e.reg.Lookup(cmdName)
if !found {
return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName}
}
return e.executeDefinition(ctx, req, def)
}
func (e *Executor) executeDefinition(ctx context.Context, req Request, def Definition) ExecuteResult {
// Ensure Reply is always non-nil so handlers don't need to check.
if req.Reply == nil {
req.Reply = func(string) error { return nil }
}
// Simple command — no sub-commands
if len(def.SubCommands) == 0 {
if def.Handler == nil {
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
}
err := def.Handler(ctx, req, e.rt)
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
// Sub-command routing
subName := nthToken(req.Text, 1)
if subName == "" {
err := req.Reply("Usage: " + def.EffectiveUsage())
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
normalized := normalizeCommandName(subName)
for _, sc := range def.SubCommands {
if normalizeCommandName(sc.Name) == normalized {
if sc.Handler == nil {
return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name}
}
err := sc.Handler(ctx, req, e.rt)
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
}
// Unknown sub-command
err := req.Reply(fmt.Sprintf("Unknown option: %s. Usage: %s", subName, def.EffectiveUsage()))
return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err}
}
+260
View File
@@ -0,0 +1,260 @@
package commands
import (
"context"
"errors"
"strings"
"testing"
)
func TestExecutor_RegisteredWithoutHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{{Name: "show"}}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/show"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
}
func TestExecutor_UnknownSlashCommand_ReturnsPassthrough(t *testing.T) {
defs := []Definition{{Name: "show"}}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/unknown"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
}
func TestExecutor_SupportedCommandWithHandler_ReturnsHandled(t *testing.T) {
called := false
defs := []Definition{
{
Name: "help",
Handler: func(context.Context, Request, *Runtime) error {
called = true
return nil
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help@my_bot"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !called {
t.Fatalf("expected handler to be called")
}
}
func TestExecutor_AliasWithoutHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{
{
Name: "show",
Aliases: []string{"display"},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/display"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
if res.Command != "show" {
t.Fatalf("command=%q, want=%q", res.Command, "show")
}
}
func TestExecutor_AliasWithHandler_ReturnsHandled(t *testing.T) {
called := false
defs := []Definition{
{
Name: "clear",
Aliases: []string{"reset"},
Handler: func(context.Context, Request, *Runtime) error {
called = true
return nil
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/reset"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if res.Command != "clear" {
t.Fatalf("command=%q, want=%q", res.Command, "clear")
}
if !called {
t.Fatalf("expected handler to be called")
}
}
func TestExecutor_SupportedCommandWithNilHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{
{Name: "placeholder"},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder list"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
if res.Command != "placeholder" {
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
}
}
func TestExecutor_NilHandlerDoesNotMaskLaterHandler(t *testing.T) {
// With Lookup-based dispatch, the first registered definition for a name wins.
// A definition with nil Handler and no SubCommands returns Passthrough.
defs := []Definition{
{Name: "placeholder"},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
if res.Command != "placeholder" {
t.Fatalf("command=%q, want=%q", res.Command, "placeholder")
}
}
func TestExecutor_HandlerErrorIsPropagated(t *testing.T) {
wantErr := errors.New("handler failed")
defs := []Definition{
{
Name: "help",
Handler: func(context.Context, Request, *Runtime) error {
return wantErr
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !errors.Is(res.Err, wantErr) {
t.Fatalf("err=%v, want=%v", res.Err, wantErr)
}
}
func TestExecutor_SupportsBangPrefixAndCaseInsensitiveCommand(t *testing.T) {
called := false
defs := []Definition{
{
Name: "help",
Handler: func(context.Context, Request, *Runtime) error {
called = true
return nil
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "!HELP"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !called {
t.Fatalf("expected handler to be called")
}
}
func TestExecutor_SubCommand_RoutesToCorrectHandler(t *testing.T) {
modelCalled := false
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model", Handler: func(_ context.Context, _ Request, _ *Runtime) error {
modelCalled = true
return nil
}},
{Name: "channel"},
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Text: "/show model"})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !modelCalled {
t.Fatal("model sub-command handler was not called")
}
}
func TestExecutor_SubCommand_NoArg_RepliesUsage(t *testing.T) {
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"},
{Name: "channel"},
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/show",
Reply: func(text string) error { reply = text; return nil },
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if reply != "Usage: /show [model|channel]" {
t.Fatalf("reply=%q, want usage message", reply)
}
}
func TestExecutor_SubCommand_UnknownArg_RepliesError(t *testing.T) {
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"},
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
var reply string
res := ex.Execute(context.Background(), Request{
Text: "/show foobar",
Reply: func(text string) error { reply = text; return nil },
})
if res.Outcome != OutcomeHandled {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if !strings.Contains(reply, "foobar") {
t.Fatalf("reply=%q, should mention unknown sub-command", reply)
}
}
func TestExecutor_SubCommand_NilHandler_ReturnsPassthrough(t *testing.T) {
defs := []Definition{
{
Name: "show",
SubCommands: []SubCommand{
{Name: "model"}, // nil Handler
},
},
}
ex := NewExecutor(NewRegistry(defs), nil)
res := ex.Execute(context.Background(), Request{Text: "/show model"})
if res.Outcome != OutcomePassthrough {
t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough)
}
}
+21
View File
@@ -0,0 +1,21 @@
package commands
import (
"context"
"fmt"
"strings"
)
// agentsHandler returns a shared handler for both /show agents and /list agents.
func agentsHandler() Handler {
return func(_ context.Context, req Request, rt *Runtime) error {
if rt == nil || rt.ListAgentIDs == nil {
return req.Reply(unavailableMsg)
}
ids := rt.ListAgentIDs()
if len(ids) == 0 {
return req.Reply("No agents registered")
}
return req.Reply(fmt.Sprintf("Registered agents: %s", strings.Join(ids, ", ")))
}
}
+55
View File
@@ -0,0 +1,55 @@
package commands
type Registry struct {
defs []Definition
index map[string]int
}
// NewRegistry stores the canonical command set used by both dispatch and
// optional platform registration adapters.
func NewRegistry(defs []Definition) *Registry {
stored := make([]Definition, len(defs))
copy(stored, defs)
index := make(map[string]int, len(stored)*2)
for i, def := range stored {
registerCommandName(index, def.Name, i)
for _, alias := range def.Aliases {
registerCommandName(index, alias, i)
}
}
return &Registry{defs: stored, index: index}
}
// Definitions returns all registered command definitions.
// Command availability is global and no longer channel-scoped.
func (r *Registry) Definitions() []Definition {
out := make([]Definition, len(r.defs))
copy(out, r.defs)
return out
}
// Lookup returns a command definition by normalized command name or alias.
func (r *Registry) Lookup(name string) (Definition, bool) {
key := normalizeCommandName(name)
if key == "" {
return Definition{}, false
}
idx, ok := r.index[key]
if !ok {
return Definition{}, false
}
return r.defs[idx], true
}
func registerCommandName(index map[string]int, name string, defIndex int) {
key := normalizeCommandName(name)
if key == "" {
return
}
if _, exists := index[key]; exists {
return
}
index[key] = defIndex
}
+49
View File
@@ -0,0 +1,49 @@
package commands
import "testing"
func TestRegistry_Definitions_ReturnsCopy(t *testing.T) {
defs := []Definition{
{Name: "help", Description: "Show help"},
{Name: "admin", Description: "Admin command"},
}
r := NewRegistry(defs)
got := r.Definitions()
if len(got) != 2 {
t.Fatalf("definitions len = %d, want 2", len(got))
}
got[0].Name = "mutated"
again := r.Definitions()
if again[0].Name != "help" {
t.Fatalf("registry should not be mutated by caller, got first name %q", again[0].Name)
}
}
func TestRegistry_Lookup_MatchesByLowercaseNameAndAlias(t *testing.T) {
r := NewRegistry([]Definition{
{Name: "Help", Aliases: []string{"Assist"}},
{Name: "List"},
})
def, ok := r.Lookup("help")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by lowercase name failed: ok=%v def=%+v", ok, def)
}
def, ok = r.Lookup("HELP")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by uppercase name failed: ok=%v def=%+v", ok, def)
}
def, ok = r.Lookup("assist")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by lowercase alias failed: ok=%v def=%+v", ok, def)
}
def, ok = r.Lookup("ASSIST")
if !ok || def.Name != "Help" {
t.Fatalf("lookup by uppercase alias failed: ok=%v def=%+v", ok, def)
}
}
+75
View File
@@ -0,0 +1,75 @@
package commands
import (
"context"
"strings"
)
type Handler func(ctx context.Context, req Request, rt *Runtime) error
type Request struct {
Channel string
ChatID string
SenderID string
Text string
Reply func(text string) error
}
const unavailableMsg = "Command unavailable in current context."
var commandPrefixes = []string{"/", "!"}
// parseCommandName accepts "/name", "!name", and Telegram's "/name@bot", then
// normalizes to lowercase command names.
func parseCommandName(input string) (string, bool) {
token := nthToken(input, 0)
if token == "" {
return "", false
}
name, ok := trimCommandPrefix(token)
if !ok {
return "", false
}
if i := strings.Index(name, "@"); i >= 0 {
name = name[:i]
}
name = normalizeCommandName(name)
if name == "" {
return "", false
}
return name, true
}
func trimCommandPrefix(token string) (string, bool) {
for _, prefix := range commandPrefixes {
if strings.HasPrefix(token, prefix) {
return strings.TrimPrefix(token, prefix), true
}
}
return "", false
}
// HasCommandPrefix returns true if the input starts with a recognized
// command prefix (e.g. "/" or "!").
func HasCommandPrefix(input string) bool {
token := nthToken(input, 0)
if token == "" {
return false
}
_, ok := trimCommandPrefix(token)
return ok
}
// nthToken returns the 0-indexed token from whitespace-split input.
func nthToken(input string, n int) string {
parts := strings.Fields(strings.TrimSpace(input))
if n >= len(parts) {
return ""
}
return parts[n]
}
func normalizeCommandName(name string) string {
return strings.ToLower(strings.TrimSpace(name))
}
+28
View File
@@ -0,0 +1,28 @@
package commands
import "testing"
func TestHasCommandPrefix(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"/help", true},
{"!help", true},
{"/switch model to gpt-4", true},
{"!switch model to gpt-4", true},
{"hello", false},
{"", false},
{" ", false},
{"hello /world", false},
{"/", true},
{"!", true},
{" /help", true},
}
for _, tt := range tests {
got := HasCommandPrefix(tt.input)
if got != tt.want {
t.Errorf("HasCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
+16
View File
@@ -0,0 +1,16 @@
package commands
import "github.com/sipeed/picoclaw/pkg/config"
// Runtime provides runtime dependencies to command handlers. It is constructed
// per-request by the agent loop so that per-request state (like session scope)
// can coexist with long-lived callbacks (like GetModelInfo).
type Runtime struct {
Config *config.Config
GetModelInfo func() (name, provider string)
ListAgentIDs func() []string
ListDefinitions func() []Definition
GetEnabledChannels func() []string
SwitchModel func(value string) (oldModel string, err error)
SwitchChannel func(value string) error
}
+85
View File
@@ -0,0 +1,85 @@
package commands
import (
"context"
"strings"
"testing"
)
func TestShowListHandlers_ChannelPolicy(t *testing.T) {
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), nil)
var telegramReply string
handled := ex.Execute(context.Background(), Request{
Channel: "telegram",
Text: "/show channel",
Reply: func(text string) error {
telegramReply = text
return nil
},
})
if handled.Outcome != OutcomeHandled {
t.Fatalf("telegram /show outcome=%v, want=%v", handled.Outcome, OutcomeHandled)
}
if telegramReply != "Current Channel: telegram" {
t.Fatalf("telegram /show reply=%q, want=%q", telegramReply, "Current Channel: telegram")
}
var whatsappReply string
handledWhatsApp := ex.Execute(context.Background(), Request{
Channel: "whatsapp",
Text: "/show channel",
Reply: func(text string) error {
whatsappReply = text
return nil
},
})
if handledWhatsApp.Outcome != OutcomeHandled {
t.Fatalf("whatsapp /show outcome=%v, want=%v", handledWhatsApp.Outcome, OutcomeHandled)
}
if handledWhatsApp.Command != "show" {
t.Fatalf("whatsapp /show command=%q, want=%q", handledWhatsApp.Command, "show")
}
if whatsappReply != "Current Channel: whatsapp" {
t.Fatalf("whatsapp /show reply=%q, want=%q", whatsappReply, "Current Channel: whatsapp")
}
passthrough := ex.Execute(context.Background(), Request{
Channel: "whatsapp",
Text: "/foo",
})
if passthrough.Outcome != OutcomePassthrough {
t.Fatalf("whatsapp /foo outcome=%v, want=%v", passthrough.Outcome, OutcomePassthrough)
}
if passthrough.Command != "foo" {
t.Fatalf("whatsapp /foo command=%q, want=%q", passthrough.Command, "foo")
}
}
func TestShowListHandlers_ListHandledOnAllChannels(t *testing.T) {
rt := &Runtime{
GetEnabledChannels: func() []string {
return []string{"telegram"}
},
}
ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt)
var reply string
res := ex.Execute(context.Background(), Request{
Channel: "whatsapp",
Text: "/list channels",
Reply: func(text string) error {
reply = text
return nil
},
})
if res.Outcome != OutcomeHandled {
t.Fatalf("whatsapp /list outcome=%v, want=%v", res.Outcome, OutcomeHandled)
}
if res.Command != "list" {
t.Fatalf("whatsapp /list command=%q, want=%q", res.Command, "list")
}
if !strings.Contains(reply, "telegram") {
t.Fatalf("whatsapp /list reply=%q, expected enabled channels content", reply)
}
}
+121 -35
View File
@@ -167,22 +167,35 @@ type SessionConfig struct {
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
}
// RoutingConfig controls the intelligent model routing feature.
// When enabled, each incoming message is scored against structural features
// (message length, code blocks, tool call history, conversation depth, attachments).
// Messages scoring below Threshold are sent to LightModel; all others use the
// agent's primary model. This reduces cost and latency for simple tasks without
// requiring any keyword matching — all scoring is language-agnostic.
type RoutingConfig struct {
Enabled bool `json:"enabled"`
LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
}
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
@@ -526,6 +539,10 @@ type GatewayConfig struct {
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
}
type ToolConfig struct {
Enabled bool `json:"enabled" env:"ENABLED"`
}
type BraveConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
@@ -550,6 +567,12 @@ type PerplexityConfig struct {
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type SearXNGConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_SEARXNG_ENABLED"`
BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_SEARXNG_BASE_URL"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARXNG_MAX_RESULTS"`
}
type GLMSearchConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
@@ -561,11 +584,13 @@ type GLMSearchConfig struct {
}
type WebToolsConfig struct {
Brave BraveConfig `json:"brave"`
Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
GLMSearch GLMSearchConfig `json:"glm_search"`
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"`
Brave BraveConfig ` json:"brave"`
Tavily TavilyConfig ` json:"tavily"`
DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"`
Perplexity PerplexityConfig ` json:"perplexity"`
SearXNG SearXNGConfig ` json:"searxng"`
GLMSearch GLMSearchConfig ` json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
@@ -573,19 +598,29 @@ type WebToolsConfig struct {
}
type CronToolsConfig struct {
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
}
type ExecConfig struct {
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"`
EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"`
CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"`
CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"`
TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s)
}
type SkillsToolsConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"`
Registries SkillsRegistriesConfig ` json:"registries"`
MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig ` json:"search_cache"`
}
type MediaCleanupConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"`
MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"`
Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"`
ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"`
MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"`
Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"`
}
type ToolsConfig struct {
@@ -597,12 +632,19 @@ type ToolsConfig struct {
Skills SkillsToolsConfig `json:"skills"`
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
MCP MCPConfig `json:"mcp"`
}
type SkillsToolsConfig struct {
Registries SkillsRegistriesConfig `json:"registries"`
MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
SearchCache SearchCacheConfig `json:"search_cache"`
AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"`
EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"`
FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"`
I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"`
InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"`
ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"`
Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"`
ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"`
}
type SearchCacheConfig struct {
@@ -648,8 +690,7 @@ type MCPServerConfig struct {
// MCPConfig defines configuration for all MCP servers
type MCPConfig struct {
// Enabled globally enables/disables MCP integration
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"`
// Servers is a map of server name to server configuration
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
}
@@ -835,3 +876,48 @@ func (c *Config) ValidateModelList() error {
}
return nil
}
func (t *ToolsConfig) IsToolEnabled(name string) bool {
switch name {
case "web":
return t.Web.Enabled
case "cron":
return t.Cron.Enabled
case "exec":
return t.Exec.Enabled
case "skills":
return t.Skills.Enabled
case "media_cleanup":
return t.MediaCleanup.Enabled
case "append_file":
return t.AppendFile.Enabled
case "edit_file":
return t.EditFile.Enabled
case "find_skills":
return t.FindSkills.Enabled
case "i2c":
return t.I2C.Enabled
case "install_skill":
return t.InstallSkill.Enabled
case "list_dir":
return t.ListDir.Enabled
case "message":
return t.Message.Enabled
case "read_file":
return t.ReadFile.Enabled
case "spawn":
return t.Spawn.Enabled
case "spi":
return t.SPI.Enabled
case "subagent":
return t.Subagent.Enabled
case "web_fetch":
return t.WebFetch.Enabled
case "write_file":
return t.WriteFile.Enabled
case "mcp":
return t.MCP.Enabled
default:
return true
}
}
+63 -2
View File
@@ -336,11 +336,16 @@ func DefaultConfig() *Config {
},
Tools: ToolsConfig{
MediaCleanup: MediaCleanupConfig{
Enabled: true,
ToolConfig: ToolConfig{
Enabled: true,
},
MaxAge: 30,
Interval: 5,
},
Web: WebToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
Proxy: "",
FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
Brave: BraveConfig{
@@ -357,6 +362,11 @@ func DefaultConfig() *Config {
APIKey: "",
MaxResults: 5,
},
SearXNG: SearXNGConfig{
Enabled: false,
BaseURL: "",
MaxResults: 5,
},
GLMSearch: GLMSearchConfig{
Enabled: false,
APIKey: "",
@@ -366,12 +376,22 @@ func DefaultConfig() *Config {
},
},
Cron: CronToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
ExecTimeoutMinutes: 5,
},
Exec: ExecConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
EnableDenyPatterns: true,
TimeoutSeconds: 60,
},
Skills: SkillsToolsConfig{
ToolConfig: ToolConfig{
Enabled: true,
},
Registries: SkillsRegistriesConfig{
ClawHub: ClawHubRegistryConfig{
Enabled: true,
@@ -385,9 +405,50 @@ func DefaultConfig() *Config {
},
},
MCP: MCPConfig{
Enabled: false,
ToolConfig: ToolConfig{
Enabled: false,
},
Servers: map[string]MCPServerConfig{},
},
AppendFile: ToolConfig{
Enabled: true,
},
EditFile: ToolConfig{
Enabled: true,
},
FindSkills: ToolConfig{
Enabled: true,
},
I2C: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
InstallSkill: ToolConfig{
Enabled: true,
},
ListDir: ToolConfig{
Enabled: true,
},
Message: ToolConfig{
Enabled: true,
},
ReadFile: ToolConfig{
Enabled: true,
},
Spawn: ToolConfig{
Enabled: true,
},
SPI: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
Subagent: ToolConfig{
Enabled: true,
},
WebFetch: ToolConfig{
Enabled: true,
},
WriteFile: ToolConfig{
Enabled: true,
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
+13 -3
View File
@@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
mgr := NewManager()
mcpCfg := config.MCPConfig{
Enabled: true,
ToolConfig: config.ToolConfig{
Enabled: true,
},
Servers: map[string]config.MCPServerConfig{
"test-server": {
Enabled: true,
@@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) {
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
mgr := NewManager()
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
err := mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
}
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
err = mgr.LoadFromMCPConfig(
context.Background(),
config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}},
"/tmp",
)
if err != nil {
t.Fatalf("expected nil error when no servers configured, got: %v", err)
}
+8 -2
View File
@@ -640,7 +640,10 @@ func FetchAntigravityProjectID(accessToken string) (string, error) {
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading loadCodeAssist response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("loadCodeAssist failed: %s", string(body))
}
@@ -681,7 +684,10 @@ func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelIn
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading fetchAvailableModels response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(
"fetchAvailableModels failed (HTTP %d): %s",
+80
View File
@@ -0,0 +1,80 @@
package routing
// Classifier evaluates a feature set and returns a complexity score in [0, 1].
// A higher score indicates a more complex task that benefits from a heavy model.
// The score is compared against the configured threshold: score >= threshold selects
// the primary (heavy) model; score < threshold selects the light model.
//
// Classifier is an interface so that future implementations (ML-based, embedding-based,
// or any other approach) can be swapped in without changing routing infrastructure.
type Classifier interface {
Score(f Features) float64
}
// RuleClassifier is the v1 implementation.
// It uses a weighted sum of structural signals with no external dependencies,
// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so
// that the returned score always falls within the [0, 1] contract.
//
// Individual weights (multiple signals can fire simultaneously):
//
// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex
// token 50-200: 0.15 — medium length; may or may not be complex
// code block present: 0.40 — coding tasks need the heavy model
// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow
// tool calls 1-3 (recent): 0.10 — some tool activity
// conversation depth > 10: 0.10 — long sessions carry implicit complexity
// attachments present: 1.00 — hard gate; multi-modal always needs heavy model
//
// Default threshold is 0.35, so:
// - Pure greetings / trivial Q&A: 0.00 → light ✓
// - Medium prose message (50200 tokens): 0.15 → light ✓
// - Message with code block: 0.40 → heavy ✓
// - Long message (>200 tokens): 0.35 → heavy ✓
// - Active tool session + medium message: 0.25 → light (acceptable)
// - Any message with an image/audio attachment: 1.00 → heavy ✓
type RuleClassifier struct{}
// Score computes the complexity score for the given feature set.
// The returned value is in [0, 1]. Attachments short-circuit to 1.0.
func (c *RuleClassifier) Score(f Features) float64 {
// Hard gate: multi-modal inputs always require the heavy model.
if f.HasAttachments {
return 1.0
}
var score float64
// Token estimate — primary verbosity signal
switch {
case f.TokenEstimate > 200:
score += 0.35
case f.TokenEstimate > 50:
score += 0.15
}
// Fenced code blocks — strongest indicator of a coding/technical task
if f.CodeBlockCount > 0 {
score += 0.40
}
// Recent tool call density — indicates an ongoing agentic workflow
switch {
case f.RecentToolCalls > 3:
score += 0.25
case f.RecentToolCalls > 0:
score += 0.10
}
// Conversation depth — accumulated context implies compound task
if f.ConversationDepth > 10 {
score += 0.10
}
// Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire
// simultaneously (e.g., long message + code block + tool chain = 1.10 raw).
if score > 1.0 {
score = 1.0
}
return score
}
+127
View File
@@ -0,0 +1,127 @@
package routing
import (
"strings"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// lookbackWindow is the number of recent history entries scanned for tool calls.
// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant).
const lookbackWindow = 6
// Features holds the structural signals extracted from a message and its session context.
// Every dimension is language-agnostic by construction — no keyword or pattern matching
// against natural-language content. This ensures consistent routing for all locales.
type Features struct {
// TokenEstimate is a proxy for token count.
// CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each.
// This avoids API calls while giving accurate estimates for all scripts.
TokenEstimate int
// CodeBlockCount is the number of fenced code blocks (``` pairs) in the message.
// Coding tasks almost always require the heavy model.
CodeBlockCount int
// RecentToolCalls is the count of tool_call messages in the last lookbackWindow
// history entries. A high density indicates an active agentic workflow.
RecentToolCalls int
// ConversationDepth is the total number of messages in the session history.
// Deep sessions tend to carry implicit complexity built up over many turns.
ConversationDepth int
// HasAttachments is true when the message appears to contain media (images,
// audio, video). Multi-modal inputs require vision-capable heavy models.
HasAttachments bool
}
// ExtractFeatures computes the structural feature vector for a message.
// It is a pure function with no side effects and zero allocations beyond
// the returned struct.
func ExtractFeatures(msg string, history []providers.Message) Features {
return Features{
TokenEstimate: estimateTokens(msg),
CodeBlockCount: countCodeBlocks(msg),
RecentToolCalls: countRecentToolCalls(history),
ConversationDepth: len(history),
HasAttachments: hasAttachments(msg),
}
}
// estimateTokens returns a token count proxy that handles both CJK and Latin text.
// CJK runes (U+2E80U+9FFF, U+F900U+FAFF, U+AC00U+D7AF) map to roughly one
// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token
// for English). Splitting the count this way avoids the 3x underestimation that a
// flat rune_count/3 would produce for Chinese, Japanese, and Korean text.
func estimateTokens(msg string) int {
total := utf8.RuneCountInString(msg)
if total == 0 {
return 0
}
cjk := 0
for _, r := range msg {
if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF {
cjk++
}
}
return cjk + (total-cjk)/4
}
// countCodeBlocks counts the number of complete fenced code blocks.
// Each ``` delimiter increments a counter; pairs of delimiters form one block.
// An unclosed opening fence (odd count) is treated as zero complete blocks
// since it may just be an inline code span or a typo.
func countCodeBlocks(msg string) int {
n := strings.Count(msg, "```")
return n / 2
}
// countRecentToolCalls counts messages with tool calls in the last lookbackWindow
// entries of history. It examines the ToolCalls field rather than parsing
// the content string, so it is robust to any message format.
func countRecentToolCalls(history []providers.Message) int {
start := len(history) - lookbackWindow
if start < 0 {
start = 0
}
count := 0
for _, msg := range history[start:] {
if len(msg.ToolCalls) > 0 {
count += len(msg.ToolCalls)
}
}
return count
}
// hasAttachments returns true when the message content contains embedded media.
// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and
// common image/audio URL extensions. This is intentionally conservative —
// false negatives (missing an attachment) just mean the routing falls back to
// the primary model anyway.
func hasAttachments(msg string) bool {
lower := strings.ToLower(msg)
// Base64 data URIs embedded directly in the message
if strings.Contains(lower, "data:image/") ||
strings.Contains(lower, "data:audio/") ||
strings.Contains(lower, "data:video/") {
return true
}
// Common image/audio extensions in URLs or file references
mediaExts := []string{
".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp",
".mp3", ".wav", ".ogg", ".m4a", ".flac",
".mp4", ".avi", ".mov", ".webm",
}
for _, ext := range mediaExts {
if strings.Contains(lower, ext) {
return true
}
}
return false
}
+82
View File
@@ -0,0 +1,82 @@
package routing
import (
"github.com/sipeed/picoclaw/pkg/providers"
)
// defaultThreshold is used when the config threshold is zero or negative.
// At 0.35 a message needs at least one strong signal (code block, long text,
// or an attachment) before the heavy model is chosen.
const defaultThreshold = 0.35
// RouterConfig holds the validated model routing settings.
// It mirrors config.RoutingConfig but lives in pkg/routing to keep the
// dependency graph simple: pkg/agent resolves config → routing, not the reverse.
type RouterConfig struct {
// LightModel is the model_name (from model_list) used for simple tasks.
LightModel string
// Threshold is the complexity score cutoff in [0, 1].
// score >= Threshold → primary (heavy) model.
// score < Threshold → light model.
Threshold float64
}
// Router selects the appropriate model tier for each incoming message.
// It is safe for concurrent use from multiple goroutines.
type Router struct {
cfg RouterConfig
classifier Classifier
}
// New creates a Router with the given config and the default RuleClassifier.
// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used.
func New(cfg RouterConfig) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{
cfg: cfg,
classifier: &RuleClassifier{},
}
}
// newWithClassifier creates a Router with a custom Classifier.
// Intended for unit tests that need to inject a deterministic scorer.
func newWithClassifier(cfg RouterConfig, c Classifier) *Router {
if cfg.Threshold <= 0 {
cfg.Threshold = defaultThreshold
}
return &Router{cfg: cfg, classifier: c}
}
// SelectModel returns the model to use for this conversation turn along with
// the computed complexity score (for logging and debugging).
//
// - If score < cfg.Threshold: returns (cfg.LightModel, true, score)
// - Otherwise: returns (primaryModel, false, score)
//
// The caller is responsible for resolving the returned model name into
// provider candidates (see AgentInstance.LightCandidates).
func (r *Router) SelectModel(
msg string,
history []providers.Message,
primaryModel string,
) (model string, usedLight bool, score float64) {
features := ExtractFeatures(msg, history)
score = r.classifier.Score(features)
if score < r.cfg.Threshold {
return r.cfg.LightModel, true, score
}
return primaryModel, false, score
}
// LightModel returns the configured light model name.
func (r *Router) LightModel() string {
return r.cfg.LightModel
}
// Threshold returns the complexity threshold in use.
func (r *Router) Threshold() float64 {
return r.cfg.Threshold
}
+414
View File
@@ -0,0 +1,414 @@
package routing
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ── ExtractFeatures ──────────────────────────────────────────────────────────
func TestExtractFeatures_EmptyMessage(t *testing.T) {
f := ExtractFeatures("", nil)
if f.TokenEstimate != 0 {
t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate)
}
if f.CodeBlockCount != 0 {
t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount)
}
if f.RecentToolCalls != 0 {
t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls)
}
if f.ConversationDepth != 0 {
t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth)
}
if f.HasAttachments {
t.Error("HasAttachments: got true, want false")
}
}
func TestExtractFeatures_TokenEstimate(t *testing.T) {
// 30 ASCII runes: 0 CJK + 30/4 = 7 tokens
msg := strings.Repeat("a", 30)
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 7 {
t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) {
// 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token).
// Using a rune slice literal avoids CJK string literals in source.
msg := string([]rune{
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60, 0x597D, 0x4E16, 0x754C,
0x4F60,
})
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 9 {
t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate)
}
}
func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) {
// Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens.
msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok"
f := ExtractFeatures(msg, nil)
if f.TokenEstimate != 6 {
t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate)
}
}
func TestExtractFeatures_CodeBlocks(t *testing.T) {
cases := []struct {
msg string
want int
}{
{"no code here", 0},
{"```go\nfmt.Println()\n```", 1},
{"```python\npass\n```\n```js\nconsole.log()\n```", 2},
{"```unclosed", 0}, // odd number of fences = 0 complete blocks
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.CodeBlockCount != tc.want {
t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want)
}
}
}
func TestExtractFeatures_RecentToolCalls(t *testing.T) {
// History longer than lookbackWindow — only last lookbackWindow entries count.
history := make([]providers.Message, 10)
// Put 2 tool calls at positions 8 and 9 (within the last 6)
history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}}
history[9] = providers.Message{
Role: "assistant",
ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}},
}
// Position 3 is outside the lookback window and must NOT be counted
history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}}
f := ExtractFeatures("test", history)
// 1 (position 8) + 2 (position 9) = 3
if f.RecentToolCalls != 3 {
t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls)
}
}
func TestExtractFeatures_ConversationDepth(t *testing.T) {
history := make([]providers.Message, 7)
f := ExtractFeatures("msg", history)
if f.ConversationDepth != 7 {
t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth)
}
}
func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"plain text", false},
{"here is an image: data:image/png;base64,abc123", true},
{"audio: data:audio/mp3;base64,xyz", true},
{"video: data:video/mp4;base64,xyz", true},
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
func TestExtractFeatures_HasAttachments_Extension(t *testing.T) {
cases := []struct {
msg string
want bool
}{
{"check out photo.jpg", true},
{"see screenshot.png", true},
{"listen to audio.mp3", true},
{"watch clip.mp4", true},
{"just a .go file", false},
{"document.pdf", false}, // pdf is not in the media list
}
for _, tc := range cases {
f := ExtractFeatures(tc.msg, nil)
if f.HasAttachments != tc.want {
t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want)
}
}
}
// ── RuleClassifier ───────────────────────────────────────────────────────────
func TestRuleClassifier_ZeroFeatures(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{})
if score != 0.0 {
t.Errorf("zero features: got %f, want 0.0", score)
}
}
func TestRuleClassifier_AttachmentsHardGate(t *testing.T) {
c := &RuleClassifier{}
score := c.Score(Features{HasAttachments: true})
if score != 1.0 {
t.Errorf("attachments: got %f, want 1.0", score)
}
}
func TestRuleClassifier_CodeBlockAlone(t *testing.T) {
c := &RuleClassifier{}
// Code block alone = 0.40, above default threshold 0.35
score := c.Score(Features{CodeBlockCount: 1})
if score < 0.35 {
t.Errorf("code block: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_LongMessage(t *testing.T) {
c := &RuleClassifier{}
// >200 tokens = 0.35, exactly at default threshold → heavy
score := c.Score(Features{TokenEstimate: 250})
if score < 0.35 {
t.Errorf("long message: score %f is below default threshold 0.35", score)
}
}
func TestRuleClassifier_MediumMessage(t *testing.T) {
c := &RuleClassifier{}
// 50-200 tokens = 0.15, below threshold → light
score := c.Score(Features{TokenEstimate: 100})
if score >= 0.35 {
t.Errorf("medium message: score %f should be below default threshold 0.35", score)
}
}
func TestRuleClassifier_ShortMessage(t *testing.T) {
c := &RuleClassifier{}
// <50 tokens, no other signals = 0.0 → light
score := c.Score(Features{TokenEstimate: 10})
if score != 0.0 {
t.Errorf("short message: got %f, want 0.0", score)
}
}
func TestRuleClassifier_ToolCallDensity(t *testing.T) {
c := &RuleClassifier{}
scoreNone := c.Score(Features{RecentToolCalls: 0})
scoreLow := c.Score(Features{RecentToolCalls: 2})
scoreHigh := c.Score(Features{RecentToolCalls: 5})
if scoreNone != 0.0 {
t.Errorf("no tools: got %f, want 0.0", scoreNone)
}
if scoreLow <= scoreNone {
t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone)
}
if scoreHigh <= scoreLow {
t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow)
}
}
func TestRuleClassifier_DeepConversation(t *testing.T) {
c := &RuleClassifier{}
shallow := c.Score(Features{ConversationDepth: 5})
deep := c.Score(Features{ConversationDepth: 15})
if deep <= shallow {
t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow)
}
}
func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) {
c := &RuleClassifier{}
// Max all signals simultaneously
f := Features{
TokenEstimate: 500,
CodeBlockCount: 3,
RecentToolCalls: 10,
ConversationDepth: 20,
}
score := c.Score(f)
if score > 1.0 {
t.Errorf("score %f exceeds 1.0", score)
}
}
// ── Router ───────────────────────────────────────────────────────────────────
func TestRouter_DefaultThreshold(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash"})
if r.Threshold() != defaultThreshold {
t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1})
if r.Threshold() != defaultThreshold {
t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold)
}
}
func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "hi"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("simple message: expected light model to be selected")
}
if model != "gemini-flash" {
t.Errorf("simple message: model got %q, want %q", model, "gemini-flash")
}
}
func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "```go\nfmt.Println(\"hello\")\n```"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("code block: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
msg := "can you analyze this? data:image/png;base64,abc123"
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("attachment: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) {
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
// >200 token estimate: 210 * 3 = 630 chars
msg := strings.Repeat("word ", 210)
model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("long message: expected primary model to be selected")
}
if model != "claude-sonnet-4-6" {
t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6")
}
}
func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) {
// Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior.
// Routing is conservative: only promote to heavy when the signal is unambiguous.
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}},
{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}},
}
msg := "ok"
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if !usedLight {
t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)")
}
}
func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) {
// Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35})
history := []providers.Message{
{Role: "assistant", ToolCalls: []providers.ToolCall{
{Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"},
}},
}
// ~55 tokens * 3 = 165 chars
msg := strings.Repeat("word ", 55)
_, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6")
if usedLight {
t.Error("tool chain + medium message: expected primary model (score >= 0.35)")
}
}
func TestRouter_SelectModel_CustomThreshold(t *testing.T) {
// Very low threshold: even a short message triggers heavy model
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05})
msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if usedLight {
t.Error("low threshold: medium message should use primary model")
}
}
func TestRouter_SelectModel_HighThreshold(t *testing.T) {
// Very high threshold: even code blocks route to light
r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99})
msg := "```go\nfmt.Println()\n```"
_, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6")
if !usedLight {
t.Error("very high threshold: code block (0.40) should route to light model")
}
}
func TestRouter_LightModel(t *testing.T) {
r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35})
if r.LightModel() != "my-fast-model" {
t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model")
}
}
// ── newWithClassifier (internal testing hook) ─────────────────────────────────
type fixedScoreClassifier struct{ score float64 }
func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score }
func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.2},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if !usedLight {
t.Error("low score with custom classifier: expected light model")
}
}
func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.8},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("high score with custom classifier: expected primary model")
}
}
func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) {
// score == threshold → primary (uses >= comparison)
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.5},
)
_, usedLight, _ := r.SelectModel("anything", nil, "heavy")
if usedLight {
t.Error("score == threshold: expected primary model (>= threshold → primary)")
}
}
func TestRouter_SelectModel_ReturnsScore(t *testing.T) {
r := newWithClassifier(
RouterConfig{LightModel: "light", Threshold: 0.5},
&fixedScoreClassifier{score: 0.42},
)
_, _, score := r.SelectModel("anything", nil, "heavy")
if score != 0.42 {
t.Errorf("score: got %f, want 0.42", score)
}
}
+64 -16
View File
@@ -259,15 +259,7 @@ func (c *ClawHubRegistry) DownloadAndInstall(
}
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String())
if err != nil {
return nil, fmt.Errorf("download failed: %w", err)
}
@@ -284,17 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall(
// --- HTTP helper ---
func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
req, err := c.newGetRequest(ctx, urlStr, "application/json")
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
resp, err := c.client.Do(req)
resp, err := utils.DoRequestWithRetry(c.client, req)
if err != nil {
return nil, err
}
@@ -312,3 +299,64 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err
return body, nil
}
func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", accept)
if c.authToken != "" {
req.Header.Set("Authorization", "Bearer "+c.authToken)
}
return req, nil
}
func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) {
req, err := c.newGetRequest(ctx, urlStr, "application/zip")
if err != nil {
return "", err
}
resp, err := utils.DoRequestWithRetry(c.client, req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
errBody := make([]byte, 512)
n, _ := io.ReadFull(resp.Body, errBody)
return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
}
tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
cleanup := func() {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1)
written, err := io.Copy(tmpFile, src)
if err != nil {
cleanup()
return "", fmt.Errorf("download write failed: %w", err)
}
if written > int64(c.maxZipSize) {
cleanup()
return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize)
}
if err := tmpFile.Close(); err != nil {
_ = os.Remove(tmpPath)
return "", fmt.Errorf("failed to close temp file: %w", err)
}
return tmpPath, nil
}
+81
View File
@@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) {
assert.Equal(t, "clawhub", results[0].RegistryName)
}
func TestClawHubRegistrySearchRetries429(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
slug := "github"
name := "GitHub Integration"
summary := "Interact with GitHub repos"
version := "1.0.0"
json.NewEncoder(w).Encode(clawhubSearchResponse{
Results: []clawhubSearchResult{
{Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
},
})
}))
defer srv.Close()
reg := newTestRegistry(srv.URL, "")
results, err := reg.Search(context.Background(), "github", 5)
require.NoError(t, err)
require.Len(t, results, 1)
assert.Equal(t, 2, attempts)
assert.Equal(t, "github", results[0].Slug)
}
func TestClawHubRegistryGetSkillMeta(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
@@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
assert.Contains(t, string(readmeContent), "# Test Skill")
}
func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) {
zipBuf := createTestZip(t, map[string]string{
"SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill",
})
downloadAttempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v1/skills/retry-skill":
json.NewEncoder(w).Encode(clawhubSkillResponse{
Slug: "retry-skill",
DisplayName: "Retry Skill",
Summary: "A retry test skill",
LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
})
case "/api/v1/download":
downloadAttempts++
if downloadAttempts == 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
assert.Equal(t, "retry-skill", r.URL.Query().Get("slug"))
w.Header().Set("Content-Type", "application/zip")
w.Write(zipBuf)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer srv.Close()
tmpDir := t.TempDir()
targetDir := filepath.Join(tmpDir, "retry-skill")
reg := newTestRegistry(srv.URL, "")
result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, "1.0.0", result.Version)
assert.Equal(t, 2, downloadAttempts)
skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
require.NoError(t, err)
assert.Contains(t, string(skillContent), "Hello skill")
}
func TestClawHubRegistryAuthToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
+6 -1
View File
@@ -132,9 +132,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
timeout := 60 * time.Second
if config != nil && config.Tools.Exec.TimeoutSeconds > 0 {
timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second
}
return &ExecTool{
workingDir: workingDir,
timeout: 60 * time.Second,
timeout: timeout,
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
+71 -1
View File
@@ -395,6 +395,68 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
type SearXNGSearchProvider struct {
baseURL string
}
func (p *SearXNGSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("%s/search?q=%s&format=json&categories=general",
strings.TrimSuffix(p.baseURL, "/"),
url.QueryEscape(query))
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("SearXNG returned status %d", resp.StatusCode)
}
var result struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
Engine string `json:"engine"`
Score float64 `json:"score"`
} `json:"results"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
// Limit results to requested count
if len(result.Results) > count {
result.Results = result.Results[:count]
}
// Format results in standard PicoClaw format
var b strings.Builder
b.WriteString(fmt.Sprintf("Results for: %s (via SearXNG)\n", query))
for i, r := range result.Results {
b.WriteString(fmt.Sprintf("%d. %s\n", i+1, r.Title))
b.WriteString(fmt.Sprintf(" %s\n", r.URL))
if r.Content != "" {
b.WriteString(fmt.Sprintf(" %s\n", r.Content))
}
}
return b.String(), nil
}
type GLMSearchProvider struct {
apiKey string
baseURL string
@@ -495,6 +557,9 @@ type WebSearchToolOptions struct {
PerplexityAPIKey string
PerplexityMaxResults int
PerplexityEnabled bool
SearXNGBaseURL string
SearXNGMaxResults int
SearXNGEnabled bool
GLMSearchAPIKey string
GLMSearchBaseURL string
GLMSearchEngine string
@@ -507,7 +572,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search
// Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
if err != nil {
@@ -526,6 +591,11 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.SearXNGEnabled && opts.SearXNGBaseURL != "" {
provider = &SearXNGSearchProvider{baseURL: opts.SearXNGBaseURL}
if opts.SearXNGMaxResults > 0 {
maxResults = opts.SearXNGMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
client, err := createHTTPClient(opts.Proxy, searchTimeout)
if err != nil {