diff --git a/.gitignore b/.gitignore index 02ef18d1f..a52b8d25a 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,9 @@ ralph/ .ralph/ tasks/ +# Plans +docs/plans/ + # Editors .vscode/ .idea/ diff --git a/README.md b/README.md index 7a31f9364..3774055b4 100644 --- a/README.md +++ b/README.md @@ -353,6 +353,13 @@ Talk to your picoclaw through Telegram, Discord, WhatsApp, DingTalk, LINE, or We picoclaw gateway ``` +**4. Telegram command menu (auto-registered at startup)** + +PicoClaw now keeps command definitions in one shared registry. On startup, Telegram will automatically register supported bot commands (for example `/start`, `/help`, `/show`, `/list`) so command menu and runtime behavior stay in sync. +Telegram command menu registration remains channel-local discovery UX; generic command execution is handled centrally in the agent loop via the commands executor. + +If command registration fails (network/API transient errors), the channel still starts and PicoClaw retries registration in the background. +
@@ -750,6 +757,12 @@ For advanced/test setups, you can override the builtin skills root with: export PICOCLAW_BUILTIN_SKILLS=/path/to/skills ``` +### Unified Command Execution Policy + +- Generic slash commands are executed through a single path in `pkg/agent/loop.go` via `commands.Executor`. +- Channel adapters no longer consume generic commands locally; they forward inbound text to the bus/agent path. Telegram still auto-registers supported commands at startup. +- Unknown slash command (for example `/foo`) passes through to normal LLM processing. +- Registered but unsupported command on the current channel (for example `/show` on WhatsApp) returns an explicit user-facing error and stops further processing. ### 🔒 Security Sandbox PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. @@ -1205,6 +1218,10 @@ picoclaw agent -m "Hello" "model": "anthropic/claude-opus-4-5" } }, + "session": { + "dm_scope": "per-channel-peer", + "backlog_limit": 20 + }, "providers": { "openrouter": { "api_key": "sk-or-v1-xxx" diff --git a/README.zh.md b/README.zh.md index bd90173f9..dc32b67e0 100644 --- a/README.zh.md +++ b/README.zh.md @@ -307,6 +307,13 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方 | **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) | | **MaixCam** | ⭐ 简单 | 专为 AI 摄像头设计的硬件集成通道 | [查看文档](docs/channels/maixcam/README.zh.md) | +### Telegram 命令注册(启动时自动同步) + +PicoClaw 现在使用统一的命令定义来源。启动时会自动将 Telegram 支持的命令(例如 `/start`、`/help`、`/show`、`/list`)注册到 Bot 命令菜单,确保菜单展示与实际行为一致。 +Telegram 侧保留的是命令菜单注册能力;通用命令的实际执行统一走 Agent Loop 中的 commands executor。 + +如果注册因网络或 API 短暂异常失败,不会阻塞 channel 启动;系统会在后台自动重试。 + ## ClawdChat 加入 Agent 社交网络 只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。 @@ -376,6 +383,12 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work export PICOCLAW_BUILTIN_SKILLS=/path/to/skills ``` +### 统一命令执行策略 + +- 通用斜杠命令通过 `pkg/agent/loop.go` 中的 `commands.Executor` 统一执行。 +- Channel 适配器不再在本地消费通用命令;它们只负责把入站文本转发到 bus/agent 路径。Telegram 仍会在启动时自动注册其支持的命令菜单。 +- 未注册的斜杠命令(例如 `/foo`)会透传给 LLM 按普通输入处理。 +- 已注册但当前 channel 不支持的命令(例如 WhatsApp 上的 `/show`)会返回明确的用户可见错误,并停止后续处理。 ### 心跳 / 周期性任务 (Heartbeat) PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: @@ -715,6 +728,10 @@ picoclaw agent -m "你好" "model": "anthropic/claude-opus-4-5" } }, + "session": { + "dm_scope": "per-channel-peer", + "backlog_limit": 20 + }, "providers": { "openrouter": { "api_key": "sk-or-v1-xxx" diff --git a/go.mod b/go.mod index 238bd405c..6fa3a900c 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 + github.com/h2non/filetype v1.1.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mdp/qrterminal/v3 v3.2.1 github.com/modelcontextprotocol/go-sdk v1.3.0 @@ -37,7 +38,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect - github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 132bb3c98..966668227 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/commands" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" @@ -46,6 +47,7 @@ type AgentLoop struct { channelManager *channels.Manager mediaStore media.MediaStore transcriber voice.Transcriber + cmdRegistry *commands.Registry } // processOptions configures how a message is processed @@ -61,7 +63,15 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." +const ( + defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." + sessionKeyAgentPrefix = "agent:" + metadataKeyAccountID = "account_id" + metadataKeyGuildID = "guild_id" + metadataKeyTeamID = "team_id" + metadataKeyParentPeerKind = "parent_peer_kind" + metadataKeyParentPeerID = "parent_peer_id" +) func NewAgentLoop( cfg *config.Config, @@ -84,14 +94,17 @@ func NewAgentLoop( stateManager = state.NewManager(defaultAgent.Workspace) } - return &AgentLoop{ + al := &AgentLoop{ bus: msgBus, cfg: cfg, registry: registry, state: stateManager, summarizing: sync.Map{}, fallback: fallbackChain, + cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()), } + + return al } // registerSharedTools registers tools that are shared across all agents (web, message, spawn). @@ -549,27 +562,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } - // Check for commands - if response, handled := al.handleCommand(ctx, msg); handled { + route, agent, routeErr := al.resolveMessageRoute(msg) + + // Commands are checked before requiring a successful route. + // Global commands (/help, /show, /switch) work even when routing fails; + // context-dependent commands check their own Runtime fields and report + // "unavailable" when the required capability is nil. + if response, handled := al.handleCommand(ctx, msg, agent); handled { return response, nil } - // Route to determine agent and session key - route := al.registry.ResolveRoute(routing.RouteInput{ - Channel: msg.Channel, - AccountID: msg.Metadata["account_id"], - Peer: extractPeer(msg), - ParentPeer: extractParentPeer(msg), - GuildID: msg.Metadata["guild_id"], - TeamID: msg.Metadata["team_id"], - }) - - agent, ok := al.registry.GetAgent(route.AgentID) - if !ok { - agent = al.registry.GetDefaultAgent() - } - if agent == nil { - return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + if routeErr != nil { + return "", routeErr } // Reset message-tool state for this round so we don't skip publishing due to a previous round. @@ -579,17 +583,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } } - // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) - sessionKey := route.SessionKey - if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { - sessionKey = msg.SessionKey - } + // Resolve session key from route, while preserving explicit agent-scoped keys. + scopeKey := resolveScopeKey(route, msg.SessionKey) + sessionKey := scopeKey logger.InfoCF("agent", "Routed message", map[string]any{ - "agent_id": agent.ID, - "session_key": sessionKey, - "matched_by": route.MatchedBy, + "agent_id": agent.ID, + "scope_key": scopeKey, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + "route_agent": route.AgentID, + "route_channel": route.Channel, }) return al.runAgentLoop(ctx, agent, processOptions{ @@ -604,6 +609,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) }) } +func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + AccountID: inboundMetadata(msg, metadataKeyAccountID), + Peer: extractPeer(msg), + ParentPeer: extractParentPeer(msg), + GuildID: inboundMetadata(msg, metadataKeyGuildID), + TeamID: inboundMetadata(msg, metadataKeyTeamID), + }) + + agent, ok := al.registry.GetAgent(route.AgentID) + if !ok { + agent = al.registry.GetDefaultAgent() + } + if agent == nil { + return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + } + + return route, agent, nil +} + +func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { + if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) { + return msgSessionKey + } + return route.SessionKey +} + func (al *AgentLoop) processSystemMessage( ctx context.Context, msg bus.InboundMessage, @@ -1504,94 +1537,87 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int { return totalChars * 2 / 5 } -func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) { - content := strings.TrimSpace(msg.Content) - if !strings.HasPrefix(content, "/") { +func (al *AgentLoop) handleCommand( + ctx context.Context, + msg bus.InboundMessage, + agent *AgentInstance, +) (string, bool) { + if !commands.HasCommandPrefix(msg.Content) { return "", false } - parts := strings.Fields(content) - if len(parts) == 0 { + if al.cmdRegistry == nil { return "", false } - cmd := parts[0] - args := parts[1:] + rt := al.buildCommandsRuntime(agent) + executor := commands.NewExecutor(al.cmdRegistry, rt) - switch cmd { - case "/show": - if len(args) < 1 { - return "Usage: /show [model|channel|agents]", true - } - switch args[0] { - case "model": - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent == nil { - return "No default agent configured", true - } - return fmt.Sprintf("Current model: %s", defaultAgent.Model), true - case "channel": - return fmt.Sprintf("Current channel: %s", msg.Channel), true - case "agents": - agentIDs := al.registry.ListAgentIDs() - return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true - default: - return fmt.Sprintf("Unknown show target: %s", args[0]), true - } + var commandReply string + result := executor.Execute(ctx, commands.Request{ + Channel: msg.Channel, + ChatID: msg.ChatID, + SenderID: msg.SenderID, + Text: msg.Content, + Reply: func(text string) error { + commandReply = text + return nil + }, + }) - case "/list": - if len(args) < 1 { - return "Usage: /list [models|channels|agents]", true + switch result.Outcome { + case commands.OutcomeHandled: + if result.Err != nil { + return mapCommandError(result), true } - switch args[0] { - case "models": - return "Available models: configured in config.json per agent", true - case "channels": + if commandReply != "" { + return commandReply, true + } + return "", true + default: // OutcomePassthrough — let the message fall through to LLM + return "", false + } +} + +func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime { + rt := &commands.Runtime{ + Config: al.cfg, + ListAgentIDs: al.registry.ListAgentIDs, + ListDefinitions: al.cmdRegistry.Definitions, + GetEnabledChannels: func() []string { if al.channelManager == nil { - return "Channel manager not initialized", true + return nil } - channels := al.channelManager.GetEnabledChannels() - if len(channels) == 0 { - return "No channels enabled", true - } - return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true - case "agents": - agentIDs := al.registry.ListAgentIDs() - return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true - default: - return fmt.Sprintf("Unknown list target: %s", args[0]), true - } - - case "/switch": - if len(args) < 3 || args[1] != "to" { - return "Usage: /switch [model|channel] to ", true - } - target := args[0] - value := args[2] - - switch target { - case "model": - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent == nil { - return "No default agent configured", true - } - oldModel := defaultAgent.Model - defaultAgent.Model = value - return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true - case "channel": + return al.channelManager.GetEnabledChannels() + }, + SwitchChannel: func(value string) error { if al.channelManager == nil { - return "Channel manager not initialized", true + return fmt.Errorf("channel manager not initialized") } if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { - return fmt.Sprintf("Channel '%s' not found or not enabled", value), true + return fmt.Errorf("channel '%s' not found or not enabled", value) } - return fmt.Sprintf("Switched target channel to %s", value), true - default: - return fmt.Sprintf("Unknown switch target: %s", target), true + return nil + }, + } + if agent != nil { + rt.GetModelInfo = func() (string, string) { + return agent.Model, al.cfg.Agents.Defaults.Provider + } + rt.SwitchModel = func(value string) (string, error) { + oldModel := agent.Model + agent.Model = value + return oldModel, nil } } + return rt +} - return "", false +func mapCommandError(result commands.ExecuteResult) string { + if result.Command == "" { + return fmt.Sprintf("Failed to execute command: %v", result.Err) + } + return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err) } // extractPeer extracts the routing peer from the inbound message's structured Peer field. @@ -1610,10 +1636,17 @@ func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} } +func inboundMetadata(msg bus.InboundMessage, key string) string { + if msg.Metadata == nil { + return "" + } + return msg.Metadata[key] +} + // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { - parentKind := msg.Metadata["parent_peer_kind"] - parentID := msg.Metadata["parent_peer_id"] + parentKind := inboundMetadata(msg, metadataKeyParentPeerKind) + parentID := inboundMetadata(msg, metadataKeyParentPeerID) if parentKind == "" || parentID == "" { return nil } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index aa7d59b5a..2e456fa60 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -318,6 +319,29 @@ func (m *simpleMockProvider) GetDefaultModel() string { return "mock-model" } +type countingMockProvider struct { + response string + calls int +} + +func (m *countingMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *countingMockProvider) GetDefaultModel() string { + return "counting-mock-model" +} + // mockCustomTool is a simple mock tool for registration testing type mockCustomTool struct{} @@ -359,6 +383,198 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms const responseTimeout = 3 * time.Second +func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "ok"} + al := NewAgentLoop(cfg, msgBus, provider) + + msg := bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + Peer: extractPeer(msg), + }) + sessionKey := route.SessionKey + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + + helper := testHelper{al: al} + _ = helper.executeAndGetResponse(t, context.Background(), msg) + + history := defaultAgent.Sessions.GetHistory(sessionKey) + if len(history) != 2 { + t.Fatalf("expected session history len=2, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hello" { + t.Fatalf("unexpected first message in session: %+v", history[0]) + } +} + +func TestProcessMessage_CommandOutcomes(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + Session: config.SessionConfig{ + DMScope: "per-channel-peer", + }, + } + + msgBus := bus.NewMessageBus() + provider := &countingMockProvider{response: "LLM reply"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + baseMsg := bus.InboundMessage{ + Channel: "whatsapp", + SenderID: "user1", + ChatID: "chat1", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/show channel", + Peer: baseMsg.Peer, + }) + if showResp != "Current Channel: whatsapp" { + t.Fatalf("unexpected /show reply: %q", showResp) + } + if provider.calls != 0 { + t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls) + } + + fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/foo", + Peer: baseMsg.Peer, + }) + if fooResp != "LLM reply" { + t.Fatalf("unexpected /foo reply: %q", fooResp) + } + if provider.calls != 1 { + t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls) + } + + newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/new", + Peer: baseMsg.Peer, + }) + if newResp != "LLM reply" { + t.Fatalf("unexpected /new reply: %q", newResp) + } + if provider.calls != 2 { + t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls) + } +} + +func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Provider: "openai", + Model: "before-switch", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &countingMockProvider{response: "LLM reply"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/switch model to after-switch", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") { + t.Fatalf("unexpected /switch reply: %q", switchResp) + } + + showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/show model", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") { + t.Fatalf("unexpected /show model reply after switch: %q", showResp) + } + + if provider.calls != 0 { + t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls) + } +} + // TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go index 74caeeac5..b3a493761 100644 --- a/pkg/channels/interfaces.go +++ b/pkg/channels/interfaces.go @@ -1,6 +1,10 @@ package channels -import "context" +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/commands" +) // TypingCapable — channels that can show a typing/thinking indicator. // StartTyping begins the indicator and returns a stop function. @@ -39,3 +43,10 @@ type PlaceholderRecorder interface { RecordTypingStop(channel, chatID string, stop func()) RecordReactionUndo(channel, chatID string, undo func()) } + +// CommandRegistrarCapable is implemented by channels that can register +// command menus with their upstream platform (e.g. Telegram BotCommand). +// Channels that do not support platform-level command menus can ignore it. +type CommandRegistrarCapable interface { + RegisterCommands(ctx context.Context, defs []commands.Definition) error +} diff --git a/pkg/channels/interfaces_command_test.go b/pkg/channels/interfaces_command_test.go new file mode 100644 index 000000000..de5502644 --- /dev/null +++ b/pkg/channels/interfaces_command_test.go @@ -0,0 +1,16 @@ +package channels + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/commands" +) + +type mockRegistrar struct{} + +func (mockRegistrar) RegisterCommands(context.Context, []commands.Definition) error { return nil } + +func TestCommandRegistrarCapable_Compiles(t *testing.T) { + var _ CommandRegistrarCapable = mockRegistrar{} +} diff --git a/pkg/channels/telegram/command_registration.go b/pkg/channels/telegram/command_registration.go new file mode 100644 index 000000000..d3152ec3d --- /dev/null +++ b/pkg/channels/telegram/command_registration.go @@ -0,0 +1,116 @@ +package telegram + +import ( + "context" + "math/rand" + "slices" + "time" + + "github.com/mymmrac/telego" + + "github.com/sipeed/picoclaw/pkg/commands" + "github.com/sipeed/picoclaw/pkg/logger" +) + +var commandRegistrationBackoff = []time.Duration{ + 5 * time.Second, + 15 * time.Second, + 60 * time.Second, + 5 * time.Minute, + 10 * time.Minute, +} + +func commandRegistrationDelay(attempt int) time.Duration { + if len(commandRegistrationBackoff) == 0 { + return 0 + } + base := commandRegistrationBackoff[min(attempt, len(commandRegistrationBackoff)-1)] + // Full jitter in [0.5, 1.0) to avoid synchronized retries across instances. + return time.Duration(float64(base) * (0.5 + rand.Float64()*0.5)) +} + +// RegisterCommands registers bot commands on Telegram platform. +func (c *TelegramChannel) RegisterCommands(ctx context.Context, defs []commands.Definition) error { + botCommands := make([]telego.BotCommand, 0, len(defs)) + for _, def := range defs { + if def.Name == "" || def.Description == "" { + continue + } + botCommands = append(botCommands, telego.BotCommand{ + Command: def.Name, + Description: def.Description, + }) + } + + current, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{}) + if err != nil { + // If we can't read current commands, fall through to set them. + logger.WarnCF("telegram", "Failed to get current commands, will set unconditionally", + map[string]any{"error": err.Error()}) + } else if slices.Equal(current, botCommands) { + logger.DebugCF("telegram", "Bot commands are up to date", nil) + return nil + } + + return c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ + Commands: botCommands, + }) +} + +func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []commands.Definition) { + if len(defs) == 0 { + return + } + + register := c.registerFunc + if register == nil { + register = c.RegisterCommands + } + + regCtx, cancel := context.WithCancel(ctx) + c.commandRegCancel = cancel + + // Registration runs asynchronously so Telegram message intake is never blocked + // by temporary upstream API failures. Retry stops on success or channel shutdown. + go func() { + attempt := 0 + timer := time.NewTimer(0) + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + defer timer.Stop() + for { + err := register(regCtx, defs) + if err == nil { + logger.InfoCF("telegram", "Telegram commands registered", map[string]any{ + "count": len(defs), + }) + return + } + + delay := commandRegistrationDelay(attempt) + logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{ + "error": err.Error(), + "retry_after": delay.String(), + }) + attempt++ + + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(delay) + + select { + case <-regCtx.Done(): + return + case <-timer.C: + } + } + }() +} diff --git a/pkg/channels/telegram/command_registration_test.go b/pkg/channels/telegram/command_registration_test.go new file mode 100644 index 000000000..26f891b2e --- /dev/null +++ b/pkg/channels/telegram/command_registration_test.go @@ -0,0 +1,96 @@ +package telegram + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/commands" +) + +func TestStartCommandRegistration_DoesNotBlock(t *testing.T) { + ch := &TelegramChannel{} + started := make(chan struct{}, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch.registerFunc = func(context.Context, []commands.Definition) error { + started <- struct{}{} + return errors.New("temporary failure") + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help"}}) + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("registration did not start asynchronously") + } +} + +func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) { + ch := &TelegramChannel{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + origBackoff := commandRegistrationBackoff + commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} + defer func() { commandRegistrationBackoff = origBackoff }() + + var attempts atomic.Int32 + ch.registerFunc = func(context.Context, []commands.Definition) error { + n := attempts.Add(1) + if n < 3 { + return errors.New("temporary failure") + } + return nil + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}}) + + deadline := time.Now().Add(250 * time.Millisecond) + for time.Now().Before(deadline) { + if attempts.Load() >= 3 { + break + } + time.Sleep(5 * time.Millisecond) + } + if attempts.Load() < 3 { + t.Fatalf("expected at least 3 attempts, got %d", attempts.Load()) + } + + stable := attempts.Load() + time.Sleep(30 * time.Millisecond) + if attempts.Load() != stable { + t.Fatalf("expected retries to stop after success, got %d -> %d", stable, attempts.Load()) + } +} + +func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) { + ch := &TelegramChannel{} + ctx, cancel := context.WithCancel(context.Background()) + + origBackoff := commandRegistrationBackoff + commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} + defer func() { commandRegistrationBackoff = origBackoff }() + defer cancel() + + var attempts atomic.Int32 + ch.registerFunc = func(context.Context, []commands.Definition) error { + attempts.Add(1) + return errors.New("always fail") + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}}) + + time.Sleep(20 * time.Millisecond) + cancel() + time.Sleep(20 * time.Millisecond) // allow in-flight attempt to settle + stable := attempts.Load() + time.Sleep(30 * time.Millisecond) + if attempts.Load() != stable { + t.Fatalf("expected retries to quiesce after cancel, got %d -> %d", stable, attempts.Load()) + } +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index f328f32b8..a2035853c 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,7 +7,6 @@ import ( "net/url" "os" "regexp" - "slices" "strconv" "strings" "time" @@ -18,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/commands" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" @@ -40,13 +40,15 @@ var ( type TelegramChannel struct { *channels.BaseChannel - bot *telego.Bot - bh *th.BotHandler - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - ctx context.Context - cancel context.CancelFunc + bot *telego.Bot + bh *th.BotHandler + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc + + registerFunc func(context.Context, []commands.Definition) error + commandRegCancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -93,7 +95,6 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann return &TelegramChannel{ BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), bot: bot, config: cfg, chatIDs: make(map[string]int64), @@ -105,12 +106,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error { c.ctx, c.cancel = context.WithCancel(ctx) - if err := c.initBotCommands(c.ctx); err != nil { - logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{ - "error": err.Error(), - }) - } - updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) @@ -126,21 +121,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } c.bh = bh - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Start(ctx, message) - }, th.CommandEqual("start")) - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Help(ctx, message) - }, th.CommandEqual("help")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Show(ctx, message) - }, th.CommandEqual("show")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.List(ctx, message) - }, th.CommandEqual("list")) - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { return c.handleMessage(ctx, &message) }, th.AnyMessage()) @@ -150,6 +130,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error { "username": c.bot.Username(), }) + c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions()) + go func() { if err = bh.Start(); err != nil { logger.ErrorCF("telegram", "Bot handler failed", map[string]any{ @@ -174,50 +156,8 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { if c.cancel != nil { c.cancel() } - - return nil -} - -func (c *TelegramChannel) initBotCommands(ctx context.Context) error { - currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{ - Scope: tu.ScopeDefault(), - }) - if err != nil { - return fmt.Errorf("get commands: %w", err) - } - - commands := []telego.BotCommand{ - { - Command: "start", - Description: "Start the bot", - }, - { - Command: "help", - Description: "Show a help message", - }, - { - Command: "show", - Description: "Show current configuration", - }, - { - Command: "list", - Description: "List available options", - }, - } - - // Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed - if !slices.Equal(currentCommands, commands) { - logger.InfoC("telegram", "Updating bot commands") - - err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ - Commands: commands, - Scope: tu.ScopeDefault(), - }) - if err != nil { - return fmt.Errorf("set commands: %w", err) - } - } else { - logger.DebugC("telegram", "Bot commands are up to date") + if c.commandRegCancel != nil { + c.commandRegCancel() } return nil @@ -721,34 +661,34 @@ func escapeHTML(text string) string { // isBotMentioned checks if the bot is mentioned in the message via entities. func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { - botUsername := c.bot.Username() - if botUsername == "" { + text, entities := telegramEntityTextAndList(message) + if text == "" || len(entities) == 0 { return false } - entities := message.Entities - if entities == nil { - entities = message.CaptionEntities + botUsername := "" + if c.bot != nil { + botUsername = c.bot.Username() } + runes := []rune(text) for _, entity := range entities { - if entity.Type == "mention" { - // Extract the mention text from the message - text := message.Text - if text == "" { - text = message.Caption - } - runes := []rune(text) - end := entity.Offset + entity.Length - if end <= len(runes) { - mention := string(runes[entity.Offset:end]) - if strings.EqualFold(mention, "@"+botUsername) { - return true - } - } + entityText, ok := telegramEntityText(runes, entity) + if !ok { + continue } - if entity.Type == "text_mention" && entity.User != nil { - if entity.User.Username == botUsername { + + switch entity.Type { + case telego.EntityTypeMention: + if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) { + return true + } + case telego.EntityTypeTextMention: + if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) { + return true + } + case telego.EntityTypeBotCommand: + if isBotCommandEntityForThisBot(entityText, botUsername) { return true } } @@ -756,6 +696,46 @@ func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { return false } +func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) { + if message.Text != "" { + return message.Text, message.Entities + } + return message.Caption, message.CaptionEntities +} + +func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) { + if entity.Offset < 0 || entity.Length <= 0 { + return "", false + } + end := entity.Offset + entity.Length + if entity.Offset >= len(runes) || end > len(runes) { + return "", false + } + return string(runes[entity.Offset:end]), true +} + +func isBotCommandEntityForThisBot(entityText, botUsername string) bool { + if !strings.HasPrefix(entityText, "/") { + return false + } + command := strings.TrimPrefix(entityText, "/") + if command == "" { + return false + } + + at := strings.IndexRune(command, '@') + if at == -1 { + // A bare /command delivered to this bot is intended for this bot. + return true + } + + mentionUsername := command[at+1:] + if mentionUsername == "" || botUsername == "" { + return false + } + return strings.EqualFold(mentionUsername, botUsername) +} + // stripBotMention removes the @bot mention from the content. func (c *TelegramChannel) stripBotMention(content string) string { botUsername := c.bot.Username() diff --git a/pkg/channels/telegram/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go deleted file mode 100644 index 496fc5e4f..000000000 --- a/pkg/channels/telegram/telegram_commands.go +++ /dev/null @@ -1,156 +0,0 @@ -package telegram - -import ( - "context" - "fmt" - "strings" - - "github.com/mymmrac/telego" - - "github.com/sipeed/picoclaw/pkg/config" -) - -type TelegramCommander interface { - Help(ctx context.Context, message telego.Message) error - Start(ctx context.Context, message telego.Message) error - Show(ctx context.Context, message telego.Message) error - List(ctx context.Context, message telego.Message) error -} - -type cmd struct { - bot *telego.Bot - config *config.Config -} - -func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { - return &cmd{ - bot: bot, - config: cfg, - } -} - -func commandArgs(text string) string { - parts := strings.SplitN(text, " ", 2) - if len(parts) < 2 { - return "" - } - return strings.TrimSpace(parts[1]) -} - -func (c *cmd) Help(ctx context.Context, message telego.Message) error { - msg := `/start - Start the bot -/help - Show this help message -/show [model|channel] - Show current configuration -/list [models|channels] - List available options - ` - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: msg, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Start(ctx context.Context, message telego.Message) error { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Hello! I am PicoClaw 🦞", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Show(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /show [model|channel]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "model": - response = fmt.Sprintf("Current Model: %s (Provider: %s)", - c.config.Agents.Defaults.GetModelName(), - c.config.Agents.Defaults.Provider) - case "channel": - response = "Current Channel: telegram" - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) List(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /list [models|channels]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "models": - provider := c.config.Agents.Defaults.Provider - if provider == "" { - provider = "configured default" - } - response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json", - c.config.Agents.Defaults.GetModelName(), provider) - - case "channels": - var enabled []string - if c.config.Channels.Telegram.Enabled { - enabled = append(enabled, "telegram") - } - if c.config.Channels.WhatsApp.Enabled { - enabled = append(enabled, "whatsapp") - } - if c.config.Channels.Feishu.Enabled { - enabled = append(enabled, "feishu") - } - if c.config.Channels.Discord.Enabled { - enabled = append(enabled, "discord") - } - if c.config.Channels.Slack.Enabled { - enabled = append(enabled, "slack") - } - response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) - - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} diff --git a/pkg/channels/telegram/telegram_dispatch_test.go b/pkg/channels/telegram/telegram_dispatch_test.go new file mode 100644 index 000000000..1ea4a4824 --- /dev/null +++ b/pkg/channels/telegram/telegram_dispatch_test.go @@ -0,0 +1,52 @@ +package telegram + +import ( + "context" + "testing" + "time" + + "github.com/mymmrac/telego" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" +) + +func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &TelegramChannel{ + BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil), + chatIDs: make(map[string]int64), + ctx: context.Background(), + } + + msg := &telego.Message{ + Text: "/new", + MessageID: 9, + Chat: telego.Chat{ + ID: 123, + Type: "private", + }, + From: &telego.User{ + ID: 42, + FirstName: "Alice", + }, + } + + if err := ch.handleMessage(context.Background(), msg); err != nil { + t.Fatalf("handleMessage error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "telegram" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/new" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/channels/telegram/telegram_group_command_filter_test.go b/pkg/channels/telegram/telegram_group_command_filter_test.go new file mode 100644 index 000000000..0d5b985fe --- /dev/null +++ b/pkg/channels/telegram/telegram_group_command_filter_test.go @@ -0,0 +1,147 @@ +package telegram + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/mymmrac/telego" + ta "github.com/mymmrac/telego/telegoapi" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +type getMeCaller struct { + username string +} + +func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) { + if strings.HasSuffix(url, "/getMe") { + result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username) + return &ta.Response{Ok: true, Result: []byte(result)}, nil + } + return &ta.Response{Ok: true, Result: []byte("true")}, nil +} + +func newTestTelegramBot(t *testing.T, username string) *telego.Bot { + t.Helper() + + token := "123456:" + strings.Repeat("a", 35) + bot, err := telego.NewBot(token, + telego.WithAPICaller(getMeCaller{username: username}), + telego.WithDiscardLogger(), + ) + if err != nil { + t.Fatalf("NewBot error: %v", err) + } + return bot +} + +func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) { + t.Helper() + + messageBus := bus.NewMessageBus() + ch := &TelegramChannel{ + BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil, + channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}), + ), + bot: newTestTelegramBot(t, botUsername), + chatIDs: make(map[string]int64), + ctx: context.Background(), + } + return ch, messageBus +} + +func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) { + tests := []struct { + name string + text string + wantForwarded bool + wantContent string + }{ + { + name: "command with bot username", + text: "/new@testbot", + wantForwarded: true, + wantContent: "/new", + }, + { + name: "bare command", + text: "/new", + wantForwarded: true, + wantContent: "/new", + }, + { + name: "command for another bot", + text: "/new@otherbot", + wantForwarded: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ch, messageBus := newGroupMentionOnlyChannel(t, "testbot") + + msg := &telego.Message{ + Text: tc.text, + Entities: []telego.MessageEntity{{ + Type: telego.EntityTypeBotCommand, + Offset: 0, + Length: len([]rune(tc.text)), + }}, + MessageID: 42, + Chat: telego.Chat{ + ID: 123, + Type: "group", + }, + From: &telego.User{ + ID: 7, + FirstName: "Alice", + }, + } + + if err := ch.handleMessage(context.Background(), msg); err != nil { + t.Fatalf("handleMessage error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if tc.wantForwarded { + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Content != tc.wantContent { + t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent) + } + return + } + + if ok { + t.Fatalf("expected message to be filtered, got content=%q", inbound.Content) + } + }) + } +} + +func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) { + ch, _ := newGroupMentionOnlyChannel(t, "testbot") + + msg := &telego.Message{ + Text: "@testbot hello", + Entities: []telego.MessageEntity{{ + Type: telego.EntityTypeMention, + Offset: 0, + Length: len("@testbot"), + }}, + } + + if !ch.isBotMentioned(msg) { + t.Fatal("expected mention entity to be treated as bot mention") + } +} diff --git a/pkg/channels/whatsapp/whatsapp_command_test.go b/pkg/channels/whatsapp/whatsapp_command_test.go new file mode 100644 index 000000000..ee8aa4a52 --- /dev/null +++ b/pkg/channels/whatsapp/whatsapp_command_test.go @@ -0,0 +1,41 @@ +package whatsapp + +import ( + "context" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &WhatsAppChannel{ + BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil), + ctx: context.Background(), + } + + ch.handleIncomingMessage(map[string]any{ + "type": "message", + "id": "mid1", + "from": "user1", + "chat": "chat1", + "content": "/help", + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "whatsapp" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/help" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/channels/whatsapp_native/whatsapp_command_test.go b/pkg/channels/whatsapp_native/whatsapp_command_test.go new file mode 100644 index 000000000..cc2dcb619 --- /dev/null +++ b/pkg/channels/whatsapp_native/whatsapp_command_test.go @@ -0,0 +1,56 @@ +//go:build whatsapp_native + +package whatsapp + +import ( + "context" + "testing" + "time" + + "go.mau.fi/whatsmeow/proto/waE2E" + "go.mau.fi/whatsmeow/types" + "go.mau.fi/whatsmeow/types/events" + "google.golang.org/protobuf/proto" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &WhatsAppNativeChannel{ + BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil), + runCtx: context.Background(), + } + + evt := &events.Message{ + Info: types.MessageInfo{ + MessageSource: types.MessageSource{ + Sender: types.NewJID("1001", types.DefaultUserServer), + Chat: types.NewJID("1001", types.DefaultUserServer), + }, + ID: "mid1", + PushName: "Alice", + }, + Message: &waE2E.Message{ + Conversation: proto.String("/new"), + }, + } + + ch.handleIncoming(evt) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "whatsapp_native" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/new" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/commands/builtin.go b/pkg/commands/builtin.go new file mode 100644 index 000000000..a36dd3eba --- /dev/null +++ b/pkg/commands/builtin.go @@ -0,0 +1,16 @@ +package commands + +// BuiltinDefinitions returns all built-in command definitions. +// Each command group is defined in its own cmd_*.go file. +// Definitions are stateless — runtime dependencies are provided +// via the Runtime parameter passed to handlers at execution time. +func BuiltinDefinitions() []Definition { + return []Definition{ + startCommand(), + helpCommand(), + showCommand(), + listCommand(), + switchCommand(), + checkCommand(), + } +} diff --git a/pkg/commands/builtin_test.go b/pkg/commands/builtin_test.go new file mode 100644 index 000000000..66a84825e --- /dev/null +++ b/pkg/commands/builtin_test.go @@ -0,0 +1,145 @@ +package commands + +import ( + "context" + "strings" + "testing" +) + +func findDefinitionByName(t *testing.T, defs []Definition, name string) Definition { + t.Helper() + for _, def := range defs { + if def.Name == name { + return def + } + } + t.Fatalf("missing /%s definition", name) + return Definition{} +} + +func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { + defs := BuiltinDefinitions() + helpDef := findDefinitionByName(t, defs, "help") + if helpDef.Handler == nil { + t.Fatalf("/help handler should not be nil") + } + + var reply string + err := helpDef.Handler(context.Background(), Request{ + Text: "/help", + Reply: func(text string) error { + reply = text + return nil + }, + }, nil) + if err != nil { + t.Fatalf("/help handler error: %v", err) + } + // Now uses auto-generated EffectiveUsage which includes agents + if !strings.Contains(reply, "/show [model|channel|agents]") { + t.Fatalf("/help reply missing /show usage, got %q", reply) + } + if !strings.Contains(reply, "/list [models|channels|agents]") { + t.Fatalf("/help reply missing /list usage, got %q", reply) + } +} + +func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) { + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), nil) + + cases := []string{"telegram", "whatsapp"} + for _, channel := range cases { + var reply string + res := ex.Execute(context.Background(), Request{ + Channel: channel, + Text: "/show channel", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/show channel on %s: outcome=%v, want=%v", channel, res.Outcome, OutcomeHandled) + } + want := "Current Channel: " + channel + if reply != want { + t.Fatalf("/show channel reply=%q, want=%q", reply, want) + } + } +} + +func TestBuiltinListChannels_UsesGetEnabledChannels(t *testing.T) { + rt := &Runtime{ + GetEnabledChannels: func() []string { + return []string{"telegram", "slack"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/list channels", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/list channels: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "telegram") || !strings.Contains(reply, "slack") { + t.Fatalf("/list channels reply=%q, want telegram and slack", reply) + } +} + +func TestBuiltinShowAgents_RestoresOldBehavior(t *testing.T) { + rt := &Runtime{ + ListAgentIDs: func() []string { + return []string{"default", "coder"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show agents", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/show agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") { + t.Fatalf("/show agents reply=%q, want agent IDs", reply) + } +} + +func TestBuiltinListAgents_RestoresOldBehavior(t *testing.T) { + rt := &Runtime{ + ListAgentIDs: func() []string { + return []string{"default", "coder"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/list agents", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/list agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") { + t.Fatalf("/list agents reply=%q, want agent IDs", reply) + } +} diff --git a/pkg/commands/cmd_check.go b/pkg/commands/cmd_check.go new file mode 100644 index 000000000..f0193dc4f --- /dev/null +++ b/pkg/commands/cmd_check.go @@ -0,0 +1,33 @@ +package commands + +import ( + "context" + "fmt" +) + +func checkCommand() Definition { + return Definition{ + Name: "check", + Description: "Check channel availability", + SubCommands: []SubCommand{ + { + Name: "channel", + Description: "Check if a channel is available", + ArgsUsage: "", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.SwitchChannel == nil { + return req.Reply(unavailableMsg) + } + value := nthToken(req.Text, 2) + if value == "" { + return req.Reply("Usage: /check channel ") + } + if err := rt.SwitchChannel(value); err != nil { + return req.Reply(err.Error()) + } + return req.Reply(fmt.Sprintf("Channel '%s' is available and enabled", value)) + }, + }, + }, + } +} diff --git a/pkg/commands/cmd_help.go b/pkg/commands/cmd_help.go new file mode 100644 index 000000000..94f7f0101 --- /dev/null +++ b/pkg/commands/cmd_help.go @@ -0,0 +1,44 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +func helpCommand() Definition { + return Definition{ + Name: "help", + Description: "Show this help message", + Usage: "/help", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + var defs []Definition + if rt != nil && rt.ListDefinitions != nil { + defs = rt.ListDefinitions() + } else { + defs = BuiltinDefinitions() + } + return req.Reply(formatHelpMessage(defs)) + }, + } +} + +func formatHelpMessage(defs []Definition) string { + if len(defs) == 0 { + return "No commands available." + } + + lines := make([]string, 0, len(defs)) + for _, def := range defs { + usage := def.EffectiveUsage() + if usage == "" { + usage = "/" + def.Name + } + desc := def.Description + if desc == "" { + desc = "No description" + } + lines = append(lines, fmt.Sprintf("%s - %s", usage, desc)) + } + return strings.Join(lines, "\n") +} diff --git a/pkg/commands/cmd_list.go b/pkg/commands/cmd_list.go new file mode 100644 index 000000000..bf47b6e9c --- /dev/null +++ b/pkg/commands/cmd_list.go @@ -0,0 +1,52 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +func listCommand() Definition { + return Definition{ + Name: "list", + Description: "List available options", + SubCommands: []SubCommand{ + { + Name: "models", + Description: "Configured models", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetModelInfo == nil { + return req.Reply(unavailableMsg) + } + name, provider := rt.GetModelInfo() + if provider == "" { + provider = "configured default" + } + return req.Reply(fmt.Sprintf( + "Configured Model: %s\nProvider: %s\n\nTo change models, update config.json", + name, provider, + )) + }, + }, + { + Name: "channels", + Description: "Enabled channels", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetEnabledChannels == nil { + return req.Reply(unavailableMsg) + } + enabled := rt.GetEnabledChannels() + if len(enabled) == 0 { + return req.Reply("No channels enabled") + } + return req.Reply(fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))) + }, + }, + { + Name: "agents", + Description: "Registered agents", + Handler: agentsHandler(), + }, + }, + } +} diff --git a/pkg/commands/cmd_show.go b/pkg/commands/cmd_show.go new file mode 100644 index 000000000..c655e6880 --- /dev/null +++ b/pkg/commands/cmd_show.go @@ -0,0 +1,38 @@ +package commands + +import ( + "context" + "fmt" +) + +func showCommand() Definition { + return Definition{ + Name: "show", + Description: "Show current configuration", + SubCommands: []SubCommand{ + { + Name: "model", + Description: "Current model and provider", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetModelInfo == nil { + return req.Reply(unavailableMsg) + } + name, provider := rt.GetModelInfo() + return req.Reply(fmt.Sprintf("Current Model: %s (Provider: %s)", name, provider)) + }, + }, + { + Name: "channel", + Description: "Current channel", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply(fmt.Sprintf("Current Channel: %s", req.Channel)) + }, + }, + { + Name: "agents", + Description: "Registered agents", + Handler: agentsHandler(), + }, + }, + } +} diff --git a/pkg/commands/cmd_start.go b/pkg/commands/cmd_start.go new file mode 100644 index 000000000..8b500aa10 --- /dev/null +++ b/pkg/commands/cmd_start.go @@ -0,0 +1,14 @@ +package commands + +import "context" + +func startCommand() Definition { + return Definition{ + Name: "start", + Description: "Start the bot", + Usage: "/start", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply("Hello! I am PicoClaw 🦞") + }, + } +} diff --git a/pkg/commands/cmd_switch.go b/pkg/commands/cmd_switch.go new file mode 100644 index 000000000..fb8fc109e --- /dev/null +++ b/pkg/commands/cmd_switch.go @@ -0,0 +1,42 @@ +package commands + +import ( + "context" + "fmt" +) + +func switchCommand() Definition { + return Definition{ + Name: "switch", + Description: "Switch model", + SubCommands: []SubCommand{ + { + Name: "model", + Description: "Switch to a different model", + ArgsUsage: "to ", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.SwitchModel == nil { + return req.Reply(unavailableMsg) + } + // Parse: /switch model to + value := nthToken(req.Text, 3) // tokens: [/switch, model, to, ] + if nthToken(req.Text, 2) != "to" || value == "" { + return req.Reply("Usage: /switch model to ") + } + oldModel, err := rt.SwitchModel(value) + if err != nil { + return req.Reply(err.Error()) + } + return req.Reply(fmt.Sprintf("Switched model from %s to %s", oldModel, value)) + }, + }, + { + Name: "channel", + Description: "Moved to /check channel", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply("This command has moved. Please use: /check channel ") + }, + }, + }, + } +} diff --git a/pkg/commands/cmd_switch_test.go b/pkg/commands/cmd_switch_test.go new file mode 100644 index 000000000..59ed305bb --- /dev/null +++ b/pkg/commands/cmd_switch_test.go @@ -0,0 +1,279 @@ +package commands + +import ( + "context" + "fmt" + "testing" +) + +func TestSwitchModel_Success(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old-model", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "Switched model from old-model to gpt-4" + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestSwitchModel_MissingToKeyword(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /switch model to " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitchModel_MissingValue(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /switch model to " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitchModel_Error(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "", fmt.Errorf("model not found") + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to bad-model", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "model not found" { + t.Fatalf("reply=%q, want error message", reply) + } +} + +func TestSwitchModel_NilDep(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Command unavailable in current context." { + t.Fatalf("reply=%q, want unavailable message", reply) + } +} + +func TestSwitchChannel_Redirect(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch channel to telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "This command has moved. Please use: /check channel " + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestCheckChannel_Success(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "Channel 'telegram' is available and enabled" + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestCheckChannel_Error(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return fmt.Errorf("channel '%s' not found", value) + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel unknown", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "channel 'unknown' not found" { + t.Fatalf("reply=%q, want error message", reply) + } +} + +func TestCheckChannel_NilDep(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Command unavailable in current context." { + t.Fatalf("reply=%q, want unavailable message", reply) + } +} + +func TestCheckChannel_MissingValue(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /check channel " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitch_BangPrefix(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "!switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("! prefix: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Switched model from old to gpt-4" { + t.Fatalf("! prefix: reply=%q, want success message", reply) + } +} + +func TestSwitch_NoSubCommand(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + // Should get usage message from executor's sub-command routing + if reply == "" { + t.Fatal("expected usage reply for bare /switch") + } +} diff --git a/pkg/commands/definition.go b/pkg/commands/definition.go new file mode 100644 index 000000000..7309df317 --- /dev/null +++ b/pkg/commands/definition.go @@ -0,0 +1,48 @@ +package commands + +import ( + "fmt" + "strings" +) + +// SubCommand defines a single sub-command within a parent command. +type SubCommand struct { + Name string + Description string + ArgsUsage string // optional, e.g. "" + Handler Handler +} + +// Definition is the single-source metadata and behavior contract for a slash command. +// +// Design notes (phase 1): +// - Every channel reads command shape from this type instead of keeping local copies. +// - Visibility is global: all definitions are considered available to all channels. +// - Platform menu registration (for example Telegram BotCommand) also derives from this +// same definition so UI labels and runtime behavior stay aligned. +type Definition struct { + Name string + Description string + Usage string // for simple commands; ignored when SubCommands is set + Aliases []string + SubCommands []SubCommand // optional; when set, Executor routes to sub-command handlers + Handler Handler // for simple commands without sub-commands +} + +// EffectiveUsage returns the usage string. When SubCommands are present, +// it is auto-generated from sub-command names so metadata and behavior +// cannot drift. +func (d Definition) EffectiveUsage() string { + if len(d.SubCommands) == 0 { + return d.Usage + } + names := make([]string, 0, len(d.SubCommands)) + for _, sc := range d.SubCommands { + name := sc.Name + if sc.ArgsUsage != "" { + name += " " + sc.ArgsUsage + } + names = append(names, name) + } + return fmt.Sprintf("/%s [%s]", d.Name, strings.Join(names, "|")) +} diff --git a/pkg/commands/definition_test.go b/pkg/commands/definition_test.go new file mode 100644 index 000000000..27ad4a0a2 --- /dev/null +++ b/pkg/commands/definition_test.go @@ -0,0 +1,41 @@ +package commands + +import ( + "testing" +) + +func TestDefinition_EffectiveUsage_NoSubCommands(t *testing.T) { + d := Definition{Name: "start", Usage: "/start"} + if got := d.EffectiveUsage(); got != "/start" { + t.Fatalf("EffectiveUsage()=%q, want %q", got, "/start") + } +} + +func TestDefinition_EffectiveUsage_WithSubCommands(t *testing.T) { + d := Definition{ + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + {Name: "channel"}, + {Name: "agents"}, + }, + } + want := "/show [model|channel|agents]" + if got := d.EffectiveUsage(); got != want { + t.Fatalf("EffectiveUsage()=%q, want %q", got, want) + } +} + +func TestDefinition_EffectiveUsage_WithArgsUsage(t *testing.T) { + d := Definition{ + Name: "session", + SubCommands: []SubCommand{ + {Name: "list"}, + {Name: "resume", ArgsUsage: ""}, + }, + } + want := "/session [list|resume ]" + if got := d.EffectiveUsage(); got != want { + t.Fatalf("EffectiveUsage()=%q, want %q", got, want) + } +} diff --git a/pkg/commands/executor.go b/pkg/commands/executor.go new file mode 100644 index 000000000..78a50e6c2 --- /dev/null +++ b/pkg/commands/executor.go @@ -0,0 +1,89 @@ +package commands + +import ( + "context" + "fmt" +) + +type Outcome int + +const ( + // OutcomePassthrough means this input should continue through normal agent flow. + OutcomePassthrough Outcome = iota + // OutcomeHandled means a command handler executed (with or without handler error). + OutcomeHandled +) + +type ExecuteResult struct { + Outcome Outcome + Command string + Err error +} + +type Executor struct { + reg *Registry + rt *Runtime +} + +func NewExecutor(reg *Registry, rt *Runtime) *Executor { + return &Executor{reg: reg, rt: rt} +} + +// Execute implements a two-state command decision: +// 1) handled: execute command immediately; +// 2) passthrough: not a command or intentionally deferred to agent logic. +func (e *Executor) Execute(ctx context.Context, req Request) ExecuteResult { + cmdName, ok := parseCommandName(req.Text) + if !ok { + return ExecuteResult{Outcome: OutcomePassthrough} + } + + if e == nil || e.reg == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName} + } + + def, found := e.reg.Lookup(cmdName) + if !found { + return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName} + } + + return e.executeDefinition(ctx, req, def) +} + +func (e *Executor) executeDefinition(ctx context.Context, req Request, def Definition) ExecuteResult { + // Ensure Reply is always non-nil so handlers don't need to check. + if req.Reply == nil { + req.Reply = func(string) error { return nil } + } + + // Simple command — no sub-commands + if len(def.SubCommands) == 0 { + if def.Handler == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name} + } + err := def.Handler(ctx, req, e.rt) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + + // Sub-command routing + subName := nthToken(req.Text, 1) + if subName == "" { + err := req.Reply("Usage: " + def.EffectiveUsage()) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + + normalized := normalizeCommandName(subName) + for _, sc := range def.SubCommands { + if normalizeCommandName(sc.Name) == normalized { + if sc.Handler == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name} + } + err := sc.Handler(ctx, req, e.rt) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + } + + // Unknown sub-command + err := req.Reply(fmt.Sprintf("Unknown option: %s. Usage: %s", subName, def.EffectiveUsage())) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} +} diff --git a/pkg/commands/executor_test.go b/pkg/commands/executor_test.go new file mode 100644 index 000000000..09350f1b6 --- /dev/null +++ b/pkg/commands/executor_test.go @@ -0,0 +1,260 @@ +package commands + +import ( + "context" + "errors" + "strings" + "testing" +) + +func TestExecutor_RegisteredWithoutHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{{Name: "show"}} + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/show"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} + +func TestExecutor_UnknownSlashCommand_ReturnsPassthrough(t *testing.T) { + defs := []Definition{{Name: "show"}} + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/unknown"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} + +func TestExecutor_SupportedCommandWithHandler_ReturnsHandled(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help@my_bot"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_AliasWithoutHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + { + Name: "show", + Aliases: []string{"display"}, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/display"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "show" { + t.Fatalf("command=%q, want=%q", res.Command, "show") + } +} + +func TestExecutor_AliasWithHandler_ReturnsHandled(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "clear", + Aliases: []string{"reset"}, + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/reset"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if res.Command != "clear" { + t.Fatalf("command=%q, want=%q", res.Command, "clear") + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_SupportedCommandWithNilHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + {Name: "placeholder"}, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder list"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "placeholder" { + t.Fatalf("command=%q, want=%q", res.Command, "placeholder") + } +} + +func TestExecutor_NilHandlerDoesNotMaskLaterHandler(t *testing.T) { + // With Lookup-based dispatch, the first registered definition for a name wins. + // A definition with nil Handler and no SubCommands returns Passthrough. + defs := []Definition{ + {Name: "placeholder"}, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "placeholder" { + t.Fatalf("command=%q, want=%q", res.Command, "placeholder") + } +} + +func TestExecutor_HandlerErrorIsPropagated(t *testing.T) { + wantErr := errors.New("handler failed") + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + return wantErr + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !errors.Is(res.Err, wantErr) { + t.Fatalf("err=%v, want=%v", res.Err, wantErr) + } +} + +func TestExecutor_SupportsBangPrefixAndCaseInsensitiveCommand(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "!HELP"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_SubCommand_RoutesToCorrectHandler(t *testing.T) { + modelCalled := false + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model", Handler: func(_ context.Context, _ Request, _ *Runtime) error { + modelCalled = true + return nil + }}, + {Name: "channel"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Text: "/show model"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !modelCalled { + t.Fatal("model sub-command handler was not called") + } +} + +func TestExecutor_SubCommand_NoArg_RepliesUsage(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + {Name: "channel"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show", + Reply: func(text string) error { reply = text; return nil }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /show [model|channel]" { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestExecutor_SubCommand_UnknownArg_RepliesError(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show foobar", + Reply: func(text string) error { reply = text; return nil }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "foobar") { + t.Fatalf("reply=%q, should mention unknown sub-command", reply) + } +} + +func TestExecutor_SubCommand_NilHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, // nil Handler + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Text: "/show model"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} diff --git a/pkg/commands/handler_agents.go b/pkg/commands/handler_agents.go new file mode 100644 index 000000000..c459516eb --- /dev/null +++ b/pkg/commands/handler_agents.go @@ -0,0 +1,21 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +// agentsHandler returns a shared handler for both /show agents and /list agents. +func agentsHandler() Handler { + return func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.ListAgentIDs == nil { + return req.Reply(unavailableMsg) + } + ids := rt.ListAgentIDs() + if len(ids) == 0 { + return req.Reply("No agents registered") + } + return req.Reply(fmt.Sprintf("Registered agents: %s", strings.Join(ids, ", "))) + } +} diff --git a/pkg/commands/registry.go b/pkg/commands/registry.go new file mode 100644 index 000000000..e17d489a6 --- /dev/null +++ b/pkg/commands/registry.go @@ -0,0 +1,55 @@ +package commands + +type Registry struct { + defs []Definition + index map[string]int +} + +// NewRegistry stores the canonical command set used by both dispatch and +// optional platform registration adapters. +func NewRegistry(defs []Definition) *Registry { + stored := make([]Definition, len(defs)) + copy(stored, defs) + + index := make(map[string]int, len(stored)*2) + for i, def := range stored { + registerCommandName(index, def.Name, i) + for _, alias := range def.Aliases { + registerCommandName(index, alias, i) + } + } + + return &Registry{defs: stored, index: index} +} + +// Definitions returns all registered command definitions. +// Command availability is global and no longer channel-scoped. +func (r *Registry) Definitions() []Definition { + out := make([]Definition, len(r.defs)) + copy(out, r.defs) + return out +} + +// Lookup returns a command definition by normalized command name or alias. +func (r *Registry) Lookup(name string) (Definition, bool) { + key := normalizeCommandName(name) + if key == "" { + return Definition{}, false + } + idx, ok := r.index[key] + if !ok { + return Definition{}, false + } + return r.defs[idx], true +} + +func registerCommandName(index map[string]int, name string, defIndex int) { + key := normalizeCommandName(name) + if key == "" { + return + } + if _, exists := index[key]; exists { + return + } + index[key] = defIndex +} diff --git a/pkg/commands/registry_test.go b/pkg/commands/registry_test.go new file mode 100644 index 000000000..bfff76b7c --- /dev/null +++ b/pkg/commands/registry_test.go @@ -0,0 +1,49 @@ +package commands + +import "testing" + +func TestRegistry_Definitions_ReturnsCopy(t *testing.T) { + defs := []Definition{ + {Name: "help", Description: "Show help"}, + {Name: "admin", Description: "Admin command"}, + } + r := NewRegistry(defs) + + got := r.Definitions() + if len(got) != 2 { + t.Fatalf("definitions len = %d, want 2", len(got)) + } + + got[0].Name = "mutated" + again := r.Definitions() + if again[0].Name != "help" { + t.Fatalf("registry should not be mutated by caller, got first name %q", again[0].Name) + } +} + +func TestRegistry_Lookup_MatchesByLowercaseNameAndAlias(t *testing.T) { + r := NewRegistry([]Definition{ + {Name: "Help", Aliases: []string{"Assist"}}, + {Name: "List"}, + }) + + def, ok := r.Lookup("help") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by lowercase name failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("HELP") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by uppercase name failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("assist") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by lowercase alias failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("ASSIST") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by uppercase alias failed: ok=%v def=%+v", ok, def) + } +} diff --git a/pkg/commands/request.go b/pkg/commands/request.go new file mode 100644 index 000000000..62ee600f2 --- /dev/null +++ b/pkg/commands/request.go @@ -0,0 +1,75 @@ +package commands + +import ( + "context" + "strings" +) + +type Handler func(ctx context.Context, req Request, rt *Runtime) error + +type Request struct { + Channel string + ChatID string + SenderID string + Text string + Reply func(text string) error +} + +const unavailableMsg = "Command unavailable in current context." + +var commandPrefixes = []string{"/", "!"} + +// parseCommandName accepts "/name", "!name", and Telegram's "/name@bot", then +// normalizes to lowercase command names. +func parseCommandName(input string) (string, bool) { + token := nthToken(input, 0) + if token == "" { + return "", false + } + + name, ok := trimCommandPrefix(token) + if !ok { + return "", false + } + if i := strings.Index(name, "@"); i >= 0 { + name = name[:i] + } + name = normalizeCommandName(name) + if name == "" { + return "", false + } + return name, true +} + +func trimCommandPrefix(token string) (string, bool) { + for _, prefix := range commandPrefixes { + if strings.HasPrefix(token, prefix) { + return strings.TrimPrefix(token, prefix), true + } + } + return "", false +} + +// HasCommandPrefix returns true if the input starts with a recognized +// command prefix (e.g. "/" or "!"). +func HasCommandPrefix(input string) bool { + token := nthToken(input, 0) + if token == "" { + return false + } + _, ok := trimCommandPrefix(token) + return ok +} + +// nthToken returns the 0-indexed token from whitespace-split input. +func nthToken(input string, n int) string { + parts := strings.Fields(strings.TrimSpace(input)) + if n >= len(parts) { + return "" + } + return parts[n] +} + +func normalizeCommandName(name string) string { + return strings.ToLower(strings.TrimSpace(name)) +} diff --git a/pkg/commands/request_test.go b/pkg/commands/request_test.go new file mode 100644 index 000000000..4389e453b --- /dev/null +++ b/pkg/commands/request_test.go @@ -0,0 +1,28 @@ +package commands + +import "testing" + +func TestHasCommandPrefix(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"/help", true}, + {"!help", true}, + {"/switch model to gpt-4", true}, + {"!switch model to gpt-4", true}, + {"hello", false}, + {"", false}, + {" ", false}, + {"hello /world", false}, + {"/", true}, + {"!", true}, + {" /help", true}, + } + for _, tt := range tests { + got := HasCommandPrefix(tt.input) + if got != tt.want { + t.Errorf("HasCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/commands/runtime.go b/pkg/commands/runtime.go new file mode 100644 index 000000000..227d495f4 --- /dev/null +++ b/pkg/commands/runtime.go @@ -0,0 +1,16 @@ +package commands + +import "github.com/sipeed/picoclaw/pkg/config" + +// Runtime provides runtime dependencies to command handlers. It is constructed +// per-request by the agent loop so that per-request state (like session scope) +// can coexist with long-lived callbacks (like GetModelInfo). +type Runtime struct { + Config *config.Config + GetModelInfo func() (name, provider string) + ListAgentIDs func() []string + ListDefinitions func() []Definition + GetEnabledChannels func() []string + SwitchModel func(value string) (oldModel string, err error) + SwitchChannel func(value string) error +} diff --git a/pkg/commands/show_list_handlers_test.go b/pkg/commands/show_list_handlers_test.go new file mode 100644 index 000000000..047708f0f --- /dev/null +++ b/pkg/commands/show_list_handlers_test.go @@ -0,0 +1,85 @@ +package commands + +import ( + "context" + "strings" + "testing" +) + +func TestShowListHandlers_ChannelPolicy(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), nil) + + var telegramReply string + handled := ex.Execute(context.Background(), Request{ + Channel: "telegram", + Text: "/show channel", + Reply: func(text string) error { + telegramReply = text + return nil + }, + }) + if handled.Outcome != OutcomeHandled { + t.Fatalf("telegram /show outcome=%v, want=%v", handled.Outcome, OutcomeHandled) + } + if telegramReply != "Current Channel: telegram" { + t.Fatalf("telegram /show reply=%q, want=%q", telegramReply, "Current Channel: telegram") + } + + var whatsappReply string + handledWhatsApp := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/show channel", + Reply: func(text string) error { + whatsappReply = text + return nil + }, + }) + if handledWhatsApp.Outcome != OutcomeHandled { + t.Fatalf("whatsapp /show outcome=%v, want=%v", handledWhatsApp.Outcome, OutcomeHandled) + } + if handledWhatsApp.Command != "show" { + t.Fatalf("whatsapp /show command=%q, want=%q", handledWhatsApp.Command, "show") + } + if whatsappReply != "Current Channel: whatsapp" { + t.Fatalf("whatsapp /show reply=%q, want=%q", whatsappReply, "Current Channel: whatsapp") + } + + passthrough := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/foo", + }) + if passthrough.Outcome != OutcomePassthrough { + t.Fatalf("whatsapp /foo outcome=%v, want=%v", passthrough.Outcome, OutcomePassthrough) + } + if passthrough.Command != "foo" { + t.Fatalf("whatsapp /foo command=%q, want=%q", passthrough.Command, "foo") + } +} + +func TestShowListHandlers_ListHandledOnAllChannels(t *testing.T) { + rt := &Runtime{ + GetEnabledChannels: func() []string { + return []string{"telegram"} + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/list channels", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("whatsapp /list outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if res.Command != "list" { + t.Fatalf("whatsapp /list command=%q, want=%q", res.Command, "list") + } + if !strings.Contains(reply, "telegram") { + t.Fatalf("whatsapp /list reply=%q, expected enabled channels content", reply) + } +}