diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 786c893ef..0edd29f22 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,6 +17,11 @@ on: required: false type: boolean default: false + upload_tos: + description: "Upload to Volcengine TOS" + required: false + type: boolean + default: true jobs: create-tag: @@ -100,3 +105,12 @@ jobs: gh release edit "${{ inputs.tag }}" \ --draft=${{ inputs.draft }} \ --prerelease=${{ inputs.prerelease }} + + upload-tos: + name: Upload to TOS + needs: release + if: ${{ inputs.upload_tos }} + uses: ./.github/workflows/upload-tos.yml + with: + tag: ${{ inputs.tag }} + secrets: inherit diff --git a/.github/workflows/upload-tos.yml b/.github/workflows/upload-tos.yml new file mode 100644 index 000000000..6d3916d53 --- /dev/null +++ b/.github/workflows/upload-tos.yml @@ -0,0 +1,49 @@ +name: Upload to Volcengine TOS + +on: + workflow_dispatch: + inputs: + tag: + description: "Release tag to download and upload (e.g. v0.2.0)" + required: true + type: string + workflow_call: + inputs: + tag: + description: "Release tag to download and upload" + required: true + type: string + +jobs: + upload-tos: + name: Upload to Volcengine TOS + runs-on: ubuntu-latest + steps: + - name: Download release assets + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + mkdir -p artifacts + gh release download "${{ inputs.tag }}" \ + --repo "${{ github.repository }}" \ + --dir artifacts \ + --pattern "*.tar.gz" \ + --pattern "*.zip" \ + --pattern "*.rpm" \ + --pattern "*.deb" + + - name: Upload to Volcengine TOS + env: + AWS_ACCESS_KEY_ID: ${{ secrets.VOLC_TOS_ACCESS_KEY }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.VOLC_TOS_SECRET_KEY }} + AWS_DEFAULT_REGION: cn-beijing + run: | + aws configure set default.s3.addressing_style virtual + TOS_ENDPOINT="https://tos-s3-cn-beijing.volces.com" + # Upload to versioned directory + aws s3 sync artifacts/ "s3://picoclaw-downloads/${{ inputs.tag }}/" \ + --endpoint-url "$TOS_ENDPOINT" + # Upload to latest (overwrite) + aws s3 sync artifacts/ "s3://picoclaw-downloads/latest/" \ + --endpoint-url "$TOS_ENDPOINT" \ + --delete diff --git a/.gitignore b/.gitignore index 02ef18d1f..a52b8d25a 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,9 @@ ralph/ .ralph/ tasks/ +# Plans +docs/plans/ + # Editors .vscode/ .idea/ diff --git a/LICENSE b/LICENSE index 410acae26..b38d9340d 100644 --- a/LICENSE +++ b/LICENSE @@ -19,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ---- - -PicoClaw is heavily inspired by and based on [nanobot](https://github.com/HKUDS/nanobot) by HKUDS. diff --git a/README.md b/README.md index c1ef72141..3774055b4 100644 --- a/README.md +++ b/README.md @@ -216,7 +216,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d > [!TIP] > Set your API key in `~/.picoclaw/config.json`. > Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) -> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback. +> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month), [SearXNG](https://github.com/searxng/searxng) (free, self-hosted) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback. **1. Initialize** @@ -265,6 +265,16 @@ picoclaw onboard "duckduckgo": { "enabled": true, "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "YOUR_PERPLEXITY_API_KEY", + "max_results": 5 + }, + "searxng": { + "enabled": false, + "base_url": "http://your-searxng-instance:8888", + "max_results": 5 } } } @@ -277,7 +287,12 @@ picoclaw onboard **3. Get API Keys** * **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) -* **Web Search** (optional): [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month) · [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month) +* **Web Search** (optional): + * [Brave Search](https://brave.com/search/api) - Paid ($5/1000 queries, ~$5-6/month) + * [Perplexity](https://www.perplexity.ai) - AI-powered search with chat interface + * [SearXNG](https://github.com/searxng/searxng) - Self-hosted metasearch engine (free, no API key needed) + * [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month) + * DuckDuckGo - Built-in fallback (no API key required) > **Note**: See `config.example.json` for a complete configuration template. @@ -338,6 +353,13 @@ Talk to your picoclaw through Telegram, Discord, WhatsApp, DingTalk, LINE, or We picoclaw gateway ``` +**4. Telegram command menu (auto-registered at startup)** + +PicoClaw now keeps command definitions in one shared registry. On startup, Telegram will automatically register supported bot commands (for example `/start`, `/help`, `/show`, `/list`) so command menu and runtime behavior stay in sync. +Telegram command menu registration remains channel-local discovery UX; generic command execution is handled centrally in the agent loop via the commands executor. + +If command registration fails (network/API transient errors), the channel still starts and PicoClaw retries registration in the background. +
@@ -735,6 +757,12 @@ For advanced/test setups, you can override the builtin skills root with: export PICOCLAW_BUILTIN_SKILLS=/path/to/skills ``` +### Unified Command Execution Policy + +- Generic slash commands are executed through a single path in `pkg/agent/loop.go` via `commands.Executor`. +- Channel adapters no longer consume generic commands locally; they forward inbound text to the bus/agent path. Telegram still auto-registers supported commands at startup. +- Unknown slash command (for example `/foo`) passes through to normal LLM processing. +- Registered but unsupported command on the current channel (for example `/show` on WhatsApp) returns an explicit user-facing error and stops further processing. ### 🔒 Security Sandbox PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. @@ -1190,6 +1218,10 @@ picoclaw agent -m "Hello" "model": "anthropic/claude-opus-4-5" } }, + "session": { + "dm_scope": "per-channel-peer", + "backlog_limit": 20 + }, "providers": { "openrouter": { "api_key": "sk-or-v1-xxx" @@ -1241,6 +1273,16 @@ picoclaw agent -m "Hello" "duckduckgo": { "enabled": true, "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "", + "max_results": 5 + }, + "searxng": { + "enabled": false, + "base_url": "http://localhost:8888", + "max_results": 5 } }, "cron": { @@ -1298,10 +1340,69 @@ discord: This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching. -To enable web search: +#### Search Provider Priority -1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results. -2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required). +PicoClaw automatically selects the best available search provider in this order: +1. **Perplexity** (if enabled and API key configured) - AI-powered search with citations +2. **Brave Search** (if enabled and API key configured) - Privacy-focused paid API ($5/1000 queries) +3. **SearXNG** (if enabled and base_url configured) - Self-hosted metasearch aggregating 70+ engines (free) +4. **DuckDuckGo** (if enabled, default fallback) - No API key required (free) + +#### Web Search Configuration Options + +**Option 1 (Best Results)**: Perplexity AI Search +```json +{ + "tools": { + "web": { + "perplexity": { + "enabled": true, + "api_key": "YOUR_PERPLEXITY_API_KEY", + "max_results": 5 + } + } + } +} +``` + +**Option 2 (Paid API)**: Get an API key at [https://brave.com/search/api](https://brave.com/search/api) ($5/1000 queries, ~$5-6/month) +```json +{ + "tools": { + "web": { + "brave": { + "enabled": true, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + } +} +``` + +**Option 3 (Self-Hosted)**: Deploy your own [SearXNG](https://github.com/searxng/searxng) instance +```json +{ + "tools": { + "web": { + "searxng": { + "enabled": true, + "base_url": "http://your-server:8888", + "max_results": 5 + } + } + } +} +``` + +Benefits of SearXNG: +- **Zero cost**: No API fees or rate limits +- **Privacy-focused**: Self-hosted, no tracking +- **Aggregate results**: Queries 70+ search engines simultaneously +- **Perfect for cloud VMs**: Solves datacenter IP blocking issues (Oracle Cloud, GCP, AWS, Azure) +- **No API key needed**: Just deploy and configure the base URL + +**Option 4 (No Setup Required)**: DuckDuckGo is enabled by default as fallback (no API key needed) Add the key to `~/.picoclaw/config.json` if using Brave: @@ -1317,6 +1418,16 @@ Add the key to `~/.picoclaw/config.json` if using Brave: "duckduckgo": { "enabled": true, "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "YOUR_PERPLEXITY_API_KEY", + "max_results": 5 + }, + "searxng": { + "enabled": false, + "base_url": "http://your-searxng-instance:8888", + "max_results": 5 } } } @@ -1335,10 +1446,11 @@ This happens when another instance of the bot is running. Make sure only one `pi ## 📝 API Key Comparison -| Service | Free Tier | Use Case | -| ---------------- | ------------------- | ------------------------------------- | -| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | -| **Zhipu** | 200K tokens/month | Best for Chinese users | -| **Brave Search** | 2000 queries/month | Web search functionality | -| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | -| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) | +| Service | Free Tier | Use Case | +| ---------------- | ------------------------ | ------------------------------------- | +| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/month | Best for Chinese users | +| **Brave Search** | Paid ($5/1000 queries) | Web search functionality | +| **SearXNG** | Unlimited (self-hosted) | Privacy-focused metasearch (70+ engines) | +| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | +| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) | diff --git a/README.zh.md b/README.zh.md index bd90173f9..dc32b67e0 100644 --- a/README.zh.md +++ b/README.zh.md @@ -307,6 +307,13 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方 | **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) | | **MaixCam** | ⭐ 简单 | 专为 AI 摄像头设计的硬件集成通道 | [查看文档](docs/channels/maixcam/README.zh.md) | +### Telegram 命令注册(启动时自动同步) + +PicoClaw 现在使用统一的命令定义来源。启动时会自动将 Telegram 支持的命令(例如 `/start`、`/help`、`/show`、`/list`)注册到 Bot 命令菜单,确保菜单展示与实际行为一致。 +Telegram 侧保留的是命令菜单注册能力;通用命令的实际执行统一走 Agent Loop 中的 commands executor。 + +如果注册因网络或 API 短暂异常失败,不会阻塞 channel 启动;系统会在后台自动重试。 + ## ClawdChat 加入 Agent 社交网络 只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。 @@ -376,6 +383,12 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work export PICOCLAW_BUILTIN_SKILLS=/path/to/skills ``` +### 统一命令执行策略 + +- 通用斜杠命令通过 `pkg/agent/loop.go` 中的 `commands.Executor` 统一执行。 +- Channel 适配器不再在本地消费通用命令;它们只负责把入站文本转发到 bus/agent 路径。Telegram 仍会在启动时自动注册其支持的命令菜单。 +- 未注册的斜杠命令(例如 `/foo`)会透传给 LLM 按普通输入处理。 +- 已注册但当前 channel 不支持的命令(例如 WhatsApp 上的 `/show`)会返回明确的用户可见错误,并停止后续处理。 ### 心跳 / 周期性任务 (Heartbeat) PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: @@ -715,6 +728,10 @@ picoclaw agent -m "你好" "model": "anthropic/claude-opus-4-5" } }, + "session": { + "dm_scope": "per-channel-peer", + "backlog_limit": 20 + }, "providers": { "openrouter": { "api_key": "sk-or-v1-xxx" diff --git a/cmd/picoclaw-launcher-tui/internal/ui/model.go b/cmd/picoclaw-launcher-tui/internal/ui/model.go index ba91f5b09..304b4efa7 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/model.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/model.go @@ -335,7 +335,11 @@ func (s *appState) testModel(model *picoclawconfig.ModelConfig) { s.showMessage("Test OK", resp.Status) return } - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 2048)) + if err != nil { + s.showMessage("Test failed", fmt.Sprintf("failed to read response: %v", err)) + return + } s.showMessage( "Test failed", fmt.Sprintf("%s: %s", resp.Status, strings.TrimSpace(string(body))), diff --git a/cmd/picoclaw-launcher/internal/server/auth_handlers.go b/cmd/picoclaw-launcher/internal/server/auth_handlers.go index 1e9b8be0a..3b48f9739 100644 --- a/cmd/picoclaw-launcher/internal/server/auth_handlers.go +++ b/cmd/picoclaw-launcher/internal/server/auth_handlers.go @@ -297,7 +297,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading userinfo response: %w", err) + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("userinfo request failed: %s", string(body)) } diff --git a/cmd/picoclaw/internal/auth/helpers.go b/cmd/picoclaw/internal/auth/helpers.go index 633ce8740..4dfbc92e7 100644 --- a/cmd/picoclaw/internal/auth/helpers.go +++ b/cmd/picoclaw/internal/auth/helpers.go @@ -177,7 +177,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading userinfo response: %w", err) + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("userinfo request failed: %s", string(body)) } diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 5225340c7..174f5db62 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -230,19 +230,25 @@ func setupCronTool( // Create cron service cronService := cron.NewCronService(cronStorePath, nil) - // Create and register CronTool - cronTool, err := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) - if err != nil { - log.Fatalf("Critical error during CronTool initialization: %v", err) + // Create and register CronTool if enabled + var cronTool *tools.CronTool + if cfg.Tools.IsToolEnabled("cron") { + var err error + cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) + if err != nil { + log.Fatalf("Critical error during CronTool initialization: %v", err) + } + + agentLoop.RegisterTool(cronTool) } - agentLoop.RegisterTool(cronTool) - - // Set the onJob handler - cronService.SetOnJob(func(job *cron.CronJob) (string, error) { - result := cronTool.ExecuteJob(context.Background(), job) - return result, nil - }) + // Set onJob handler + if cronTool != nil { + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + } return cronService } diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go index 9655d3c08..f81d7013d 100644 --- a/cmd/picoclaw/internal/helpers.go +++ b/cmd/picoclaw/internal/helpers.go @@ -18,12 +18,21 @@ var ( goVersion string ) +// GetPicoclawHome returns the picoclaw home directory. +// Priority: $PICOCLAW_HOME > ~/.picoclaw +func GetPicoclawHome() string { + if home := os.Getenv("PICOCLAW_HOME"); home != "" { + return home + } + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw") +} + func GetConfigPath() string { if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" { return configPath } - home, _ := os.UserHomeDir() - return filepath.Join(home, ".picoclaw", "config.json") + return filepath.Join(GetPicoclawHome(), "config.json") } func LoadConfig() (*config.Config, error) { diff --git a/cmd/picoclaw/internal/helpers_test.go b/cmd/picoclaw/internal/helpers_test.go index 47e2f8c07..646be1ba1 100644 --- a/cmd/picoclaw/internal/helpers_test.go +++ b/cmd/picoclaw/internal/helpers_test.go @@ -19,6 +19,27 @@ func TestGetConfigPath(t *testing.T) { assert.Equal(t, want, got) } +func TestGetConfigPath_WithPICOCLAW_HOME(t *testing.T) { + t.Setenv("PICOCLAW_HOME", "/custom/picoclaw") + t.Setenv("HOME", "/tmp/home") + + got := GetConfigPath() + want := filepath.Join("/custom/picoclaw", "config.json") + + assert.Equal(t, want, got) +} + +func TestGetConfigPath_WithPICOCLAW_CONFIG(t *testing.T) { + t.Setenv("PICOCLAW_CONFIG", "/custom/config.json") + t.Setenv("PICOCLAW_HOME", "/custom/picoclaw") + t.Setenv("HOME", "/tmp/home") + + got := GetConfigPath() + want := "/custom/config.json" + + assert.Equal(t, want, got) +} + func TestFormatVersion_NoGitCommit(t *testing.T) { oldVersion, oldGit := version, gitCommit t.Cleanup(func() { version, gitCommit = oldVersion, oldGit }) diff --git a/cmd/picoclaw/internal/skills/install.go b/cmd/picoclaw/internal/skills/install.go index a30f68632..78bc421db 100644 --- a/cmd/picoclaw/internal/skills/install.go +++ b/cmd/picoclaw/internal/skills/install.go @@ -21,8 +21,8 @@ picoclaw skills install --registry clawhub github `, Args: func(cmd *cobra.Command, args []string) error { if registry != "" { - if len(args) != 2 { - return fmt.Errorf("when --registry is set, exactly 2 arguments are required: ") + if len(args) != 1 { + return fmt.Errorf("when --registry is set, exactly 1 argument is required: ") } return nil } @@ -45,7 +45,7 @@ picoclaw skills install --registry clawhub github return err } - return skillsInstallFromRegistry(cfg, args[0], args[1]) + return skillsInstallFromRegistry(cfg, registry, args[0]) } return skillsInstallCmd(installer, args[0]) diff --git a/cmd/picoclaw/internal/skills/install_test.go b/cmd/picoclaw/internal/skills/install_test.go index 97787a986..6b362822d 100644 --- a/cmd/picoclaw/internal/skills/install_test.go +++ b/cmd/picoclaw/internal/skills/install_test.go @@ -26,3 +26,72 @@ func TestNewInstallSubcommand(t *testing.T) { assert.Len(t, cmd.Aliases, 0) } + +func TestInstallCommandArgs(t *testing.T) { + tests := []struct { + name string + args []string + registry string + expectError bool + errorMsg string + }{ + { + name: "no registry, one arg", + args: []string{"sipeed/picoclaw-skills/weather"}, + registry: "", + expectError: false, + }, + { + name: "no registry, no args", + args: []string{}, + registry: "", + expectError: true, + errorMsg: "exactly 1 argument is required: ", + }, + { + name: "no registry, too many args", + args: []string{"arg1", "arg2"}, + registry: "", + expectError: true, + errorMsg: "exactly 1 argument is required: ", + }, + { + name: "with registry, one arg", + args: []string{"weather-skill"}, + registry: "clawhub", + expectError: false, + }, + { + name: "with registry, no args", + args: []string{}, + registry: "clawhub", + expectError: true, + errorMsg: "when --registry is set, exactly 1 argument is required: ", + }, + { + name: "with registry, too many args", + args: []string{"arg1", "arg2"}, + registry: "clawhub", + expectError: true, + errorMsg: "when --registry is set, exactly 1 argument is required: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := newInstallCommand(nil) + + if tt.registry != "" { + require.NoError(t, cmd.Flags().Set("registry", tt.registry)) + } + + err := cmd.Args(cmd, tt.args) + if tt.expectError { + require.Error(t, err) + assert.Equal(t, tt.errorMsg, err.Error()) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/config/config.example.json b/config/config.example.json index c59a39885..2f643d41b 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -232,24 +232,46 @@ } }, "tools": { + "allow_read_paths": null, + "allow_write_paths": null, "web": { + "enabled": true, "brave": { "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 }, + "tavily": { + "enabled": false, + "api_key": "", + "base_url": "", + "max_results": 0 + }, "duckduckgo": { "enabled": true, "max_results": 5 }, "perplexity": { "enabled": false, - "api_key": "pplx-xxx", + "api_key": "", "max_results": 5 }, - "proxy": "" + "searxng": { + "enabled": false, + "base_url": "http://localhost:8888", + "max_results": 5 + }, + "glm_search": { + "enabled": false, + "api_key": "", + "base_url": "https://open.bigmodel.cn/api/paas/v4/web_search", + "search_engine": "search_std", + "max_results": 5 + }, + "fetch_limit_bytes": 10485760 }, "cron": { + "enabled": true, "exec_timeout_minutes": 5 }, "mcp": { @@ -318,19 +340,75 @@ } }, "exec": { - "enable_deny_patterns": false, - "custom_deny_patterns": [] + "enabled": true, + "enable_deny_patterns": true, + "custom_deny_patterns": null, + "custom_allow_patterns": null }, "skills": { + "enabled": true, "registries": { "clawhub": { "enabled": true, "base_url": "https://clawhub.ai", - "search_path": "/api/v1/search", - "skills_path": "/api/v1/skills", - "download_path": "/api/v1/download" + "auth_token": "", + "search_path": "", + "skills_path": "", + "download_path": "", + "timeout": 0, + "max_zip_size": 0, + "max_response_size": 0 } + }, + "max_concurrent_searches": 2, + "search_cache": { + "max_size": 50, + "ttl_seconds": 300 } + }, + "media_cleanup": { + "enabled": true, + "max_age_minutes": 30, + "interval_minutes": 5 + }, + "append_file": { + "enabled": true + }, + "edit_file": { + "enabled": true + }, + "find_skills": { + "enabled": true + }, + "i2c": { + "enabled": false + }, + "install_skill": { + "enabled": true + }, + "list_dir": { + "enabled": true + }, + "message": { + "enabled": true + }, + "read_file": { + "enabled": true + }, + "spawn": { + "enabled": true + }, + "spi": { + "enabled": false + }, + "subagent": { + "enabled": true + }, + "web_fetch": { + "enabled": true + }, + "write_file": { + "enabled": true } }, "heartbeat": { diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md index 6204fb0c8..e64a3a107 100644 --- a/docs/tools_configuration.md +++ b/docs/tools_configuration.md @@ -180,6 +180,7 @@ The skills tool configures skill discovery and installation via registries like | ---------------------------------- | ------ | -------------------- | ----------------------- | | `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | | `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | +| `registries.clawhub.auth_token` | string | `""` | Optional Bearer token for higher rate limits | | `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | | `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | | `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | @@ -194,6 +195,7 @@ The skills tool configures skill discovery and installation via registries like "clawhub": { "enabled": true, "base_url": "https://clawhub.ai", + "auth_token": "", "search_path": "/api/v1/search", "skills_path": "/api/v1/skills", "download_path": "/api/v1/download" diff --git a/go.mod b/go.mod index 238bd405c..6fa3a900c 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 + github.com/h2non/filetype v1.1.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mdp/qrterminal/v3 v3.2.1 github.com/modelcontextprotocol/go-sdk v1.3.0 @@ -37,7 +38,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect - github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 3aa903b3f..719b0cb6d 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -42,6 +42,9 @@ type ContextBuilder struct { } func getGlobalConfigDir() string { + if home := os.Getenv("PICOCLAW_HOME"); home != "" { + return home + } home, err := os.UserHomeDir() if err != nil { return "" @@ -602,7 +605,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message } } - return sanitized + // Second pass: ensure every assistant message with tool_calls has matching + // tool result messages following it. This is required by strict providers + // like DeepSeek that enforce: "An assistant message with 'tool_calls' must + // be followed by tool messages responding to each 'tool_call_id'." + final := make([]providers.Message, 0, len(sanitized)) + for i := 0; i < len(sanitized); i++ { + msg := sanitized[i] + if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { + // Collect expected tool_call IDs + expected := make(map[string]bool, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + expected[tc.ID] = false + } + + // Check following messages for matching tool results + toolMsgCount := 0 + for j := i + 1; j < len(sanitized); j++ { + if sanitized[j].Role != "tool" { + break + } + toolMsgCount++ + if _, exists := expected[sanitized[j].ToolCallID]; exists { + expected[sanitized[j].ToolCallID] = true + } + } + + // If any tool_call_id is missing, drop this assistant message and its partial tool messages + allFound := true + for toolCallID, found := range expected { + if !found { + allFound = false + logger.DebugCF( + "agent", + "Dropping assistant message with incomplete tool results", + map[string]any{ + "missing_tool_call_id": toolCallID, + "expected_count": len(expected), + "found_count": toolMsgCount, + }, + ) + break + } + } + + if !allFound { + // Skip this assistant message and its tool messages + i += toolMsgCount + continue + } + } + final = append(final, msg) + } + + return final } func (cb *ContextBuilder) AddToolResult( diff --git a/pkg/agent/context_test.go b/pkg/agent/context_test.go index e023c9c30..5756ed911 100644 --- a/pkg/agent/context_test.go +++ b/pkg/agent/context_test.go @@ -207,3 +207,77 @@ func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) { } } } + +// TestSanitizeHistoryForProvider_IncompleteToolResults tests the forward validation +// that ensures assistant messages with tool_calls have ALL matching tool results. +// This fixes the DeepSeek error: "An assistant message with 'tool_calls' must be +// followed by tool messages responding to each 'tool_call_id'." +func TestSanitizeHistoryForProvider_IncompleteToolResults(t *testing.T) { + // Assistant expects tool results for both A and B, but only A is present + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + // toolResult("B") is missing - this would cause DeepSeek to fail + msg("user", "next question"), + msg("assistant", "answer"), + } + + result := sanitizeHistoryForProvider(history) + // The assistant message with incomplete tool results should be dropped, + // along with its partial tool result. The remaining messages are: + // user ("do two things"), user ("next question"), assistant ("answer") + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "user", "assistant") +} + +// TestSanitizeHistoryForProvider_MissingAllToolResults tests the case where +// an assistant message has tool_calls but no tool results follow at all. +func TestSanitizeHistoryForProvider_MissingAllToolResults(t *testing.T) { + history := []providers.Message{ + msg("user", "do something"), + assistantWithTools("A"), + // No tool results at all + msg("user", "hello"), + msg("assistant", "hi"), + } + + result := sanitizeHistoryForProvider(history) + // The assistant message with no tool results should be dropped. + // Remaining: user ("do something"), user ("hello"), assistant ("hi") + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "user", "assistant") +} + +// TestSanitizeHistoryForProvider_PartialToolResultsInMiddle tests that +// incomplete tool results in the middle of a conversation are properly handled. +func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) { + history := []providers.Message{ + msg("user", "first"), + assistantWithTools("A"), + toolResult("A"), + msg("assistant", "done"), + msg("user", "second"), + assistantWithTools("B", "C"), + toolResult("B"), + // toolResult("C") is missing + msg("user", "third"), + assistantWithTools("D"), + toolResult("D"), + msg("assistant", "all done"), + } + + result := sanitizeHistoryForProvider(history) + // First round is complete (user, assistant+tools, tool, assistant), + // second round is incomplete and dropped (assistant+tools, partial tool), + // third round is complete (user, assistant+tools, tool, assistant). + // Remaining: user, assistant, tool, assistant, user, user, assistant, tool, assistant + if len(result) != 9 { + t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant") +} diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 1e18b6f64..97cf0fa05 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -37,6 +37,14 @@ type AgentInstance struct { Subagents *config.SubagentsConfig SkillsFilter []string Candidates []providers.FallbackCandidate + + // Router is non-nil when model routing is configured and the light model + // was successfully resolved. It scores each incoming message and decides + // whether to route to LightCandidates or stay with Candidates. + Router *routing.Router + // LightCandidates holds the resolved provider candidates for the light model. + // Pre-computed at agent creation to avoid repeated model_list lookups at runtime. + LightCandidates []providers.FallbackCandidate } // NewAgentInstance creates an agent instance from config. @@ -60,17 +68,30 @@ func NewAgentInstance( allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths) toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths)) - toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths)) - toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths)) - execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg) - if err != nil { - log.Fatalf("Critical error: unable to initialize exec tool: %v", err) - } - toolsRegistry.Register(execTool) - toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths)) - toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths)) + if cfg.Tools.IsToolEnabled("read_file") { + toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths)) + } + if cfg.Tools.IsToolEnabled("write_file") { + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths)) + } + if cfg.Tools.IsToolEnabled("list_dir") { + toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths)) + } + if cfg.Tools.IsToolEnabled("exec") { + execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg) + if err != nil { + log.Fatalf("Critical error: unable to initialize exec tool: %v", err) + } + toolsRegistry.Register(execTool) + } + + if cfg.Tools.IsToolEnabled("edit_file") { + toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths)) + } + if cfg.Tools.IsToolEnabled("append_file") { + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths)) + } sessionsDir := filepath.Join(workspace, "sessions") sessionsManager := session.NewSessionManager(sessionsDir) @@ -167,6 +188,25 @@ func NewAgentInstance( candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList) + // Model routing setup: pre-resolve light model candidates at creation time + // to avoid repeated model_list lookups on every incoming message. + var router *routing.Router + var lightCandidates []providers.FallbackCandidate + if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" { + lightModelCfg := providers.ModelConfig{Primary: rc.LightModel} + resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList) + if len(resolved) > 0 { + router = routing.New(routing.RouterConfig{ + LightModel: rc.LightModel, + Threshold: rc.Threshold, + }) + lightCandidates = resolved + } else { + log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q", + rc.LightModel, agentID) + } + } + return &AgentInstance{ ID: agentID, Name: agentName, @@ -187,6 +227,8 @@ func NewAgentInstance( Subagents: subagents, SkillsFilter: skillsFilter, Candidates: candidates, + Router: router, + LightCandidates: lightCandidates, } } @@ -195,12 +237,13 @@ func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentD if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" { return expandHome(strings.TrimSpace(agentCfg.Workspace)) } + // Use the configured default workspace (respects PICOCLAW_HOME) if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" { return expandHome(defaults.Workspace) } - home, _ := os.UserHomeDir() + // For named agents without explicit workspace, use default workspace with agent ID suffix id := routing.NormalizeAgentID(agentCfg.ID) - return filepath.Join(home, ".picoclaw", "workspace-"+id) + return filepath.Join(expandHome(defaults.Workspace), "..", "workspace-"+id) } // resolveAgentModel resolves the primary model for an agent. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 263eeb4dd..966668227 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/commands" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" @@ -46,6 +47,7 @@ type AgentLoop struct { channelManager *channels.Manager mediaStore media.MediaStore transcriber voice.Transcriber + cmdRegistry *commands.Registry } // processOptions configures how a message is processed @@ -61,7 +63,15 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." +const ( + defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." + sessionKeyAgentPrefix = "agent:" + metadataKeyAccountID = "account_id" + metadataKeyGuildID = "guild_id" + metadataKeyTeamID = "team_id" + metadataKeyParentPeerKind = "parent_peer_kind" + metadataKeyParentPeerID = "parent_peer_id" +) func NewAgentLoop( cfg *config.Config, @@ -84,14 +94,17 @@ func NewAgentLoop( stateManager = state.NewManager(defaultAgent.Workspace) } - return &AgentLoop{ + al := &AgentLoop{ bus: msgBus, cfg: cfg, registry: registry, state: stateManager, summarizing: sync.Map{}, fallback: fallbackChain, + cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()), } + + return al } // registerSharedTools registers tools that are shared across all agents (web, message, spawn). @@ -108,76 +121,106 @@ func registerSharedTools( } // Web tools - searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKey: cfg.Tools.Web.Brave.APIKey, - BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, - BraveEnabled: cfg.Tools.Web.Brave.Enabled, - TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, - TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, - TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, - TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, - DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, - DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, - PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, - PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, - PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, - GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey, - GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL, - GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine, - GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults, - GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled, - Proxy: cfg.Tools.Web.Proxy, - }) - if err != nil { - logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) - } else if searchTool != nil { - agent.Tools.Register(searchTool) + if cfg.Tools.IsToolEnabled("web") { + searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, + TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, + TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, + TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, + PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL, + SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults, + SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled, + GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey, + GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL, + GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine, + GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults, + GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled, + Proxy: cfg.Tools.Web.Proxy, + }) + if err != nil { + logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) + } else if searchTool != nil { + agent.Tools.Register(searchTool) + } } - fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) - if err != nil { - logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) - } else { - agent.Tools.Register(fetchTool) + if cfg.Tools.IsToolEnabled("web_fetch") { + fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } else { + agent.Tools.Register(fetchTool) + } } // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms - agent.Tools.Register(tools.NewI2CTool()) - agent.Tools.Register(tools.NewSPITool()) + if cfg.Tools.IsToolEnabled("i2c") { + agent.Tools.Register(tools.NewI2CTool()) + } + if cfg.Tools.IsToolEnabled("spi") { + agent.Tools.Register(tools.NewSPITool()) + } // Message tool - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer pubCancel() - return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, + if cfg.Tools.IsToolEnabled("message") { + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) }) - }) - agent.Tools.Register(messageTool) + agent.Tools.Register(messageTool) + } // Skill discovery and installation tools - registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ - MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, - ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), - }) - searchCache := skills.NewSearchCache( - cfg.Tools.Skills.SearchCache.MaxSize, - time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, - ) - agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) - agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + skills_enabled := cfg.Tools.IsToolEnabled("skills") + find_skills_enable := cfg.Tools.IsToolEnabled("find_skills") + install_skills_enable := cfg.Tools.IsToolEnabled("install_skill") + if skills_enabled && (find_skills_enable || install_skills_enable) { + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + + if find_skills_enable { + searchCache := skills.NewSearchCache( + cfg.Tools.Skills.SearchCache.MaxSize, + time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, + ) + agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) + } + + if install_skills_enable { + agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + } + } // Spawn tool with allowlist checker - subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) - subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) - spawnTool := tools.NewSpawnTool(subagentManager) - currentAgentID := agentID - spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { - return registry.CanSpawnSubagent(currentAgentID, targetAgentID) - }) - agent.Tools.Register(spawnTool) + if cfg.Tools.IsToolEnabled("spawn") { + if cfg.Tools.IsToolEnabled("subagent") { + subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) + spawnTool := tools.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) + } else { + logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil) + } + } } } @@ -185,7 +228,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) // Initialize MCP servers for all agents - if al.cfg.Tools.MCP.Enabled { + if al.cfg.Tools.IsToolEnabled("mcp") { mcpManager := mcp.NewManager() // Ensure MCP connections are cleaned up on exit, regardless of initialization success // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails @@ -227,6 +270,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { if !ok { continue } + mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) agent.Tools.Register(mcpTool) totalRegistrations++ @@ -518,27 +562,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } - // Check for commands - if response, handled := al.handleCommand(ctx, msg); handled { + route, agent, routeErr := al.resolveMessageRoute(msg) + + // Commands are checked before requiring a successful route. + // Global commands (/help, /show, /switch) work even when routing fails; + // context-dependent commands check their own Runtime fields and report + // "unavailable" when the required capability is nil. + if response, handled := al.handleCommand(ctx, msg, agent); handled { return response, nil } - // Route to determine agent and session key - route := al.registry.ResolveRoute(routing.RouteInput{ - Channel: msg.Channel, - AccountID: msg.Metadata["account_id"], - Peer: extractPeer(msg), - ParentPeer: extractParentPeer(msg), - GuildID: msg.Metadata["guild_id"], - TeamID: msg.Metadata["team_id"], - }) - - agent, ok := al.registry.GetAgent(route.AgentID) - if !ok { - agent = al.registry.GetDefaultAgent() - } - if agent == nil { - return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + if routeErr != nil { + return "", routeErr } // Reset message-tool state for this round so we don't skip publishing due to a previous round. @@ -548,17 +583,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } } - // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) - sessionKey := route.SessionKey - if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { - sessionKey = msg.SessionKey - } + // Resolve session key from route, while preserving explicit agent-scoped keys. + scopeKey := resolveScopeKey(route, msg.SessionKey) + sessionKey := scopeKey logger.InfoCF("agent", "Routed message", map[string]any{ - "agent_id": agent.ID, - "session_key": sessionKey, - "matched_by": route.MatchedBy, + "agent_id": agent.ID, + "scope_key": scopeKey, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + "route_agent": route.AgentID, + "route_channel": route.Channel, }) return al.runAgentLoop(ctx, agent, processOptions{ @@ -573,6 +609,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) }) } +func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + AccountID: inboundMetadata(msg, metadataKeyAccountID), + Peer: extractPeer(msg), + ParentPeer: extractParentPeer(msg), + GuildID: inboundMetadata(msg, metadataKeyGuildID), + TeamID: inboundMetadata(msg, metadataKeyTeamID), + }) + + agent, ok := al.registry.GetAgent(route.AgentID) + if !ok { + agent = al.registry.GetDefaultAgent() + } + if agent == nil { + return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + } + + return route, agent, nil +} + +func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { + if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) { + return msgSessionKey + } + return route.SessionKey +} + func (al *AgentLoop) processSystemMessage( ctx context.Context, msg bus.InboundMessage, @@ -793,6 +857,12 @@ func (al *AgentLoop) runLLMIteration( iteration := 0 var finalContent string + // Determine effective model tier for this conversation turn. + // selectCandidates evaluates routing once and the decision is sticky for + // all tool-follow-up iterations within the same turn so that a multi-step + // tool chain doesn't switch models mid-way through. + activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages) + for iteration < agent.MaxIterations { iteration++ @@ -811,7 +881,7 @@ func (al *AgentLoop) runLLMIteration( map[string]any{ "agent_id": agent.ID, "iteration": iteration, - "model": agent.Model, + "model": activeModel, "messages_count": len(messages), "tools_count": len(providerToolDefs), "max_tokens": agent.MaxTokens, @@ -827,7 +897,7 @@ func (al *AgentLoop) runLLMIteration( "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM with fallback chain if candidates are configured. + // Call LLM with fallback chain if multiple candidates are configured. var response *providers.LLMResponse var err error @@ -848,10 +918,10 @@ func (al *AgentLoop) runLLMIteration( } callLLM := func() (*providers.LLMResponse, error) { - if len(agent.Candidates) > 1 && al.fallback != nil { + if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( ctx, - agent.Candidates, + activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts) }, @@ -869,7 +939,7 @@ func (al *AgentLoop) runLLMIteration( } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts) + return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts) } // Retry loop for context/token errors @@ -1138,6 +1208,44 @@ func (al *AgentLoop) runLLMIteration( return finalContent, iteration, nil } +// selectCandidates returns the model candidates and resolved model name to use +// for a conversation turn. When model routing is configured and the incoming +// message scores below the complexity threshold, it returns the light model +// candidates instead of the primary ones. +// +// The returned (candidates, model) pair is used for all LLM calls within one +// turn — tool follow-up iterations use the same tier as the initial call so +// that a multi-step tool chain doesn't switch models mid-way. +func (al *AgentLoop) selectCandidates( + agent *AgentInstance, + userMsg string, + history []providers.Message, +) (candidates []providers.FallbackCandidate, model string) { + if agent.Router == nil || len(agent.LightCandidates) == 0 { + return agent.Candidates, agent.Model + } + + _, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model) + if !usedLight { + logger.DebugCF("agent", "Model routing: primary model selected", + map[string]any{ + "agent_id": agent.ID, + "score": score, + "threshold": agent.Router.Threshold(), + }) + return agent.Candidates, agent.Model + } + + logger.InfoCF("agent", "Model routing: light model selected", + map[string]any{ + "agent_id": agent.ID, + "light_model": agent.Router.LightModel(), + "score": score, + "threshold": agent.Router.Threshold(), + }) + return agent.LightCandidates, agent.Router.LightModel() +} + // maybeSummarize triggers summarization if the session history exceeds thresholds. func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) @@ -1429,94 +1537,87 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int { return totalChars * 2 / 5 } -func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) { - content := strings.TrimSpace(msg.Content) - if !strings.HasPrefix(content, "/") { +func (al *AgentLoop) handleCommand( + ctx context.Context, + msg bus.InboundMessage, + agent *AgentInstance, +) (string, bool) { + if !commands.HasCommandPrefix(msg.Content) { return "", false } - parts := strings.Fields(content) - if len(parts) == 0 { + if al.cmdRegistry == nil { return "", false } - cmd := parts[0] - args := parts[1:] + rt := al.buildCommandsRuntime(agent) + executor := commands.NewExecutor(al.cmdRegistry, rt) - switch cmd { - case "/show": - if len(args) < 1 { - return "Usage: /show [model|channel|agents]", true - } - switch args[0] { - case "model": - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent == nil { - return "No default agent configured", true - } - return fmt.Sprintf("Current model: %s", defaultAgent.Model), true - case "channel": - return fmt.Sprintf("Current channel: %s", msg.Channel), true - case "agents": - agentIDs := al.registry.ListAgentIDs() - return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true - default: - return fmt.Sprintf("Unknown show target: %s", args[0]), true - } + var commandReply string + result := executor.Execute(ctx, commands.Request{ + Channel: msg.Channel, + ChatID: msg.ChatID, + SenderID: msg.SenderID, + Text: msg.Content, + Reply: func(text string) error { + commandReply = text + return nil + }, + }) - case "/list": - if len(args) < 1 { - return "Usage: /list [models|channels|agents]", true + switch result.Outcome { + case commands.OutcomeHandled: + if result.Err != nil { + return mapCommandError(result), true } - switch args[0] { - case "models": - return "Available models: configured in config.json per agent", true - case "channels": + if commandReply != "" { + return commandReply, true + } + return "", true + default: // OutcomePassthrough — let the message fall through to LLM + return "", false + } +} + +func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime { + rt := &commands.Runtime{ + Config: al.cfg, + ListAgentIDs: al.registry.ListAgentIDs, + ListDefinitions: al.cmdRegistry.Definitions, + GetEnabledChannels: func() []string { if al.channelManager == nil { - return "Channel manager not initialized", true + return nil } - channels := al.channelManager.GetEnabledChannels() - if len(channels) == 0 { - return "No channels enabled", true - } - return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true - case "agents": - agentIDs := al.registry.ListAgentIDs() - return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true - default: - return fmt.Sprintf("Unknown list target: %s", args[0]), true - } - - case "/switch": - if len(args) < 3 || args[1] != "to" { - return "Usage: /switch [model|channel] to ", true - } - target := args[0] - value := args[2] - - switch target { - case "model": - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent == nil { - return "No default agent configured", true - } - oldModel := defaultAgent.Model - defaultAgent.Model = value - return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true - case "channel": + return al.channelManager.GetEnabledChannels() + }, + SwitchChannel: func(value string) error { if al.channelManager == nil { - return "Channel manager not initialized", true + return fmt.Errorf("channel manager not initialized") } if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { - return fmt.Sprintf("Channel '%s' not found or not enabled", value), true + return fmt.Errorf("channel '%s' not found or not enabled", value) } - return fmt.Sprintf("Switched target channel to %s", value), true - default: - return fmt.Sprintf("Unknown switch target: %s", target), true + return nil + }, + } + if agent != nil { + rt.GetModelInfo = func() (string, string) { + return agent.Model, al.cfg.Agents.Defaults.Provider + } + rt.SwitchModel = func(value string) (string, error) { + oldModel := agent.Model + agent.Model = value + return oldModel, nil } } + return rt +} - return "", false +func mapCommandError(result commands.ExecuteResult) string { + if result.Command == "" { + return fmt.Sprintf("Failed to execute command: %v", result.Err) + } + return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err) } // extractPeer extracts the routing peer from the inbound message's structured Peer field. @@ -1535,10 +1636,17 @@ func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} } +func inboundMetadata(msg bus.InboundMessage, key string) string { + if msg.Metadata == nil { + return "" + } + return msg.Metadata[key] +} + // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { - parentKind := msg.Metadata["parent_peer_kind"] - parentID := msg.Metadata["parent_peer_id"] + parentKind := inboundMetadata(msg, metadataKeyParentPeerKind) + parentID := inboundMetadata(msg, metadataKeyParentPeerID) if parentKind == "" || parentID == "" { return nil } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 4ab6b4542..2e456fa60 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -227,16 +228,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { } defer os.RemoveAll(tmpDir) - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, - } + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 msgBus := bus.NewMessageBus() provider := &mockProvider{} @@ -323,6 +319,29 @@ func (m *simpleMockProvider) GetDefaultModel() string { return "mock-model" } +type countingMockProvider struct { + response string + calls int +} + +func (m *countingMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *countingMockProvider) GetDefaultModel() string { + return "counting-mock-model" +} + // mockCustomTool is a simple mock tool for registration testing type mockCustomTool struct{} @@ -364,6 +383,198 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms const responseTimeout = 3 * time.Second +func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "ok"} + al := NewAgentLoop(cfg, msgBus, provider) + + msg := bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + Peer: extractPeer(msg), + }) + sessionKey := route.SessionKey + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + + helper := testHelper{al: al} + _ = helper.executeAndGetResponse(t, context.Background(), msg) + + history := defaultAgent.Sessions.GetHistory(sessionKey) + if len(history) != 2 { + t.Fatalf("expected session history len=2, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hello" { + t.Fatalf("unexpected first message in session: %+v", history[0]) + } +} + +func TestProcessMessage_CommandOutcomes(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + Session: config.SessionConfig{ + DMScope: "per-channel-peer", + }, + } + + msgBus := bus.NewMessageBus() + provider := &countingMockProvider{response: "LLM reply"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + baseMsg := bus.InboundMessage{ + Channel: "whatsapp", + SenderID: "user1", + ChatID: "chat1", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/show channel", + Peer: baseMsg.Peer, + }) + if showResp != "Current Channel: whatsapp" { + t.Fatalf("unexpected /show reply: %q", showResp) + } + if provider.calls != 0 { + t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls) + } + + fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/foo", + Peer: baseMsg.Peer, + }) + if fooResp != "LLM reply" { + t.Fatalf("unexpected /foo reply: %q", fooResp) + } + if provider.calls != 1 { + t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls) + } + + newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/new", + Peer: baseMsg.Peer, + }) + if newResp != "LLM reply" { + t.Fatalf("unexpected /new reply: %q", newResp) + } + if provider.calls != 2 { + t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls) + } +} + +func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Provider: "openai", + Model: "before-switch", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &countingMockProvider{response: "LLM reply"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/switch model to after-switch", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") { + t.Fatalf("unexpected /switch reply: %q", switchResp) + } + + showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/show model", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") { + t.Fatalf("unexpected /show model reply after switch: %q", showResp) + } + + if provider.calls != 0 { + t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls) + } +} + // TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index 91c9e25c5..4667e3d81 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -212,7 +212,10 @@ func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading device code response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("device code request failed: %s", string(body)) } @@ -300,7 +303,10 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading device code response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("device code request failed: %s", string(body)) } @@ -360,7 +366,10 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au return nil, fmt.Errorf("pending") } - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading device token response: %w", err) + } var tokenResp struct { AuthorizationCode string `json:"authorization_code"` @@ -401,7 +410,10 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading token refresh response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token refresh failed: %s", string(body)) } @@ -494,7 +506,10 @@ func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading token exchange response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token exchange failed: %s", string(body)) } diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 283dc6977..2e55d4877 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -39,6 +39,9 @@ func (c *AuthCredential) NeedsRefresh() bool { } func authFilePath() string { + if home := os.Getenv("PICOCLAW_HOME"); home != "" { + return filepath.Join(home, "auth.json") + } home, _ := os.UserHomeDir() return filepath.Join(home, ".picoclaw", "auth.json") } diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 1de910c83..c3bcbff8d 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "os" + "regexp" "strings" "sync" "time" @@ -26,6 +27,12 @@ const ( sendTimeout = 10 * time.Second ) +var ( + // Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call) + channelRefRe = regexp.MustCompile(`<#(\d+)>`) + msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`) +) + type DiscordChannel struct { *channels.BaseChannel session *discordgo.Session @@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = c.stripBotMention(content) } + // Resolve Discord refs in main content before concatenation to avoid + // double-expanding links that appear in the referenced message. + content = c.resolveDiscordRefs(s, content, m.GuildID) + + // Prepend referenced (quoted) message content if this is a reply + if m.MessageReference != nil && m.ReferencedMessage != nil { + refContent := m.ReferencedMessage.Content + if refContent != "" { + refAuthor := "unknown" + if m.ReferencedMessage.Author != nil { + refAuthor = m.ReferencedMessage.Author.Username + } + refContent = c.resolveDiscordRefs(s, refContent, m.GuildID) + content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s", + refAuthor, refContent, content) + } + } + senderID := m.Author.ID mediaPaths := make([]string, 0, len(m.Attachments)) @@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error { return nil } +// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and +// expands Discord message links to show the linked message content. +// Only links pointing to the same guild are expanded to prevent cross-guild leakage. +func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string { + // 1. Resolve channel references: <#id> → #channel-name + text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string { + parts := channelRefRe.FindStringSubmatch(match) + if len(parts) < 2 { + return match + } + // Prefer session state cache to avoid API calls + if ch, err := s.State.Channel(parts[1]); err == nil { + return "#" + ch.Name + } + if ch, err := s.Channel(parts[1]); err == nil { + return "#" + ch.Name + } + return match + }) + + // 2. Expand Discord message links (max 3, same guild only) + matches := msgLinkRe.FindAllStringSubmatch(text, 3) + for _, m := range matches { + if len(m) < 4 { + continue + } + linkGuildID, channelID, messageID := m[1], m[2], m[3] + // Security: only expand links from the same guild + if linkGuildID != guildID { + continue + } + msg, err := s.ChannelMessage(channelID, messageID) + if err != nil || msg == nil || msg.Content == "" { + continue + } + author := "unknown" + if msg.Author != nil { + author = msg.Author.Username + } + text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content) + } + + return text +} + // stripBotMention removes the bot mention from the message content. // Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). func (c *DiscordChannel) stripBotMention(text string) string { diff --git a/pkg/channels/discord/discord_resolve_test.go b/pkg/channels/discord/discord_resolve_test.go new file mode 100644 index 000000000..4bc65cc18 --- /dev/null +++ b/pkg/channels/discord/discord_resolve_test.go @@ -0,0 +1,98 @@ +package discord + +import ( + "testing" +) + +func TestChannelRefRegex(t *testing.T) { + tests := []struct { + name string + input string + wantID string + wantOK bool + }{ + {"basic channel ref", "<#123456789>", "123456789", true}, + {"long id", "<#9876543210123456>", "9876543210123456", true}, + {"no match plain text", "hello world", "", false}, + {"no match partial", "<#>", "", false}, + {"no match letters", "<#abc>", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := channelRefRe.FindStringSubmatch(tt.input) + if tt.wantOK { + if len(matches) < 2 || matches[1] != tt.wantID { + t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID) + } + } else { + if len(matches) >= 2 { + t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches) + } + } + }) + } +} + +func TestMsgLinkRegex(t *testing.T) { + tests := []struct { + name string + input string + wantGuild string + wantChan string + wantMsg string + wantOK bool + }{ + { + "discord.com link", + "https://discord.com/channels/111/222/333", + "111", "222", "333", true, + }, + { + "discordapp.com link", + "https://discordapp.com/channels/111/222/333", + "111", "222", "333", true, + }, + { + "real world ids", + "check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please", + "9000000000000001", "9000000000000002", "9000000000000003", true, + }, + {"no match http", "http://discord.com/channels/1/2/3", "", "", "", false}, + {"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false}, + {"no match plain text", "hello world", "", "", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := msgLinkRe.FindStringSubmatch(tt.input) + if tt.wantOK { + if len(matches) < 4 { + t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s", + tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg) + } + if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg { + t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s", + tt.input, matches[1], matches[2], matches[3], + tt.wantGuild, tt.wantChan, tt.wantMsg) + } + } else { + if len(matches) >= 4 { + t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches) + } + } + }) + } +} + +func TestMsgLinkRegex_MultipleMatches(t *testing.T) { + input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12" + matches := msgLinkRe.FindAllStringSubmatch(input, 3) + if len(matches) != 3 { + t.Fatalf("expected 3 matches (capped), got %d", len(matches)) + } + // Verify the 3rd match is 7/8/9 (not 10/11/12) + if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" { + t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2]) + } +} diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go index 74caeeac5..b3a493761 100644 --- a/pkg/channels/interfaces.go +++ b/pkg/channels/interfaces.go @@ -1,6 +1,10 @@ package channels -import "context" +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/commands" +) // TypingCapable — channels that can show a typing/thinking indicator. // StartTyping begins the indicator and returns a stop function. @@ -39,3 +43,10 @@ type PlaceholderRecorder interface { RecordTypingStop(channel, chatID string, stop func()) RecordReactionUndo(channel, chatID string, undo func()) } + +// CommandRegistrarCapable is implemented by channels that can register +// command menus with their upstream platform (e.g. Telegram BotCommand). +// Channels that do not support platform-level command menus can ignore it. +type CommandRegistrarCapable interface { + RegisterCommands(ctx context.Context, defs []commands.Definition) error +} diff --git a/pkg/channels/interfaces_command_test.go b/pkg/channels/interfaces_command_test.go new file mode 100644 index 000000000..de5502644 --- /dev/null +++ b/pkg/channels/interfaces_command_test.go @@ -0,0 +1,16 @@ +package channels + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/commands" +) + +type mockRegistrar struct{} + +func (mockRegistrar) RegisterCommands(context.Context, []commands.Definition) error { return nil } + +func TestCommandRegistrarCapable_Compiles(t *testing.T) { + var _ CommandRegistrarCapable = mockRegistrar{} +} diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 398f12e6b..b36350a06 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -654,7 +654,10 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("reading LINE API error response: %w", err)) + } return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody))) } diff --git a/pkg/channels/telegram/command_registration.go b/pkg/channels/telegram/command_registration.go new file mode 100644 index 000000000..d3152ec3d --- /dev/null +++ b/pkg/channels/telegram/command_registration.go @@ -0,0 +1,116 @@ +package telegram + +import ( + "context" + "math/rand" + "slices" + "time" + + "github.com/mymmrac/telego" + + "github.com/sipeed/picoclaw/pkg/commands" + "github.com/sipeed/picoclaw/pkg/logger" +) + +var commandRegistrationBackoff = []time.Duration{ + 5 * time.Second, + 15 * time.Second, + 60 * time.Second, + 5 * time.Minute, + 10 * time.Minute, +} + +func commandRegistrationDelay(attempt int) time.Duration { + if len(commandRegistrationBackoff) == 0 { + return 0 + } + base := commandRegistrationBackoff[min(attempt, len(commandRegistrationBackoff)-1)] + // Full jitter in [0.5, 1.0) to avoid synchronized retries across instances. + return time.Duration(float64(base) * (0.5 + rand.Float64()*0.5)) +} + +// RegisterCommands registers bot commands on Telegram platform. +func (c *TelegramChannel) RegisterCommands(ctx context.Context, defs []commands.Definition) error { + botCommands := make([]telego.BotCommand, 0, len(defs)) + for _, def := range defs { + if def.Name == "" || def.Description == "" { + continue + } + botCommands = append(botCommands, telego.BotCommand{ + Command: def.Name, + Description: def.Description, + }) + } + + current, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{}) + if err != nil { + // If we can't read current commands, fall through to set them. + logger.WarnCF("telegram", "Failed to get current commands, will set unconditionally", + map[string]any{"error": err.Error()}) + } else if slices.Equal(current, botCommands) { + logger.DebugCF("telegram", "Bot commands are up to date", nil) + return nil + } + + return c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ + Commands: botCommands, + }) +} + +func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []commands.Definition) { + if len(defs) == 0 { + return + } + + register := c.registerFunc + if register == nil { + register = c.RegisterCommands + } + + regCtx, cancel := context.WithCancel(ctx) + c.commandRegCancel = cancel + + // Registration runs asynchronously so Telegram message intake is never blocked + // by temporary upstream API failures. Retry stops on success or channel shutdown. + go func() { + attempt := 0 + timer := time.NewTimer(0) + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + defer timer.Stop() + for { + err := register(regCtx, defs) + if err == nil { + logger.InfoCF("telegram", "Telegram commands registered", map[string]any{ + "count": len(defs), + }) + return + } + + delay := commandRegistrationDelay(attempt) + logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{ + "error": err.Error(), + "retry_after": delay.String(), + }) + attempt++ + + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(delay) + + select { + case <-regCtx.Done(): + return + case <-timer.C: + } + } + }() +} diff --git a/pkg/channels/telegram/command_registration_test.go b/pkg/channels/telegram/command_registration_test.go new file mode 100644 index 000000000..26f891b2e --- /dev/null +++ b/pkg/channels/telegram/command_registration_test.go @@ -0,0 +1,96 @@ +package telegram + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/commands" +) + +func TestStartCommandRegistration_DoesNotBlock(t *testing.T) { + ch := &TelegramChannel{} + started := make(chan struct{}, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch.registerFunc = func(context.Context, []commands.Definition) error { + started <- struct{}{} + return errors.New("temporary failure") + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help"}}) + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("registration did not start asynchronously") + } +} + +func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) { + ch := &TelegramChannel{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + origBackoff := commandRegistrationBackoff + commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} + defer func() { commandRegistrationBackoff = origBackoff }() + + var attempts atomic.Int32 + ch.registerFunc = func(context.Context, []commands.Definition) error { + n := attempts.Add(1) + if n < 3 { + return errors.New("temporary failure") + } + return nil + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}}) + + deadline := time.Now().Add(250 * time.Millisecond) + for time.Now().Before(deadline) { + if attempts.Load() >= 3 { + break + } + time.Sleep(5 * time.Millisecond) + } + if attempts.Load() < 3 { + t.Fatalf("expected at least 3 attempts, got %d", attempts.Load()) + } + + stable := attempts.Load() + time.Sleep(30 * time.Millisecond) + if attempts.Load() != stable { + t.Fatalf("expected retries to stop after success, got %d -> %d", stable, attempts.Load()) + } +} + +func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) { + ch := &TelegramChannel{} + ctx, cancel := context.WithCancel(context.Background()) + + origBackoff := commandRegistrationBackoff + commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} + defer func() { commandRegistrationBackoff = origBackoff }() + defer cancel() + + var attempts atomic.Int32 + ch.registerFunc = func(context.Context, []commands.Definition) error { + attempts.Add(1) + return errors.New("always fail") + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}}) + + time.Sleep(20 * time.Millisecond) + cancel() + time.Sleep(20 * time.Millisecond) // allow in-flight attempt to settle + stable := attempts.Load() + time.Sleep(30 * time.Millisecond) + if attempts.Load() != stable { + t.Fatalf("expected retries to quiesce after cancel, got %d -> %d", stable, attempts.Load()) + } +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index f328f32b8..a2035853c 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,7 +7,6 @@ import ( "net/url" "os" "regexp" - "slices" "strconv" "strings" "time" @@ -18,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/commands" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" @@ -40,13 +40,15 @@ var ( type TelegramChannel struct { *channels.BaseChannel - bot *telego.Bot - bh *th.BotHandler - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - ctx context.Context - cancel context.CancelFunc + bot *telego.Bot + bh *th.BotHandler + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc + + registerFunc func(context.Context, []commands.Definition) error + commandRegCancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -93,7 +95,6 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann return &TelegramChannel{ BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), bot: bot, config: cfg, chatIDs: make(map[string]int64), @@ -105,12 +106,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error { c.ctx, c.cancel = context.WithCancel(ctx) - if err := c.initBotCommands(c.ctx); err != nil { - logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{ - "error": err.Error(), - }) - } - updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) @@ -126,21 +121,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } c.bh = bh - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Start(ctx, message) - }, th.CommandEqual("start")) - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Help(ctx, message) - }, th.CommandEqual("help")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Show(ctx, message) - }, th.CommandEqual("show")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.List(ctx, message) - }, th.CommandEqual("list")) - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { return c.handleMessage(ctx, &message) }, th.AnyMessage()) @@ -150,6 +130,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error { "username": c.bot.Username(), }) + c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions()) + go func() { if err = bh.Start(); err != nil { logger.ErrorCF("telegram", "Bot handler failed", map[string]any{ @@ -174,50 +156,8 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { if c.cancel != nil { c.cancel() } - - return nil -} - -func (c *TelegramChannel) initBotCommands(ctx context.Context) error { - currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{ - Scope: tu.ScopeDefault(), - }) - if err != nil { - return fmt.Errorf("get commands: %w", err) - } - - commands := []telego.BotCommand{ - { - Command: "start", - Description: "Start the bot", - }, - { - Command: "help", - Description: "Show a help message", - }, - { - Command: "show", - Description: "Show current configuration", - }, - { - Command: "list", - Description: "List available options", - }, - } - - // Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed - if !slices.Equal(currentCommands, commands) { - logger.InfoC("telegram", "Updating bot commands") - - err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ - Commands: commands, - Scope: tu.ScopeDefault(), - }) - if err != nil { - return fmt.Errorf("set commands: %w", err) - } - } else { - logger.DebugC("telegram", "Bot commands are up to date") + if c.commandRegCancel != nil { + c.commandRegCancel() } return nil @@ -721,34 +661,34 @@ func escapeHTML(text string) string { // isBotMentioned checks if the bot is mentioned in the message via entities. func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { - botUsername := c.bot.Username() - if botUsername == "" { + text, entities := telegramEntityTextAndList(message) + if text == "" || len(entities) == 0 { return false } - entities := message.Entities - if entities == nil { - entities = message.CaptionEntities + botUsername := "" + if c.bot != nil { + botUsername = c.bot.Username() } + runes := []rune(text) for _, entity := range entities { - if entity.Type == "mention" { - // Extract the mention text from the message - text := message.Text - if text == "" { - text = message.Caption - } - runes := []rune(text) - end := entity.Offset + entity.Length - if end <= len(runes) { - mention := string(runes[entity.Offset:end]) - if strings.EqualFold(mention, "@"+botUsername) { - return true - } - } + entityText, ok := telegramEntityText(runes, entity) + if !ok { + continue } - if entity.Type == "text_mention" && entity.User != nil { - if entity.User.Username == botUsername { + + switch entity.Type { + case telego.EntityTypeMention: + if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) { + return true + } + case telego.EntityTypeTextMention: + if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) { + return true + } + case telego.EntityTypeBotCommand: + if isBotCommandEntityForThisBot(entityText, botUsername) { return true } } @@ -756,6 +696,46 @@ func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { return false } +func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) { + if message.Text != "" { + return message.Text, message.Entities + } + return message.Caption, message.CaptionEntities +} + +func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) { + if entity.Offset < 0 || entity.Length <= 0 { + return "", false + } + end := entity.Offset + entity.Length + if entity.Offset >= len(runes) || end > len(runes) { + return "", false + } + return string(runes[entity.Offset:end]), true +} + +func isBotCommandEntityForThisBot(entityText, botUsername string) bool { + if !strings.HasPrefix(entityText, "/") { + return false + } + command := strings.TrimPrefix(entityText, "/") + if command == "" { + return false + } + + at := strings.IndexRune(command, '@') + if at == -1 { + // A bare /command delivered to this bot is intended for this bot. + return true + } + + mentionUsername := command[at+1:] + if mentionUsername == "" || botUsername == "" { + return false + } + return strings.EqualFold(mentionUsername, botUsername) +} + // stripBotMention removes the @bot mention from the content. func (c *TelegramChannel) stripBotMention(content string) string { botUsername := c.bot.Username() diff --git a/pkg/channels/telegram/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go deleted file mode 100644 index 496fc5e4f..000000000 --- a/pkg/channels/telegram/telegram_commands.go +++ /dev/null @@ -1,156 +0,0 @@ -package telegram - -import ( - "context" - "fmt" - "strings" - - "github.com/mymmrac/telego" - - "github.com/sipeed/picoclaw/pkg/config" -) - -type TelegramCommander interface { - Help(ctx context.Context, message telego.Message) error - Start(ctx context.Context, message telego.Message) error - Show(ctx context.Context, message telego.Message) error - List(ctx context.Context, message telego.Message) error -} - -type cmd struct { - bot *telego.Bot - config *config.Config -} - -func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { - return &cmd{ - bot: bot, - config: cfg, - } -} - -func commandArgs(text string) string { - parts := strings.SplitN(text, " ", 2) - if len(parts) < 2 { - return "" - } - return strings.TrimSpace(parts[1]) -} - -func (c *cmd) Help(ctx context.Context, message telego.Message) error { - msg := `/start - Start the bot -/help - Show this help message -/show [model|channel] - Show current configuration -/list [models|channels] - List available options - ` - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: msg, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Start(ctx context.Context, message telego.Message) error { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Hello! I am PicoClaw 🦞", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Show(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /show [model|channel]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "model": - response = fmt.Sprintf("Current Model: %s (Provider: %s)", - c.config.Agents.Defaults.GetModelName(), - c.config.Agents.Defaults.Provider) - case "channel": - response = "Current Channel: telegram" - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) List(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /list [models|channels]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "models": - provider := c.config.Agents.Defaults.Provider - if provider == "" { - provider = "configured default" - } - response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json", - c.config.Agents.Defaults.GetModelName(), provider) - - case "channels": - var enabled []string - if c.config.Channels.Telegram.Enabled { - enabled = append(enabled, "telegram") - } - if c.config.Channels.WhatsApp.Enabled { - enabled = append(enabled, "whatsapp") - } - if c.config.Channels.Feishu.Enabled { - enabled = append(enabled, "feishu") - } - if c.config.Channels.Discord.Enabled { - enabled = append(enabled, "discord") - } - if c.config.Channels.Slack.Enabled { - enabled = append(enabled, "slack") - } - response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) - - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} diff --git a/pkg/channels/telegram/telegram_dispatch_test.go b/pkg/channels/telegram/telegram_dispatch_test.go new file mode 100644 index 000000000..1ea4a4824 --- /dev/null +++ b/pkg/channels/telegram/telegram_dispatch_test.go @@ -0,0 +1,52 @@ +package telegram + +import ( + "context" + "testing" + "time" + + "github.com/mymmrac/telego" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" +) + +func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &TelegramChannel{ + BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil), + chatIDs: make(map[string]int64), + ctx: context.Background(), + } + + msg := &telego.Message{ + Text: "/new", + MessageID: 9, + Chat: telego.Chat{ + ID: 123, + Type: "private", + }, + From: &telego.User{ + ID: 42, + FirstName: "Alice", + }, + } + + if err := ch.handleMessage(context.Background(), msg); err != nil { + t.Fatalf("handleMessage error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "telegram" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/new" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/channels/telegram/telegram_group_command_filter_test.go b/pkg/channels/telegram/telegram_group_command_filter_test.go new file mode 100644 index 000000000..0d5b985fe --- /dev/null +++ b/pkg/channels/telegram/telegram_group_command_filter_test.go @@ -0,0 +1,147 @@ +package telegram + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/mymmrac/telego" + ta "github.com/mymmrac/telego/telegoapi" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +type getMeCaller struct { + username string +} + +func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) { + if strings.HasSuffix(url, "/getMe") { + result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username) + return &ta.Response{Ok: true, Result: []byte(result)}, nil + } + return &ta.Response{Ok: true, Result: []byte("true")}, nil +} + +func newTestTelegramBot(t *testing.T, username string) *telego.Bot { + t.Helper() + + token := "123456:" + strings.Repeat("a", 35) + bot, err := telego.NewBot(token, + telego.WithAPICaller(getMeCaller{username: username}), + telego.WithDiscardLogger(), + ) + if err != nil { + t.Fatalf("NewBot error: %v", err) + } + return bot +} + +func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) { + t.Helper() + + messageBus := bus.NewMessageBus() + ch := &TelegramChannel{ + BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil, + channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}), + ), + bot: newTestTelegramBot(t, botUsername), + chatIDs: make(map[string]int64), + ctx: context.Background(), + } + return ch, messageBus +} + +func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) { + tests := []struct { + name string + text string + wantForwarded bool + wantContent string + }{ + { + name: "command with bot username", + text: "/new@testbot", + wantForwarded: true, + wantContent: "/new", + }, + { + name: "bare command", + text: "/new", + wantForwarded: true, + wantContent: "/new", + }, + { + name: "command for another bot", + text: "/new@otherbot", + wantForwarded: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ch, messageBus := newGroupMentionOnlyChannel(t, "testbot") + + msg := &telego.Message{ + Text: tc.text, + Entities: []telego.MessageEntity{{ + Type: telego.EntityTypeBotCommand, + Offset: 0, + Length: len([]rune(tc.text)), + }}, + MessageID: 42, + Chat: telego.Chat{ + ID: 123, + Type: "group", + }, + From: &telego.User{ + ID: 7, + FirstName: "Alice", + }, + } + + if err := ch.handleMessage(context.Background(), msg); err != nil { + t.Fatalf("handleMessage error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if tc.wantForwarded { + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Content != tc.wantContent { + t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent) + } + return + } + + if ok { + t.Fatalf("expected message to be filtered, got content=%q", inbound.Content) + } + }) + } +} + +func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) { + ch, _ := newGroupMentionOnlyChannel(t, "testbot") + + msg := &telego.Message{ + Text: "@testbot hello", + Entities: []telego.MessageEntity{{ + Type: telego.EntityTypeMention, + Offset: 0, + Length: len("@testbot"), + }}, + } + + if !ch.isBotMentioned(msg) { + t.Fatal("expected mention entity to be treated as bot mention") + } +} diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go index 6c5aca40b..93fe8c36d 100644 --- a/pkg/channels/wecom/aibot.go +++ b/pkg/channels/wecom/aibot.go @@ -793,7 +793,10 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro return nil } - respBody, _ := io.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err) + } switch { case resp.StatusCode == http.StatusTooManyRequests: return fmt.Errorf("response_url rate limited (%d): %s: %w", diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 717815b9f..2098fcd4e 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -321,8 +321,17 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody))) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return "", channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("reading wecom upload error response: %w", readErr), + ) + } + return "", channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("wecom upload error: %s", string(respBody)), + ) } var result struct { @@ -371,8 +380,17 @@ func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken stri defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody))) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("reading wecom_app error response: %w", readErr), + ) + } + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("wecom_app API error: %s", string(respBody)), + ) } respBody, err := io.ReadAll(resp.Body) diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 9126a847d..96d5a961f 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -453,8 +453,17 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body))) + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("reading webhook error response: %w", readErr), + ) + } + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("webhook API error: %s", string(body)), + ) } body, err := io.ReadAll(resp.Body) diff --git a/pkg/channels/whatsapp/whatsapp_command_test.go b/pkg/channels/whatsapp/whatsapp_command_test.go new file mode 100644 index 000000000..ee8aa4a52 --- /dev/null +++ b/pkg/channels/whatsapp/whatsapp_command_test.go @@ -0,0 +1,41 @@ +package whatsapp + +import ( + "context" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &WhatsAppChannel{ + BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil), + ctx: context.Background(), + } + + ch.handleIncomingMessage(map[string]any{ + "type": "message", + "id": "mid1", + "from": "user1", + "chat": "chat1", + "content": "/help", + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "whatsapp" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/help" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/channels/whatsapp_native/whatsapp_command_test.go b/pkg/channels/whatsapp_native/whatsapp_command_test.go new file mode 100644 index 000000000..cc2dcb619 --- /dev/null +++ b/pkg/channels/whatsapp_native/whatsapp_command_test.go @@ -0,0 +1,56 @@ +//go:build whatsapp_native + +package whatsapp + +import ( + "context" + "testing" + "time" + + "go.mau.fi/whatsmeow/proto/waE2E" + "go.mau.fi/whatsmeow/types" + "go.mau.fi/whatsmeow/types/events" + "google.golang.org/protobuf/proto" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &WhatsAppNativeChannel{ + BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil), + runCtx: context.Background(), + } + + evt := &events.Message{ + Info: types.MessageInfo{ + MessageSource: types.MessageSource{ + Sender: types.NewJID("1001", types.DefaultUserServer), + Chat: types.NewJID("1001", types.DefaultUserServer), + }, + ID: "mid1", + PushName: "Alice", + }, + Message: &waE2E.Message{ + Conversation: proto.String("/new"), + }, + } + + ch.handleIncoming(evt) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "whatsapp_native" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/new" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/commands/builtin.go b/pkg/commands/builtin.go new file mode 100644 index 000000000..a36dd3eba --- /dev/null +++ b/pkg/commands/builtin.go @@ -0,0 +1,16 @@ +package commands + +// BuiltinDefinitions returns all built-in command definitions. +// Each command group is defined in its own cmd_*.go file. +// Definitions are stateless — runtime dependencies are provided +// via the Runtime parameter passed to handlers at execution time. +func BuiltinDefinitions() []Definition { + return []Definition{ + startCommand(), + helpCommand(), + showCommand(), + listCommand(), + switchCommand(), + checkCommand(), + } +} diff --git a/pkg/commands/builtin_test.go b/pkg/commands/builtin_test.go new file mode 100644 index 000000000..66a84825e --- /dev/null +++ b/pkg/commands/builtin_test.go @@ -0,0 +1,145 @@ +package commands + +import ( + "context" + "strings" + "testing" +) + +func findDefinitionByName(t *testing.T, defs []Definition, name string) Definition { + t.Helper() + for _, def := range defs { + if def.Name == name { + return def + } + } + t.Fatalf("missing /%s definition", name) + return Definition{} +} + +func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { + defs := BuiltinDefinitions() + helpDef := findDefinitionByName(t, defs, "help") + if helpDef.Handler == nil { + t.Fatalf("/help handler should not be nil") + } + + var reply string + err := helpDef.Handler(context.Background(), Request{ + Text: "/help", + Reply: func(text string) error { + reply = text + return nil + }, + }, nil) + if err != nil { + t.Fatalf("/help handler error: %v", err) + } + // Now uses auto-generated EffectiveUsage which includes agents + if !strings.Contains(reply, "/show [model|channel|agents]") { + t.Fatalf("/help reply missing /show usage, got %q", reply) + } + if !strings.Contains(reply, "/list [models|channels|agents]") { + t.Fatalf("/help reply missing /list usage, got %q", reply) + } +} + +func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) { + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), nil) + + cases := []string{"telegram", "whatsapp"} + for _, channel := range cases { + var reply string + res := ex.Execute(context.Background(), Request{ + Channel: channel, + Text: "/show channel", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/show channel on %s: outcome=%v, want=%v", channel, res.Outcome, OutcomeHandled) + } + want := "Current Channel: " + channel + if reply != want { + t.Fatalf("/show channel reply=%q, want=%q", reply, want) + } + } +} + +func TestBuiltinListChannels_UsesGetEnabledChannels(t *testing.T) { + rt := &Runtime{ + GetEnabledChannels: func() []string { + return []string{"telegram", "slack"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/list channels", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/list channels: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "telegram") || !strings.Contains(reply, "slack") { + t.Fatalf("/list channels reply=%q, want telegram and slack", reply) + } +} + +func TestBuiltinShowAgents_RestoresOldBehavior(t *testing.T) { + rt := &Runtime{ + ListAgentIDs: func() []string { + return []string{"default", "coder"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show agents", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/show agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") { + t.Fatalf("/show agents reply=%q, want agent IDs", reply) + } +} + +func TestBuiltinListAgents_RestoresOldBehavior(t *testing.T) { + rt := &Runtime{ + ListAgentIDs: func() []string { + return []string{"default", "coder"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/list agents", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/list agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") { + t.Fatalf("/list agents reply=%q, want agent IDs", reply) + } +} diff --git a/pkg/commands/cmd_check.go b/pkg/commands/cmd_check.go new file mode 100644 index 000000000..f0193dc4f --- /dev/null +++ b/pkg/commands/cmd_check.go @@ -0,0 +1,33 @@ +package commands + +import ( + "context" + "fmt" +) + +func checkCommand() Definition { + return Definition{ + Name: "check", + Description: "Check channel availability", + SubCommands: []SubCommand{ + { + Name: "channel", + Description: "Check if a channel is available", + ArgsUsage: "", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.SwitchChannel == nil { + return req.Reply(unavailableMsg) + } + value := nthToken(req.Text, 2) + if value == "" { + return req.Reply("Usage: /check channel ") + } + if err := rt.SwitchChannel(value); err != nil { + return req.Reply(err.Error()) + } + return req.Reply(fmt.Sprintf("Channel '%s' is available and enabled", value)) + }, + }, + }, + } +} diff --git a/pkg/commands/cmd_help.go b/pkg/commands/cmd_help.go new file mode 100644 index 000000000..94f7f0101 --- /dev/null +++ b/pkg/commands/cmd_help.go @@ -0,0 +1,44 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +func helpCommand() Definition { + return Definition{ + Name: "help", + Description: "Show this help message", + Usage: "/help", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + var defs []Definition + if rt != nil && rt.ListDefinitions != nil { + defs = rt.ListDefinitions() + } else { + defs = BuiltinDefinitions() + } + return req.Reply(formatHelpMessage(defs)) + }, + } +} + +func formatHelpMessage(defs []Definition) string { + if len(defs) == 0 { + return "No commands available." + } + + lines := make([]string, 0, len(defs)) + for _, def := range defs { + usage := def.EffectiveUsage() + if usage == "" { + usage = "/" + def.Name + } + desc := def.Description + if desc == "" { + desc = "No description" + } + lines = append(lines, fmt.Sprintf("%s - %s", usage, desc)) + } + return strings.Join(lines, "\n") +} diff --git a/pkg/commands/cmd_list.go b/pkg/commands/cmd_list.go new file mode 100644 index 000000000..bf47b6e9c --- /dev/null +++ b/pkg/commands/cmd_list.go @@ -0,0 +1,52 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +func listCommand() Definition { + return Definition{ + Name: "list", + Description: "List available options", + SubCommands: []SubCommand{ + { + Name: "models", + Description: "Configured models", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetModelInfo == nil { + return req.Reply(unavailableMsg) + } + name, provider := rt.GetModelInfo() + if provider == "" { + provider = "configured default" + } + return req.Reply(fmt.Sprintf( + "Configured Model: %s\nProvider: %s\n\nTo change models, update config.json", + name, provider, + )) + }, + }, + { + Name: "channels", + Description: "Enabled channels", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetEnabledChannels == nil { + return req.Reply(unavailableMsg) + } + enabled := rt.GetEnabledChannels() + if len(enabled) == 0 { + return req.Reply("No channels enabled") + } + return req.Reply(fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))) + }, + }, + { + Name: "agents", + Description: "Registered agents", + Handler: agentsHandler(), + }, + }, + } +} diff --git a/pkg/commands/cmd_show.go b/pkg/commands/cmd_show.go new file mode 100644 index 000000000..c655e6880 --- /dev/null +++ b/pkg/commands/cmd_show.go @@ -0,0 +1,38 @@ +package commands + +import ( + "context" + "fmt" +) + +func showCommand() Definition { + return Definition{ + Name: "show", + Description: "Show current configuration", + SubCommands: []SubCommand{ + { + Name: "model", + Description: "Current model and provider", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetModelInfo == nil { + return req.Reply(unavailableMsg) + } + name, provider := rt.GetModelInfo() + return req.Reply(fmt.Sprintf("Current Model: %s (Provider: %s)", name, provider)) + }, + }, + { + Name: "channel", + Description: "Current channel", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply(fmt.Sprintf("Current Channel: %s", req.Channel)) + }, + }, + { + Name: "agents", + Description: "Registered agents", + Handler: agentsHandler(), + }, + }, + } +} diff --git a/pkg/commands/cmd_start.go b/pkg/commands/cmd_start.go new file mode 100644 index 000000000..8b500aa10 --- /dev/null +++ b/pkg/commands/cmd_start.go @@ -0,0 +1,14 @@ +package commands + +import "context" + +func startCommand() Definition { + return Definition{ + Name: "start", + Description: "Start the bot", + Usage: "/start", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply("Hello! I am PicoClaw 🦞") + }, + } +} diff --git a/pkg/commands/cmd_switch.go b/pkg/commands/cmd_switch.go new file mode 100644 index 000000000..fb8fc109e --- /dev/null +++ b/pkg/commands/cmd_switch.go @@ -0,0 +1,42 @@ +package commands + +import ( + "context" + "fmt" +) + +func switchCommand() Definition { + return Definition{ + Name: "switch", + Description: "Switch model", + SubCommands: []SubCommand{ + { + Name: "model", + Description: "Switch to a different model", + ArgsUsage: "to ", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.SwitchModel == nil { + return req.Reply(unavailableMsg) + } + // Parse: /switch model to + value := nthToken(req.Text, 3) // tokens: [/switch, model, to, ] + if nthToken(req.Text, 2) != "to" || value == "" { + return req.Reply("Usage: /switch model to ") + } + oldModel, err := rt.SwitchModel(value) + if err != nil { + return req.Reply(err.Error()) + } + return req.Reply(fmt.Sprintf("Switched model from %s to %s", oldModel, value)) + }, + }, + { + Name: "channel", + Description: "Moved to /check channel", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply("This command has moved. Please use: /check channel ") + }, + }, + }, + } +} diff --git a/pkg/commands/cmd_switch_test.go b/pkg/commands/cmd_switch_test.go new file mode 100644 index 000000000..59ed305bb --- /dev/null +++ b/pkg/commands/cmd_switch_test.go @@ -0,0 +1,279 @@ +package commands + +import ( + "context" + "fmt" + "testing" +) + +func TestSwitchModel_Success(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old-model", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "Switched model from old-model to gpt-4" + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestSwitchModel_MissingToKeyword(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /switch model to " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitchModel_MissingValue(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /switch model to " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitchModel_Error(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "", fmt.Errorf("model not found") + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to bad-model", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "model not found" { + t.Fatalf("reply=%q, want error message", reply) + } +} + +func TestSwitchModel_NilDep(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Command unavailable in current context." { + t.Fatalf("reply=%q, want unavailable message", reply) + } +} + +func TestSwitchChannel_Redirect(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch channel to telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "This command has moved. Please use: /check channel " + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestCheckChannel_Success(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "Channel 'telegram' is available and enabled" + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestCheckChannel_Error(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return fmt.Errorf("channel '%s' not found", value) + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel unknown", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "channel 'unknown' not found" { + t.Fatalf("reply=%q, want error message", reply) + } +} + +func TestCheckChannel_NilDep(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Command unavailable in current context." { + t.Fatalf("reply=%q, want unavailable message", reply) + } +} + +func TestCheckChannel_MissingValue(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /check channel " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitch_BangPrefix(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "!switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("! prefix: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Switched model from old to gpt-4" { + t.Fatalf("! prefix: reply=%q, want success message", reply) + } +} + +func TestSwitch_NoSubCommand(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + // Should get usage message from executor's sub-command routing + if reply == "" { + t.Fatal("expected usage reply for bare /switch") + } +} diff --git a/pkg/commands/definition.go b/pkg/commands/definition.go new file mode 100644 index 000000000..7309df317 --- /dev/null +++ b/pkg/commands/definition.go @@ -0,0 +1,48 @@ +package commands + +import ( + "fmt" + "strings" +) + +// SubCommand defines a single sub-command within a parent command. +type SubCommand struct { + Name string + Description string + ArgsUsage string // optional, e.g. "" + Handler Handler +} + +// Definition is the single-source metadata and behavior contract for a slash command. +// +// Design notes (phase 1): +// - Every channel reads command shape from this type instead of keeping local copies. +// - Visibility is global: all definitions are considered available to all channels. +// - Platform menu registration (for example Telegram BotCommand) also derives from this +// same definition so UI labels and runtime behavior stay aligned. +type Definition struct { + Name string + Description string + Usage string // for simple commands; ignored when SubCommands is set + Aliases []string + SubCommands []SubCommand // optional; when set, Executor routes to sub-command handlers + Handler Handler // for simple commands without sub-commands +} + +// EffectiveUsage returns the usage string. When SubCommands are present, +// it is auto-generated from sub-command names so metadata and behavior +// cannot drift. +func (d Definition) EffectiveUsage() string { + if len(d.SubCommands) == 0 { + return d.Usage + } + names := make([]string, 0, len(d.SubCommands)) + for _, sc := range d.SubCommands { + name := sc.Name + if sc.ArgsUsage != "" { + name += " " + sc.ArgsUsage + } + names = append(names, name) + } + return fmt.Sprintf("/%s [%s]", d.Name, strings.Join(names, "|")) +} diff --git a/pkg/commands/definition_test.go b/pkg/commands/definition_test.go new file mode 100644 index 000000000..27ad4a0a2 --- /dev/null +++ b/pkg/commands/definition_test.go @@ -0,0 +1,41 @@ +package commands + +import ( + "testing" +) + +func TestDefinition_EffectiveUsage_NoSubCommands(t *testing.T) { + d := Definition{Name: "start", Usage: "/start"} + if got := d.EffectiveUsage(); got != "/start" { + t.Fatalf("EffectiveUsage()=%q, want %q", got, "/start") + } +} + +func TestDefinition_EffectiveUsage_WithSubCommands(t *testing.T) { + d := Definition{ + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + {Name: "channel"}, + {Name: "agents"}, + }, + } + want := "/show [model|channel|agents]" + if got := d.EffectiveUsage(); got != want { + t.Fatalf("EffectiveUsage()=%q, want %q", got, want) + } +} + +func TestDefinition_EffectiveUsage_WithArgsUsage(t *testing.T) { + d := Definition{ + Name: "session", + SubCommands: []SubCommand{ + {Name: "list"}, + {Name: "resume", ArgsUsage: ""}, + }, + } + want := "/session [list|resume ]" + if got := d.EffectiveUsage(); got != want { + t.Fatalf("EffectiveUsage()=%q, want %q", got, want) + } +} diff --git a/pkg/commands/executor.go b/pkg/commands/executor.go new file mode 100644 index 000000000..78a50e6c2 --- /dev/null +++ b/pkg/commands/executor.go @@ -0,0 +1,89 @@ +package commands + +import ( + "context" + "fmt" +) + +type Outcome int + +const ( + // OutcomePassthrough means this input should continue through normal agent flow. + OutcomePassthrough Outcome = iota + // OutcomeHandled means a command handler executed (with or without handler error). + OutcomeHandled +) + +type ExecuteResult struct { + Outcome Outcome + Command string + Err error +} + +type Executor struct { + reg *Registry + rt *Runtime +} + +func NewExecutor(reg *Registry, rt *Runtime) *Executor { + return &Executor{reg: reg, rt: rt} +} + +// Execute implements a two-state command decision: +// 1) handled: execute command immediately; +// 2) passthrough: not a command or intentionally deferred to agent logic. +func (e *Executor) Execute(ctx context.Context, req Request) ExecuteResult { + cmdName, ok := parseCommandName(req.Text) + if !ok { + return ExecuteResult{Outcome: OutcomePassthrough} + } + + if e == nil || e.reg == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName} + } + + def, found := e.reg.Lookup(cmdName) + if !found { + return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName} + } + + return e.executeDefinition(ctx, req, def) +} + +func (e *Executor) executeDefinition(ctx context.Context, req Request, def Definition) ExecuteResult { + // Ensure Reply is always non-nil so handlers don't need to check. + if req.Reply == nil { + req.Reply = func(string) error { return nil } + } + + // Simple command — no sub-commands + if len(def.SubCommands) == 0 { + if def.Handler == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name} + } + err := def.Handler(ctx, req, e.rt) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + + // Sub-command routing + subName := nthToken(req.Text, 1) + if subName == "" { + err := req.Reply("Usage: " + def.EffectiveUsage()) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + + normalized := normalizeCommandName(subName) + for _, sc := range def.SubCommands { + if normalizeCommandName(sc.Name) == normalized { + if sc.Handler == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name} + } + err := sc.Handler(ctx, req, e.rt) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + } + + // Unknown sub-command + err := req.Reply(fmt.Sprintf("Unknown option: %s. Usage: %s", subName, def.EffectiveUsage())) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} +} diff --git a/pkg/commands/executor_test.go b/pkg/commands/executor_test.go new file mode 100644 index 000000000..09350f1b6 --- /dev/null +++ b/pkg/commands/executor_test.go @@ -0,0 +1,260 @@ +package commands + +import ( + "context" + "errors" + "strings" + "testing" +) + +func TestExecutor_RegisteredWithoutHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{{Name: "show"}} + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/show"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} + +func TestExecutor_UnknownSlashCommand_ReturnsPassthrough(t *testing.T) { + defs := []Definition{{Name: "show"}} + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/unknown"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} + +func TestExecutor_SupportedCommandWithHandler_ReturnsHandled(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help@my_bot"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_AliasWithoutHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + { + Name: "show", + Aliases: []string{"display"}, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/display"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "show" { + t.Fatalf("command=%q, want=%q", res.Command, "show") + } +} + +func TestExecutor_AliasWithHandler_ReturnsHandled(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "clear", + Aliases: []string{"reset"}, + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/reset"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if res.Command != "clear" { + t.Fatalf("command=%q, want=%q", res.Command, "clear") + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_SupportedCommandWithNilHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + {Name: "placeholder"}, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder list"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "placeholder" { + t.Fatalf("command=%q, want=%q", res.Command, "placeholder") + } +} + +func TestExecutor_NilHandlerDoesNotMaskLaterHandler(t *testing.T) { + // With Lookup-based dispatch, the first registered definition for a name wins. + // A definition with nil Handler and no SubCommands returns Passthrough. + defs := []Definition{ + {Name: "placeholder"}, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "placeholder" { + t.Fatalf("command=%q, want=%q", res.Command, "placeholder") + } +} + +func TestExecutor_HandlerErrorIsPropagated(t *testing.T) { + wantErr := errors.New("handler failed") + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + return wantErr + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !errors.Is(res.Err, wantErr) { + t.Fatalf("err=%v, want=%v", res.Err, wantErr) + } +} + +func TestExecutor_SupportsBangPrefixAndCaseInsensitiveCommand(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "!HELP"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_SubCommand_RoutesToCorrectHandler(t *testing.T) { + modelCalled := false + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model", Handler: func(_ context.Context, _ Request, _ *Runtime) error { + modelCalled = true + return nil + }}, + {Name: "channel"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Text: "/show model"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !modelCalled { + t.Fatal("model sub-command handler was not called") + } +} + +func TestExecutor_SubCommand_NoArg_RepliesUsage(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + {Name: "channel"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show", + Reply: func(text string) error { reply = text; return nil }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /show [model|channel]" { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestExecutor_SubCommand_UnknownArg_RepliesError(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show foobar", + Reply: func(text string) error { reply = text; return nil }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "foobar") { + t.Fatalf("reply=%q, should mention unknown sub-command", reply) + } +} + +func TestExecutor_SubCommand_NilHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, // nil Handler + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Text: "/show model"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} diff --git a/pkg/commands/handler_agents.go b/pkg/commands/handler_agents.go new file mode 100644 index 000000000..c459516eb --- /dev/null +++ b/pkg/commands/handler_agents.go @@ -0,0 +1,21 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +// agentsHandler returns a shared handler for both /show agents and /list agents. +func agentsHandler() Handler { + return func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.ListAgentIDs == nil { + return req.Reply(unavailableMsg) + } + ids := rt.ListAgentIDs() + if len(ids) == 0 { + return req.Reply("No agents registered") + } + return req.Reply(fmt.Sprintf("Registered agents: %s", strings.Join(ids, ", "))) + } +} diff --git a/pkg/commands/registry.go b/pkg/commands/registry.go new file mode 100644 index 000000000..e17d489a6 --- /dev/null +++ b/pkg/commands/registry.go @@ -0,0 +1,55 @@ +package commands + +type Registry struct { + defs []Definition + index map[string]int +} + +// NewRegistry stores the canonical command set used by both dispatch and +// optional platform registration adapters. +func NewRegistry(defs []Definition) *Registry { + stored := make([]Definition, len(defs)) + copy(stored, defs) + + index := make(map[string]int, len(stored)*2) + for i, def := range stored { + registerCommandName(index, def.Name, i) + for _, alias := range def.Aliases { + registerCommandName(index, alias, i) + } + } + + return &Registry{defs: stored, index: index} +} + +// Definitions returns all registered command definitions. +// Command availability is global and no longer channel-scoped. +func (r *Registry) Definitions() []Definition { + out := make([]Definition, len(r.defs)) + copy(out, r.defs) + return out +} + +// Lookup returns a command definition by normalized command name or alias. +func (r *Registry) Lookup(name string) (Definition, bool) { + key := normalizeCommandName(name) + if key == "" { + return Definition{}, false + } + idx, ok := r.index[key] + if !ok { + return Definition{}, false + } + return r.defs[idx], true +} + +func registerCommandName(index map[string]int, name string, defIndex int) { + key := normalizeCommandName(name) + if key == "" { + return + } + if _, exists := index[key]; exists { + return + } + index[key] = defIndex +} diff --git a/pkg/commands/registry_test.go b/pkg/commands/registry_test.go new file mode 100644 index 000000000..bfff76b7c --- /dev/null +++ b/pkg/commands/registry_test.go @@ -0,0 +1,49 @@ +package commands + +import "testing" + +func TestRegistry_Definitions_ReturnsCopy(t *testing.T) { + defs := []Definition{ + {Name: "help", Description: "Show help"}, + {Name: "admin", Description: "Admin command"}, + } + r := NewRegistry(defs) + + got := r.Definitions() + if len(got) != 2 { + t.Fatalf("definitions len = %d, want 2", len(got)) + } + + got[0].Name = "mutated" + again := r.Definitions() + if again[0].Name != "help" { + t.Fatalf("registry should not be mutated by caller, got first name %q", again[0].Name) + } +} + +func TestRegistry_Lookup_MatchesByLowercaseNameAndAlias(t *testing.T) { + r := NewRegistry([]Definition{ + {Name: "Help", Aliases: []string{"Assist"}}, + {Name: "List"}, + }) + + def, ok := r.Lookup("help") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by lowercase name failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("HELP") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by uppercase name failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("assist") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by lowercase alias failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("ASSIST") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by uppercase alias failed: ok=%v def=%+v", ok, def) + } +} diff --git a/pkg/commands/request.go b/pkg/commands/request.go new file mode 100644 index 000000000..62ee600f2 --- /dev/null +++ b/pkg/commands/request.go @@ -0,0 +1,75 @@ +package commands + +import ( + "context" + "strings" +) + +type Handler func(ctx context.Context, req Request, rt *Runtime) error + +type Request struct { + Channel string + ChatID string + SenderID string + Text string + Reply func(text string) error +} + +const unavailableMsg = "Command unavailable in current context." + +var commandPrefixes = []string{"/", "!"} + +// parseCommandName accepts "/name", "!name", and Telegram's "/name@bot", then +// normalizes to lowercase command names. +func parseCommandName(input string) (string, bool) { + token := nthToken(input, 0) + if token == "" { + return "", false + } + + name, ok := trimCommandPrefix(token) + if !ok { + return "", false + } + if i := strings.Index(name, "@"); i >= 0 { + name = name[:i] + } + name = normalizeCommandName(name) + if name == "" { + return "", false + } + return name, true +} + +func trimCommandPrefix(token string) (string, bool) { + for _, prefix := range commandPrefixes { + if strings.HasPrefix(token, prefix) { + return strings.TrimPrefix(token, prefix), true + } + } + return "", false +} + +// HasCommandPrefix returns true if the input starts with a recognized +// command prefix (e.g. "/" or "!"). +func HasCommandPrefix(input string) bool { + token := nthToken(input, 0) + if token == "" { + return false + } + _, ok := trimCommandPrefix(token) + return ok +} + +// nthToken returns the 0-indexed token from whitespace-split input. +func nthToken(input string, n int) string { + parts := strings.Fields(strings.TrimSpace(input)) + if n >= len(parts) { + return "" + } + return parts[n] +} + +func normalizeCommandName(name string) string { + return strings.ToLower(strings.TrimSpace(name)) +} diff --git a/pkg/commands/request_test.go b/pkg/commands/request_test.go new file mode 100644 index 000000000..4389e453b --- /dev/null +++ b/pkg/commands/request_test.go @@ -0,0 +1,28 @@ +package commands + +import "testing" + +func TestHasCommandPrefix(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"/help", true}, + {"!help", true}, + {"/switch model to gpt-4", true}, + {"!switch model to gpt-4", true}, + {"hello", false}, + {"", false}, + {" ", false}, + {"hello /world", false}, + {"/", true}, + {"!", true}, + {" /help", true}, + } + for _, tt := range tests { + got := HasCommandPrefix(tt.input) + if got != tt.want { + t.Errorf("HasCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/commands/runtime.go b/pkg/commands/runtime.go new file mode 100644 index 000000000..227d495f4 --- /dev/null +++ b/pkg/commands/runtime.go @@ -0,0 +1,16 @@ +package commands + +import "github.com/sipeed/picoclaw/pkg/config" + +// Runtime provides runtime dependencies to command handlers. It is constructed +// per-request by the agent loop so that per-request state (like session scope) +// can coexist with long-lived callbacks (like GetModelInfo). +type Runtime struct { + Config *config.Config + GetModelInfo func() (name, provider string) + ListAgentIDs func() []string + ListDefinitions func() []Definition + GetEnabledChannels func() []string + SwitchModel func(value string) (oldModel string, err error) + SwitchChannel func(value string) error +} diff --git a/pkg/commands/show_list_handlers_test.go b/pkg/commands/show_list_handlers_test.go new file mode 100644 index 000000000..047708f0f --- /dev/null +++ b/pkg/commands/show_list_handlers_test.go @@ -0,0 +1,85 @@ +package commands + +import ( + "context" + "strings" + "testing" +) + +func TestShowListHandlers_ChannelPolicy(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), nil) + + var telegramReply string + handled := ex.Execute(context.Background(), Request{ + Channel: "telegram", + Text: "/show channel", + Reply: func(text string) error { + telegramReply = text + return nil + }, + }) + if handled.Outcome != OutcomeHandled { + t.Fatalf("telegram /show outcome=%v, want=%v", handled.Outcome, OutcomeHandled) + } + if telegramReply != "Current Channel: telegram" { + t.Fatalf("telegram /show reply=%q, want=%q", telegramReply, "Current Channel: telegram") + } + + var whatsappReply string + handledWhatsApp := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/show channel", + Reply: func(text string) error { + whatsappReply = text + return nil + }, + }) + if handledWhatsApp.Outcome != OutcomeHandled { + t.Fatalf("whatsapp /show outcome=%v, want=%v", handledWhatsApp.Outcome, OutcomeHandled) + } + if handledWhatsApp.Command != "show" { + t.Fatalf("whatsapp /show command=%q, want=%q", handledWhatsApp.Command, "show") + } + if whatsappReply != "Current Channel: whatsapp" { + t.Fatalf("whatsapp /show reply=%q, want=%q", whatsappReply, "Current Channel: whatsapp") + } + + passthrough := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/foo", + }) + if passthrough.Outcome != OutcomePassthrough { + t.Fatalf("whatsapp /foo outcome=%v, want=%v", passthrough.Outcome, OutcomePassthrough) + } + if passthrough.Command != "foo" { + t.Fatalf("whatsapp /foo command=%q, want=%q", passthrough.Command, "foo") + } +} + +func TestShowListHandlers_ListHandledOnAllChannels(t *testing.T) { + rt := &Runtime{ + GetEnabledChannels: func() []string { + return []string{"telegram"} + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/list channels", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("whatsapp /list outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if res.Command != "list" { + t.Fatalf("whatsapp /list command=%q, want=%q", res.Command, "list") + } + if !strings.Contains(reply, "telegram") { + t.Fatalf("whatsapp /list reply=%q, expected enabled channels content", reply) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 3cfebf5e8..cff81a3a7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -167,22 +167,35 @@ type SessionConfig struct { IdentityLinks map[string][]string `json:"identity_links,omitempty"` } +// RoutingConfig controls the intelligent model routing feature. +// When enabled, each incoming message is scored against structural features +// (message length, code blocks, tool call history, conversation depth, attachments). +// Messages scoring below Threshold are sent to LightModel; all others use the +// agent's primary model. This reduces cost and latency for simple tasks without +// requiring any keyword matching — all scoring is language-agnostic. +type RoutingConfig struct { + Enabled bool `json:"enabled"` + LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks + Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model +} + type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead - ModelFallbacks []string `json:"model_fallbacks,omitempty"` - ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` - ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` - SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` - SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` - MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Routing *RoutingConfig `json:"routing,omitempty"` } const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB @@ -526,6 +539,10 @@ type GatewayConfig struct { Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` } +type ToolConfig struct { + Enabled bool `json:"enabled" env:"ENABLED"` +} + type BraveConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` @@ -550,6 +567,12 @@ type PerplexityConfig struct { MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"` } +type SearXNGConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_SEARXNG_ENABLED"` + BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_SEARXNG_BASE_URL"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARXNG_MAX_RESULTS"` +} + type GLMSearchConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"` APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"` @@ -561,11 +584,13 @@ type GLMSearchConfig struct { } type WebToolsConfig struct { - Brave BraveConfig `json:"brave"` - Tavily TavilyConfig `json:"tavily"` - DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` - Perplexity PerplexityConfig `json:"perplexity"` - GLMSearch GLMSearchConfig `json:"glm_search"` + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"` + Brave BraveConfig ` json:"brave"` + Tavily TavilyConfig ` json:"tavily"` + DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"` + Perplexity PerplexityConfig ` json:"perplexity"` + SearXNG SearXNGConfig ` json:"searxng"` + GLMSearch GLMSearchConfig ` json:"glm_search"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` @@ -573,19 +598,29 @@ type WebToolsConfig struct { } type CronToolsConfig struct { - ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"` + ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout } type ExecConfig struct { - EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"` - CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` - CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"` + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"` + EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"` + CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"` + CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"` + TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s) +} + +type SkillsToolsConfig struct { + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"` + Registries SkillsRegistriesConfig ` json:"registries"` + MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"` + SearchCache SearchCacheConfig ` json:"search_cache"` } type MediaCleanupConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"` - MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"` - Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"` + ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"` + MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"` + Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"` } type ToolsConfig struct { @@ -597,12 +632,19 @@ type ToolsConfig struct { Skills SkillsToolsConfig `json:"skills"` MediaCleanup MediaCleanupConfig `json:"media_cleanup"` MCP MCPConfig `json:"mcp"` -} - -type SkillsToolsConfig struct { - Registries SkillsRegistriesConfig `json:"registries"` - MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"` - SearchCache SearchCacheConfig `json:"search_cache"` + AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"` + EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"` + FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"` + I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"` + InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"` + ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"` + Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` + ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"` + Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"` + SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"` + Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"` + WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"` + WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"` } type SearchCacheConfig struct { @@ -648,8 +690,7 @@ type MCPServerConfig struct { // MCPConfig defines configuration for all MCP servers type MCPConfig struct { - // Enabled globally enables/disables MCP integration - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"` + ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"` // Servers is a map of server name to server configuration Servers map[string]MCPServerConfig `json:"servers,omitempty"` } @@ -835,3 +876,48 @@ func (c *Config) ValidateModelList() error { } return nil } + +func (t *ToolsConfig) IsToolEnabled(name string) bool { + switch name { + case "web": + return t.Web.Enabled + case "cron": + return t.Cron.Enabled + case "exec": + return t.Exec.Enabled + case "skills": + return t.Skills.Enabled + case "media_cleanup": + return t.MediaCleanup.Enabled + case "append_file": + return t.AppendFile.Enabled + case "edit_file": + return t.EditFile.Enabled + case "find_skills": + return t.FindSkills.Enabled + case "i2c": + return t.I2C.Enabled + case "install_skill": + return t.InstallSkill.Enabled + case "list_dir": + return t.ListDir.Enabled + case "message": + return t.Message.Enabled + case "read_file": + return t.ReadFile.Enabled + case "spawn": + return t.Spawn.Enabled + case "spi": + return t.SPI.Enabled + case "subagent": + return t.Subagent.Enabled + case "web_fetch": + return t.WebFetch.Enabled + case "write_file": + return t.WriteFile.Enabled + case "mcp": + return t.MCP.Enabled + default: + return true + } +} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 84fc60435..c4c04d41a 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -336,11 +336,16 @@ func DefaultConfig() *Config { }, Tools: ToolsConfig{ MediaCleanup: MediaCleanupConfig{ - Enabled: true, + ToolConfig: ToolConfig{ + Enabled: true, + }, MaxAge: 30, Interval: 5, }, Web: WebToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, Proxy: "", FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default Brave: BraveConfig{ @@ -357,6 +362,11 @@ func DefaultConfig() *Config { APIKey: "", MaxResults: 5, }, + SearXNG: SearXNGConfig{ + Enabled: false, + BaseURL: "", + MaxResults: 5, + }, GLMSearch: GLMSearchConfig{ Enabled: false, APIKey: "", @@ -366,12 +376,22 @@ func DefaultConfig() *Config { }, }, Cron: CronToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, ExecTimeoutMinutes: 5, }, Exec: ExecConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, EnableDenyPatterns: true, + TimeoutSeconds: 60, }, Skills: SkillsToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, Registries: SkillsRegistriesConfig{ ClawHub: ClawHubRegistryConfig{ Enabled: true, @@ -385,9 +405,50 @@ func DefaultConfig() *Config { }, }, MCP: MCPConfig{ - Enabled: false, + ToolConfig: ToolConfig{ + Enabled: false, + }, Servers: map[string]MCPServerConfig{}, }, + AppendFile: ToolConfig{ + Enabled: true, + }, + EditFile: ToolConfig{ + Enabled: true, + }, + FindSkills: ToolConfig{ + Enabled: true, + }, + I2C: ToolConfig{ + Enabled: false, // Hardware tool - Linux only + }, + InstallSkill: ToolConfig{ + Enabled: true, + }, + ListDir: ToolConfig{ + Enabled: true, + }, + Message: ToolConfig{ + Enabled: true, + }, + ReadFile: ToolConfig{ + Enabled: true, + }, + Spawn: ToolConfig{ + Enabled: true, + }, + SPI: ToolConfig{ + Enabled: false, // Hardware tool - Linux only + }, + Subagent: ToolConfig{ + Enabled: true, + }, + WebFetch: ToolConfig{ + Enabled: true, + }, + WriteFile: ToolConfig{ + Enabled: true, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go index 8ce81d09e..f353942ab 100644 --- a/pkg/mcp/manager_test.go +++ b/pkg/mcp/manager_test.go @@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) { mgr := NewManager() mcpCfg := config.MCPConfig{ - Enabled: true, + ToolConfig: config.ToolConfig{ + Enabled: true, + }, Servers: map[string]config.MCPServerConfig{ "test-server": { Enabled: true, @@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) { func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) { mgr := NewManager() - err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp") + err := mgr.LoadFromMCPConfig( + context.Background(), + config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}}, + "/tmp", + ) if err != nil { t.Fatalf("expected nil error when MCP disabled, got: %v", err) } - err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp") + err = mgr.LoadFromMCPConfig( + context.Background(), + config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}}, + "/tmp", + ) if err != nil { t.Fatalf("expected nil error when no servers configured, got: %v", err) } diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go index d4ee528b7..8a1890212 100644 --- a/pkg/providers/antigravity_provider.go +++ b/pkg/providers/antigravity_provider.go @@ -640,7 +640,10 @@ func FetchAntigravityProjectID(accessToken string) (string, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading loadCodeAssist response: %w", err) + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("loadCodeAssist failed: %s", string(body)) } @@ -681,7 +684,10 @@ func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelIn } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading fetchAvailableModels response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf( "fetchAvailableModels failed (HTTP %d): %s", diff --git a/pkg/routing/classifier.go b/pkg/routing/classifier.go new file mode 100644 index 000000000..8cddaf069 --- /dev/null +++ b/pkg/routing/classifier.go @@ -0,0 +1,80 @@ +package routing + +// Classifier evaluates a feature set and returns a complexity score in [0, 1]. +// A higher score indicates a more complex task that benefits from a heavy model. +// The score is compared against the configured threshold: score >= threshold selects +// the primary (heavy) model; score < threshold selects the light model. +// +// Classifier is an interface so that future implementations (ML-based, embedding-based, +// or any other approach) can be swapped in without changing routing infrastructure. +type Classifier interface { + Score(f Features) float64 +} + +// RuleClassifier is the v1 implementation. +// It uses a weighted sum of structural signals with no external dependencies, +// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so +// that the returned score always falls within the [0, 1] contract. +// +// Individual weights (multiple signals can fire simultaneously): +// +// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex +// token 50-200: 0.15 — medium length; may or may not be complex +// code block present: 0.40 — coding tasks need the heavy model +// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow +// tool calls 1-3 (recent): 0.10 — some tool activity +// conversation depth > 10: 0.10 — long sessions carry implicit complexity +// attachments present: 1.00 — hard gate; multi-modal always needs heavy model +// +// Default threshold is 0.35, so: +// - Pure greetings / trivial Q&A: 0.00 → light ✓ +// - Medium prose message (50–200 tokens): 0.15 → light ✓ +// - Message with code block: 0.40 → heavy ✓ +// - Long message (>200 tokens): 0.35 → heavy ✓ +// - Active tool session + medium message: 0.25 → light (acceptable) +// - Any message with an image/audio attachment: 1.00 → heavy ✓ +type RuleClassifier struct{} + +// Score computes the complexity score for the given feature set. +// The returned value is in [0, 1]. Attachments short-circuit to 1.0. +func (c *RuleClassifier) Score(f Features) float64 { + // Hard gate: multi-modal inputs always require the heavy model. + if f.HasAttachments { + return 1.0 + } + + var score float64 + + // Token estimate — primary verbosity signal + switch { + case f.TokenEstimate > 200: + score += 0.35 + case f.TokenEstimate > 50: + score += 0.15 + } + + // Fenced code blocks — strongest indicator of a coding/technical task + if f.CodeBlockCount > 0 { + score += 0.40 + } + + // Recent tool call density — indicates an ongoing agentic workflow + switch { + case f.RecentToolCalls > 3: + score += 0.25 + case f.RecentToolCalls > 0: + score += 0.10 + } + + // Conversation depth — accumulated context implies compound task + if f.ConversationDepth > 10 { + score += 0.10 + } + + // Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire + // simultaneously (e.g., long message + code block + tool chain = 1.10 raw). + if score > 1.0 { + score = 1.0 + } + return score +} diff --git a/pkg/routing/features.go b/pkg/routing/features.go new file mode 100644 index 000000000..c371e21aa --- /dev/null +++ b/pkg/routing/features.go @@ -0,0 +1,127 @@ +package routing + +import ( + "strings" + "unicode/utf8" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// lookbackWindow is the number of recent history entries scanned for tool calls. +// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant). +const lookbackWindow = 6 + +// Features holds the structural signals extracted from a message and its session context. +// Every dimension is language-agnostic by construction — no keyword or pattern matching +// against natural-language content. This ensures consistent routing for all locales. +type Features struct { + // TokenEstimate is a proxy for token count. + // CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each. + // This avoids API calls while giving accurate estimates for all scripts. + TokenEstimate int + + // CodeBlockCount is the number of fenced code blocks (``` pairs) in the message. + // Coding tasks almost always require the heavy model. + CodeBlockCount int + + // RecentToolCalls is the count of tool_call messages in the last lookbackWindow + // history entries. A high density indicates an active agentic workflow. + RecentToolCalls int + + // ConversationDepth is the total number of messages in the session history. + // Deep sessions tend to carry implicit complexity built up over many turns. + ConversationDepth int + + // HasAttachments is true when the message appears to contain media (images, + // audio, video). Multi-modal inputs require vision-capable heavy models. + HasAttachments bool +} + +// ExtractFeatures computes the structural feature vector for a message. +// It is a pure function with no side effects and zero allocations beyond +// the returned struct. +func ExtractFeatures(msg string, history []providers.Message) Features { + return Features{ + TokenEstimate: estimateTokens(msg), + CodeBlockCount: countCodeBlocks(msg), + RecentToolCalls: countRecentToolCalls(history), + ConversationDepth: len(history), + HasAttachments: hasAttachments(msg), + } +} + +// estimateTokens returns a token count proxy that handles both CJK and Latin text. +// CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) map to roughly one +// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token +// for English). Splitting the count this way avoids the 3x underestimation that a +// flat rune_count/3 would produce for Chinese, Japanese, and Korean text. +func estimateTokens(msg string) int { + total := utf8.RuneCountInString(msg) + if total == 0 { + return 0 + } + cjk := 0 + for _, r := range msg { + if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF { + cjk++ + } + } + return cjk + (total-cjk)/4 +} + +// countCodeBlocks counts the number of complete fenced code blocks. +// Each ``` delimiter increments a counter; pairs of delimiters form one block. +// An unclosed opening fence (odd count) is treated as zero complete blocks +// since it may just be an inline code span or a typo. +func countCodeBlocks(msg string) int { + n := strings.Count(msg, "```") + return n / 2 +} + +// countRecentToolCalls counts messages with tool calls in the last lookbackWindow +// entries of history. It examines the ToolCalls field rather than parsing +// the content string, so it is robust to any message format. +func countRecentToolCalls(history []providers.Message) int { + start := len(history) - lookbackWindow + if start < 0 { + start = 0 + } + + count := 0 + for _, msg := range history[start:] { + if len(msg.ToolCalls) > 0 { + count += len(msg.ToolCalls) + } + } + return count +} + +// hasAttachments returns true when the message content contains embedded media. +// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and +// common image/audio URL extensions. This is intentionally conservative — +// false negatives (missing an attachment) just mean the routing falls back to +// the primary model anyway. +func hasAttachments(msg string) bool { + lower := strings.ToLower(msg) + + // Base64 data URIs embedded directly in the message + if strings.Contains(lower, "data:image/") || + strings.Contains(lower, "data:audio/") || + strings.Contains(lower, "data:video/") { + return true + } + + // Common image/audio extensions in URLs or file references + mediaExts := []string{ + ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", + ".mp3", ".wav", ".ogg", ".m4a", ".flac", + ".mp4", ".avi", ".mov", ".webm", + } + for _, ext := range mediaExts { + if strings.Contains(lower, ext) { + return true + } + } + + return false +} diff --git a/pkg/routing/router.go b/pkg/routing/router.go new file mode 100644 index 000000000..b1fa347e9 --- /dev/null +++ b/pkg/routing/router.go @@ -0,0 +1,82 @@ +package routing + +import ( + "github.com/sipeed/picoclaw/pkg/providers" +) + +// defaultThreshold is used when the config threshold is zero or negative. +// At 0.35 a message needs at least one strong signal (code block, long text, +// or an attachment) before the heavy model is chosen. +const defaultThreshold = 0.35 + +// RouterConfig holds the validated model routing settings. +// It mirrors config.RoutingConfig but lives in pkg/routing to keep the +// dependency graph simple: pkg/agent resolves config → routing, not the reverse. +type RouterConfig struct { + // LightModel is the model_name (from model_list) used for simple tasks. + LightModel string + + // Threshold is the complexity score cutoff in [0, 1]. + // score >= Threshold → primary (heavy) model. + // score < Threshold → light model. + Threshold float64 +} + +// Router selects the appropriate model tier for each incoming message. +// It is safe for concurrent use from multiple goroutines. +type Router struct { + cfg RouterConfig + classifier Classifier +} + +// New creates a Router with the given config and the default RuleClassifier. +// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used. +func New(cfg RouterConfig) *Router { + if cfg.Threshold <= 0 { + cfg.Threshold = defaultThreshold + } + return &Router{ + cfg: cfg, + classifier: &RuleClassifier{}, + } +} + +// newWithClassifier creates a Router with a custom Classifier. +// Intended for unit tests that need to inject a deterministic scorer. +func newWithClassifier(cfg RouterConfig, c Classifier) *Router { + if cfg.Threshold <= 0 { + cfg.Threshold = defaultThreshold + } + return &Router{cfg: cfg, classifier: c} +} + +// SelectModel returns the model to use for this conversation turn along with +// the computed complexity score (for logging and debugging). +// +// - If score < cfg.Threshold: returns (cfg.LightModel, true, score) +// - Otherwise: returns (primaryModel, false, score) +// +// The caller is responsible for resolving the returned model name into +// provider candidates (see AgentInstance.LightCandidates). +func (r *Router) SelectModel( + msg string, + history []providers.Message, + primaryModel string, +) (model string, usedLight bool, score float64) { + features := ExtractFeatures(msg, history) + score = r.classifier.Score(features) + if score < r.cfg.Threshold { + return r.cfg.LightModel, true, score + } + return primaryModel, false, score +} + +// LightModel returns the configured light model name. +func (r *Router) LightModel() string { + return r.cfg.LightModel +} + +// Threshold returns the complexity threshold in use. +func (r *Router) Threshold() float64 { + return r.cfg.Threshold +} diff --git a/pkg/routing/router_test.go b/pkg/routing/router_test.go new file mode 100644 index 000000000..2824d10ab --- /dev/null +++ b/pkg/routing/router_test.go @@ -0,0 +1,414 @@ +package routing + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// ── ExtractFeatures ────────────────────────────────────────────────────────── + +func TestExtractFeatures_EmptyMessage(t *testing.T) { + f := ExtractFeatures("", nil) + if f.TokenEstimate != 0 { + t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate) + } + if f.CodeBlockCount != 0 { + t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount) + } + if f.RecentToolCalls != 0 { + t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls) + } + if f.ConversationDepth != 0 { + t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth) + } + if f.HasAttachments { + t.Error("HasAttachments: got true, want false") + } +} + +func TestExtractFeatures_TokenEstimate(t *testing.T) { + // 30 ASCII runes: 0 CJK + 30/4 = 7 tokens + msg := strings.Repeat("a", 30) + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 7 { + t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate) + } +} + +func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) { + // 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token). + // Using a rune slice literal avoids CJK string literals in source. + msg := string([]rune{ + 0x4F60, 0x597D, 0x4E16, 0x754C, + 0x4F60, 0x597D, 0x4E16, 0x754C, + 0x4F60, + }) + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 9 { + t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate) + } +} + +func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) { + // Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens. + msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok" + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 6 { + t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate) + } +} + +func TestExtractFeatures_CodeBlocks(t *testing.T) { + cases := []struct { + msg string + want int + }{ + {"no code here", 0}, + {"```go\nfmt.Println()\n```", 1}, + {"```python\npass\n```\n```js\nconsole.log()\n```", 2}, + {"```unclosed", 0}, // odd number of fences = 0 complete blocks + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.CodeBlockCount != tc.want { + t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want) + } + } +} + +func TestExtractFeatures_RecentToolCalls(t *testing.T) { + // History longer than lookbackWindow — only last lookbackWindow entries count. + history := make([]providers.Message, 10) + // Put 2 tool calls at positions 8 and 9 (within the last 6) + history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}} + history[9] = providers.Message{ + Role: "assistant", + ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}, + } + // Position 3 is outside the lookback window and must NOT be counted + history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}} + + f := ExtractFeatures("test", history) + // 1 (position 8) + 2 (position 9) = 3 + if f.RecentToolCalls != 3 { + t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls) + } +} + +func TestExtractFeatures_ConversationDepth(t *testing.T) { + history := make([]providers.Message, 7) + f := ExtractFeatures("msg", history) + if f.ConversationDepth != 7 { + t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth) + } +} + +func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) { + cases := []struct { + msg string + want bool + }{ + {"plain text", false}, + {"here is an image: data:image/png;base64,abc123", true}, + {"audio: data:audio/mp3;base64,xyz", true}, + {"video: data:video/mp4;base64,xyz", true}, + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.HasAttachments != tc.want { + t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want) + } + } +} + +func TestExtractFeatures_HasAttachments_Extension(t *testing.T) { + cases := []struct { + msg string + want bool + }{ + {"check out photo.jpg", true}, + {"see screenshot.png", true}, + {"listen to audio.mp3", true}, + {"watch clip.mp4", true}, + {"just a .go file", false}, + {"document.pdf", false}, // pdf is not in the media list + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.HasAttachments != tc.want { + t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want) + } + } +} + +// ── RuleClassifier ─────────────────────────────────────────────────────────── + +func TestRuleClassifier_ZeroFeatures(t *testing.T) { + c := &RuleClassifier{} + score := c.Score(Features{}) + if score != 0.0 { + t.Errorf("zero features: got %f, want 0.0", score) + } +} + +func TestRuleClassifier_AttachmentsHardGate(t *testing.T) { + c := &RuleClassifier{} + score := c.Score(Features{HasAttachments: true}) + if score != 1.0 { + t.Errorf("attachments: got %f, want 1.0", score) + } +} + +func TestRuleClassifier_CodeBlockAlone(t *testing.T) { + c := &RuleClassifier{} + // Code block alone = 0.40, above default threshold 0.35 + score := c.Score(Features{CodeBlockCount: 1}) + if score < 0.35 { + t.Errorf("code block: score %f is below default threshold 0.35", score) + } +} + +func TestRuleClassifier_LongMessage(t *testing.T) { + c := &RuleClassifier{} + // >200 tokens = 0.35, exactly at default threshold → heavy + score := c.Score(Features{TokenEstimate: 250}) + if score < 0.35 { + t.Errorf("long message: score %f is below default threshold 0.35", score) + } +} + +func TestRuleClassifier_MediumMessage(t *testing.T) { + c := &RuleClassifier{} + // 50-200 tokens = 0.15, below threshold → light + score := c.Score(Features{TokenEstimate: 100}) + if score >= 0.35 { + t.Errorf("medium message: score %f should be below default threshold 0.35", score) + } +} + +func TestRuleClassifier_ShortMessage(t *testing.T) { + c := &RuleClassifier{} + // <50 tokens, no other signals = 0.0 → light + score := c.Score(Features{TokenEstimate: 10}) + if score != 0.0 { + t.Errorf("short message: got %f, want 0.0", score) + } +} + +func TestRuleClassifier_ToolCallDensity(t *testing.T) { + c := &RuleClassifier{} + + scoreNone := c.Score(Features{RecentToolCalls: 0}) + scoreLow := c.Score(Features{RecentToolCalls: 2}) + scoreHigh := c.Score(Features{RecentToolCalls: 5}) + + if scoreNone != 0.0 { + t.Errorf("no tools: got %f, want 0.0", scoreNone) + } + if scoreLow <= scoreNone { + t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone) + } + if scoreHigh <= scoreLow { + t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow) + } +} + +func TestRuleClassifier_DeepConversation(t *testing.T) { + c := &RuleClassifier{} + shallow := c.Score(Features{ConversationDepth: 5}) + deep := c.Score(Features{ConversationDepth: 15}) + if deep <= shallow { + t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow) + } +} + +func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) { + c := &RuleClassifier{} + // Max all signals simultaneously + f := Features{ + TokenEstimate: 500, + CodeBlockCount: 3, + RecentToolCalls: 10, + ConversationDepth: 20, + } + score := c.Score(f) + if score > 1.0 { + t.Errorf("score %f exceeds 1.0", score) + } +} + +// ── Router ─────────────────────────────────────────────────────────────────── + +func TestRouter_DefaultThreshold(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash"}) + if r.Threshold() != defaultThreshold { + t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold) + } +} + +func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1}) + if r.Threshold() != defaultThreshold { + t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold) + } +} + +func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "hi" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if !usedLight { + t.Error("simple message: expected light model to be selected") + } + if model != "gemini-flash" { + t.Errorf("simple message: model got %q, want %q", model, "gemini-flash") + } +} + +func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "```go\nfmt.Println(\"hello\")\n```" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("code block: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "can you analyze this? data:image/png;base64,abc123" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("attachment: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + // >200 token estimate: 210 * 3 = 630 chars + msg := strings.Repeat("word ", 210) + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("long message: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) { + // Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior. + // Routing is conservative: only promote to heavy when the signal is unambiguous. + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + history := []providers.Message{ + {Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}}, + {Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}}, + } + msg := "ok" + _, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6") + if !usedLight { + t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)") + } +} + +func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) { + // Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + history := []providers.Message{ + {Role: "assistant", ToolCalls: []providers.ToolCall{ + {Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"}, + }}, + } + // ~55 tokens * 3 = 165 chars + msg := strings.Repeat("word ", 55) + _, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6") + if usedLight { + t.Error("tool chain + medium message: expected primary model (score >= 0.35)") + } +} + +func TestRouter_SelectModel_CustomThreshold(t *testing.T) { + // Very low threshold: even a short message triggers heavy model + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05}) + msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05 + _, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("low threshold: medium message should use primary model") + } +} + +func TestRouter_SelectModel_HighThreshold(t *testing.T) { + // Very high threshold: even code blocks route to light + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99}) + msg := "```go\nfmt.Println()\n```" + _, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if !usedLight { + t.Error("very high threshold: code block (0.40) should route to light model") + } +} + +func TestRouter_LightModel(t *testing.T) { + r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35}) + if r.LightModel() != "my-fast-model" { + t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model") + } +} + +// ── newWithClassifier (internal testing hook) ───────────────────────────────── + +type fixedScoreClassifier struct{ score float64 } + +func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score } + +func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.2}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if !usedLight { + t.Error("low score with custom classifier: expected light model") + } +} + +func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.8}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if usedLight { + t.Error("high score with custom classifier: expected primary model") + } +} + +func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) { + // score == threshold → primary (uses >= comparison) + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.5}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if usedLight { + t.Error("score == threshold: expected primary model (>= threshold → primary)") + } +} + +func TestRouter_SelectModel_ReturnsScore(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.42}, + ) + _, _, score := r.SelectModel("anything", nil, "heavy") + if score != 0.42 { + t.Errorf("score: got %f, want 0.42", score) + } +} diff --git a/pkg/skills/clawhub_registry.go b/pkg/skills/clawhub_registry.go index f78197bbe..bd4bed8fb 100644 --- a/pkg/skills/clawhub_registry.go +++ b/pkg/skills/clawhub_registry.go @@ -259,15 +259,7 @@ func (c *ClawHubRegistry) DownloadAndInstall( } u.RawQuery = q.Encode() - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - if c.authToken != "" { - req.Header.Set("Authorization", "Bearer "+c.authToken) - } - - tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize)) + tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String()) if err != nil { return nil, fmt.Errorf("download failed: %w", err) } @@ -284,17 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall( // --- HTTP helper --- func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) { - req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) + req, err := c.newGetRequest(ctx, urlStr, "application/json") if err != nil { return nil, err } - req.Header.Set("Accept", "application/json") - if c.authToken != "" { - req.Header.Set("Authorization", "Bearer "+c.authToken) - } - - resp, err := c.client.Do(req) + resp, err := utils.DoRequestWithRetry(c.client, req) if err != nil { return nil, err } @@ -312,3 +299,64 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err return body, nil } + +func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", accept) + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + return req, nil +} + +func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) { + req, err := c.newGetRequest(ctx, urlStr, "application/zip") + if err != nil { + return "", err + } + + resp, err := utils.DoRequestWithRetry(c.client, req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody := make([]byte, 512) + n, _ := io.ReadFull(resp.Body, errBody) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n])) + } + + tmpFile, err := os.CreateTemp("", "picoclaw-dl-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + cleanup := func() { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + } + + src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1) + written, err := io.Copy(tmpFile, src) + if err != nil { + cleanup() + return "", fmt.Errorf("download write failed: %w", err) + } + + if written > int64(c.maxZipSize) { + cleanup() + return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize) + } + + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("failed to close temp file: %w", err) + } + + return tmpPath, nil +} diff --git a/pkg/skills/clawhub_registry_test.go b/pkg/skills/clawhub_registry_test.go index 65ee638da..055da22dc 100644 --- a/pkg/skills/clawhub_registry_test.go +++ b/pkg/skills/clawhub_registry_test.go @@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) { assert.Equal(t, "clawhub", results[0].RegistryName) } +func TestClawHubRegistrySearchRetries429(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + + slug := "github" + name := "GitHub Integration" + summary := "Interact with GitHub repos" + version := "1.0.0" + + json.NewEncoder(w).Encode(clawhubSearchResponse{ + Results: []clawhubSearchResult{ + {Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version}, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + results, err := reg.Search(context.Background(), "github", 5) + + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, 2, attempts) + assert.Equal(t, "github", results[0].Slug) +} + func TestClawHubRegistryGetSkillMeta(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/api/v1/skills/github", r.URL.Path) @@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) { assert.Contains(t, string(readmeContent), "# Test Skill") } +func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) { + zipBuf := createTestZip(t, map[string]string{ + "SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill", + }) + + downloadAttempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/skills/retry-skill": + json.NewEncoder(w).Encode(clawhubSkillResponse{ + Slug: "retry-skill", + DisplayName: "Retry Skill", + Summary: "A retry test skill", + LatestVersion: &clawhubVersionInfo{Version: "1.0.0"}, + }) + case "/api/v1/download": + downloadAttempts++ + if downloadAttempts == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + assert.Equal(t, "retry-skill", r.URL.Query().Get("slug")) + w.Header().Set("Content-Type", "application/zip") + w.Write(zipBuf) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + tmpDir := t.TempDir() + targetDir := filepath.Join(tmpDir, "retry-skill") + + reg := newTestRegistry(srv.URL, "") + result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "1.0.0", result.Version) + assert.Equal(t, 2, downloadAttempts) + + skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md")) + require.NoError(t, err) + assert.Contains(t, string(skillContent), "Hello skill") +} + func TestClawHubRegistryAuthToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 1c8cff99c..a41279280 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -132,9 +132,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf denyPatterns = append(denyPatterns, defaultDenyPatterns...) } + timeout := 60 * time.Second + if config != nil && config.Tools.Exec.TimeoutSeconds > 0 { + timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second + } + return &ExecTool{ workingDir: workingDir, - timeout: 60 * time.Second, + timeout: timeout, denyPatterns: denyPatterns, allowPatterns: nil, customAllowPatterns: customAllowPatterns, diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 7b14686c9..eeceabd98 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -395,6 +395,68 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil } +type SearXNGSearchProvider struct { + baseURL string +} + +func (p *SearXNGSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := fmt.Sprintf("%s/search?q=%s&format=json&categories=general", + strings.TrimSuffix(p.baseURL, "/"), + url.QueryEscape(query)) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("SearXNG returned status %d", resp.StatusCode) + } + + var result struct { + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` + Engine string `json:"engine"` + Score float64 `json:"score"` + } `json:"results"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if len(result.Results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + // Limit results to requested count + if len(result.Results) > count { + result.Results = result.Results[:count] + } + + // Format results in standard PicoClaw format + var b strings.Builder + b.WriteString(fmt.Sprintf("Results for: %s (via SearXNG)\n", query)) + for i, r := range result.Results { + b.WriteString(fmt.Sprintf("%d. %s\n", i+1, r.Title)) + b.WriteString(fmt.Sprintf(" %s\n", r.URL)) + if r.Content != "" { + b.WriteString(fmt.Sprintf(" %s\n", r.Content)) + } + } + + return b.String(), nil +} + type GLMSearchProvider struct { apiKey string baseURL string @@ -495,6 +557,9 @@ type WebSearchToolOptions struct { PerplexityAPIKey string PerplexityMaxResults int PerplexityEnabled bool + SearXNGBaseURL string + SearXNGMaxResults int + SearXNGEnabled bool GLMSearchAPIKey string GLMSearchBaseURL string GLMSearchEngine string @@ -507,7 +572,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { var provider SearchProvider maxResults := 5 - // Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search + // Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { client, err := createHTTPClient(opts.Proxy, perplexityTimeout) if err != nil { @@ -526,6 +591,11 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { if opts.BraveMaxResults > 0 { maxResults = opts.BraveMaxResults } + } else if opts.SearXNGEnabled && opts.SearXNGBaseURL != "" { + provider = &SearXNGSearchProvider{baseURL: opts.SearXNGBaseURL} + if opts.SearXNGMaxResults > 0 { + maxResults = opts.SearXNGMaxResults + } } else if opts.TavilyEnabled && opts.TavilyAPIKey != "" { client, err := createHTTPClient(opts.Proxy, searchTimeout) if err != nil {