diff --git a/.env.example b/.env.example index 98fa7f868..798815a7a 100644 --- a/.env.example +++ b/.env.example @@ -5,13 +5,10 @@ # ANTHROPIC_API_KEY=sk-ant-xxx # OPENAI_API_KEY=sk-xxx # GEMINI_API_KEY=xxx -# CEREBRAS_API_KEY=xxx - +# CLAUDE_CODE_OAUTH=xxx # ── Chat Channel ────────────────────────── # TELEGRAM_BOT_TOKEN=123456:ABC... # DISCORD_BOT_TOKEN=xxx -# LINE_CHANNEL_SECRET=xxx -# LINE_CHANNEL_ACCESS_TOKEN=xxx # Feishu (飞书) # PICOCLAW_CHANNELS_FEISHU_APP_ID=cli_xxx # PICOCLAW_CHANNELS_FEISHU_APP_SECRET=xxx diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 786c893ef..0edd29f22 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,6 +17,11 @@ on: required: false type: boolean default: false + upload_tos: + description: "Upload to Volcengine TOS" + required: false + type: boolean + default: true jobs: create-tag: @@ -100,3 +105,12 @@ jobs: gh release edit "${{ inputs.tag }}" \ --draft=${{ inputs.draft }} \ --prerelease=${{ inputs.prerelease }} + + upload-tos: + name: Upload to TOS + needs: release + if: ${{ inputs.upload_tos }} + uses: ./.github/workflows/upload-tos.yml + with: + tag: ${{ inputs.tag }} + secrets: inherit diff --git a/.github/workflows/upload-tos.yml b/.github/workflows/upload-tos.yml new file mode 100644 index 000000000..6d3916d53 --- /dev/null +++ b/.github/workflows/upload-tos.yml @@ -0,0 +1,49 @@ +name: Upload to Volcengine TOS + +on: + workflow_dispatch: + inputs: + tag: + description: "Release tag to download and upload (e.g. v0.2.0)" + required: true + type: string + workflow_call: + inputs: + tag: + description: "Release tag to download and upload" + required: true + type: string + +jobs: + upload-tos: + name: Upload to Volcengine TOS + runs-on: ubuntu-latest + steps: + - name: Download release assets + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + mkdir -p artifacts + gh release download "${{ inputs.tag }}" \ + --repo "${{ github.repository }}" \ + --dir artifacts \ + --pattern "*.tar.gz" \ + --pattern "*.zip" \ + --pattern "*.rpm" \ + --pattern "*.deb" + + - name: Upload to Volcengine TOS + env: + AWS_ACCESS_KEY_ID: ${{ secrets.VOLC_TOS_ACCESS_KEY }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.VOLC_TOS_SECRET_KEY }} + AWS_DEFAULT_REGION: cn-beijing + run: | + aws configure set default.s3.addressing_style virtual + TOS_ENDPOINT="https://tos-s3-cn-beijing.volces.com" + # Upload to versioned directory + aws s3 sync artifacts/ "s3://picoclaw-downloads/${{ inputs.tag }}/" \ + --endpoint-url "$TOS_ENDPOINT" + # Upload to latest (overwrite) + aws s3 sync artifacts/ "s3://picoclaw-downloads/latest/" \ + --endpoint-url "$TOS_ENDPOINT" \ + --delete diff --git a/.gitignore b/.gitignore index 02ef18d1f..a52b8d25a 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,9 @@ ralph/ .ralph/ tasks/ +# Plans +docs/plans/ + # Editors .vscode/ .idea/ diff --git a/Makefile b/Makefile index afc76a6ad..8de98e984 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,28 @@ LDFLAGS=-ldflags "-X $(INTERNAL).version=$(VERSION) -X $(INTERNAL).gitCommit=$(G GO?=CGO_ENABLED=0 go GOFLAGS?=-v -tags stdjson +# Patch MIPS LE ELF e_flags (offset 36) for NaN2008-only kernels (e.g. Ingenic X2600). +# +# Bytes (octal): \004 \024 \000 \160 → little-endian 0x70001404 +# 0x70000000 EF_MIPS_ARCH_32R2 MIPS32 Release 2 +# 0x00001000 EF_MIPS_ABI_O32 O32 ABI +# 0x00000400 EF_MIPS_NAN2008 IEEE 754-2008 NaN encoding +# 0x00000004 EF_MIPS_CPIC PIC calling sequence +# +# Go's GOMIPS=softfloat emits no FP instructions, so the NaN mode is irrelevant +# at runtime — this is purely an ELF metadata fix to satisfy the kernel's check. +# patchelf cannot modify e_flags; dd at a fixed offset is the most portable way. +# +# Ref: https://codebrowser.dev/linux/linux/arch/mips/include/asm/elf.h.html +define PATCH_MIPS_FLAGS + @if [ -f "$(1)" ]; then \ + printf '\004\024\000\160' | dd of=$(1) bs=1 seek=36 count=4 conv=notrunc 2>/dev/null || \ + { echo "Error: failed to patch MIPS e_flags for $(1)"; exit 1; }; \ + else \ + echo "Error: $(1) not found, cannot patch MIPS e_flags"; exit 1; \ + fi +endef + # Golangci-lint GOLANGCI_LINT?=golangci-lint @@ -50,6 +72,8 @@ ifeq ($(UNAME_S),Linux) ARCH=loong64 else ifeq ($(UNAME_M),riscv64) ARCH=riscv64 + else ifeq ($(UNAME_M),mipsel) + ARCH=mipsle else ARCH=$(UNAME_M) endif @@ -97,6 +121,8 @@ build-whatsapp-native: generate GOOS=linux GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) GOOS=linux GOARCH=loong64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) GOOS=linux GOARCH=riscv64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) + GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) + $(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle) GOOS=darwin GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) GOOS=windows GOARCH=amd64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) ## @$(GO) build $(GOFLAGS) -tags whatsapp_native $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR) @@ -117,6 +143,14 @@ build-linux-arm64: generate GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64" +## build-linux-mipsle: Build for Linux MIPS32 LE +build-linux-mipsle: generate + @echo "Building for linux/mipsle (softfloat)..." + @mkdir -p $(BUILD_DIR) + GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) + $(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle) + @echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle" + ## build-pi-zero: Build for Raspberry Pi Zero 2 W (32-bit and 64-bit) build-pi-zero: build-linux-arm build-linux-arm64 @echo "Pi Zero 2 W builds: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm (32-bit), $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 (64-bit)" @@ -130,6 +164,8 @@ build-all: generate GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR) GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) + GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR) + $(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle) GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR) GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) diff --git a/README.fr.md b/README.fr.md index 320aa9e22..08a1926b6 100644 --- a/README.fr.md +++ b/README.fr.md @@ -7,7 +7,7 @@

Go - Hardware + Hardware License
Website @@ -65,7 +65,7 @@ ⚡️ **Démarrage Éclair** : Temps de démarrage 400X plus rapide, boot en 1 seconde même sur un cœur unique à 0,6 GHz. -🌍 **Véritable Portabilité** : Un seul binaire autonome pour RISC-V, ARM et x86. Un clic et c'est parti ! +🌍 **Véritable Portabilité** : Un seul binaire autonome pour RISC-V, ARM, MIPS et x86. Un clic et c'est parti ! 🤖 **Auto-Construit par l'IA** : Implémentation native en Go de manière autonome — 95% du cœur généré par l'Agent avec affinement humain dans la boucle. diff --git a/README.ja.md b/README.ja.md index ea6bc7e72..c4c5b27a0 100644 --- a/README.ja.md +++ b/README.ja.md @@ -8,7 +8,7 @@

Go -Hardware +Hardware License

@@ -49,7 +49,7 @@ ⚡️ **超高速**: 起動時間 400 倍高速、0.6GHz シングルコアでも 1 秒で起動。 -🌍 **真のポータビリティ**: RISC-V、ARM、x86 対応の単一バイナリ。ワンクリックで Go! +🌍 **真のポータビリティ**: RISC-V、ARM、MIPS、x86 対応の単一バイナリ。ワンクリックで Go! 🤖 **AI ブートストラップ**: 自律的な Go ネイティブ実装 — コアの 95% が AI 生成、人間によるレビュー付き。 diff --git a/README.md b/README.md index 7a31f9364..c9cc28f58 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

Go - Hardware + Hardware License
Website @@ -69,7 +69,7 @@ ⚡️ **Lightning Fast**: 400X Faster startup time, boot in 1 second even in 0.6GHz single core. -🌍 **True Portability**: Single self-contained binary across RISC-V, ARM, and x86, One-click to Go! +🌍 **True Portability**: Single self-contained binary across RISC-V, ARM, MIPS, and x86, One-click to Go! 🤖 **AI-Bootstrapped**: Autonomous Go-native implementation — 95% Agent-generated core with human-in-the-loop refinement. @@ -353,6 +353,13 @@ Talk to your picoclaw through Telegram, Discord, WhatsApp, DingTalk, LINE, or We picoclaw gateway ``` +**4. Telegram command menu (auto-registered at startup)** + +PicoClaw now keeps command definitions in one shared registry. On startup, Telegram will automatically register supported bot commands (for example `/start`, `/help`, `/show`, `/list`) so command menu and runtime behavior stay in sync. +Telegram command menu registration remains channel-local discovery UX; generic command execution is handled centrally in the agent loop via the commands executor. + +If command registration fails (network/API transient errors), the channel still starts and PicoClaw retries registration in the background. +

@@ -750,6 +757,12 @@ For advanced/test setups, you can override the builtin skills root with: export PICOCLAW_BUILTIN_SKILLS=/path/to/skills ``` +### Unified Command Execution Policy + +- Generic slash commands are executed through a single path in `pkg/agent/loop.go` via `commands.Executor`. +- Channel adapters no longer consume generic commands locally; they forward inbound text to the bus/agent path. Telegram still auto-registers supported commands at startup. +- Unknown slash command (for example `/foo`) passes through to normal LLM processing. +- Registered but unsupported command on the current channel (for example `/show` on WhatsApp) returns an explicit user-facing error and stops further processing. ### 🔒 Security Sandbox PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. @@ -939,6 +952,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate | `qwen` | LLM (Qwen direct) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | | `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) | +| `vivgrid` | LLM (Vivgrid direct) | [vivgrid.com](https://vivgrid.com) | ### Model Configuration (model_list) @@ -966,11 +980,12 @@ This design also enables **multi-agent support** with flexible provider selectio | **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) | | **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) | | **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) | -| **LiteLLM Proxy** | `litellm/` | `http://localhost:4000/v1 | OpenAI | Your LiteLLM proxy key | +| **LiteLLM Proxy** | `litellm/` | `http://localhost:4000/v1` | OpenAI | Your LiteLLM proxy key | | **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local | | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) | | **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) | | **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | @@ -1205,6 +1220,10 @@ picoclaw agent -m "Hello" "model": "anthropic/claude-opus-4-5" } }, + "session": { + "dm_scope": "per-channel-peer", + "backlog_limit": 20 + }, "providers": { "openrouter": { "api_key": "sk-or-v1-xxx" diff --git a/README.pt-br.md b/README.pt-br.md index 67ce9e0d3..5f37ba457 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -7,7 +7,7 @@

Go - Hardware + Hardware License
Website @@ -66,7 +66,7 @@ ⚡️ **Inicialização Relámpago**: Tempo de inicialização 400X mais rápido, boot em 1 segundo mesmo em CPU single-core de 0.6GHz. -🌍 **Portabilidade Real**: Um único binário auto-contido para RISC-V, ARM e x86. Um clique e já era! +🌍 **Portabilidade Real**: Um único binário auto-contido para RISC-V, ARM, MIPS e x86. Um clique e já era! 🤖 **Auto-Construído por IA**: Implementação nativa em Go de forma autônoma — 95% do núcleo gerado pelo Agente com refinamento humano no loop. diff --git a/README.vi.md b/README.vi.md index 5755896ed..92c6ecbae 100644 --- a/README.vi.md +++ b/README.vi.md @@ -7,7 +7,7 @@

Go - Hardware + Hardware License
Website @@ -65,7 +65,7 @@ ⚡️ **Khởi động siêu nhanh**: Nhanh gấp 400 lần, khởi động trong 1 giây ngay cả trên CPU đơn nhân 0.6GHz. -🌍 **Di động thực sự**: Một file binary duy nhất chạy trên RISC-V, ARM và x86. Một click là chạy! +🌍 **Di động thực sự**: Một file binary duy nhất chạy trên RISC-V, ARM, MIPS và x86. Một click là chạy! 🤖 **AI tự xây dựng**: Triển khai Go-native tự động — 95% mã nguồn cốt lõi được Agent tạo ra, với sự tinh chỉnh của con người. diff --git a/README.zh.md b/README.zh.md index bd90173f9..d42b3cbb8 100644 --- a/README.zh.md +++ b/README.zh.md @@ -7,7 +7,7 @@

Go - Hardware + Hardware License
Website @@ -67,7 +67,7 @@ ⚡️ **闪电启动**: 启动速度快 400 倍,即使在 0.6GHz 单核处理器上也能在 1 秒内启动。 -🌍 **真正可移植**: 跨 RISC-V、ARM 和 x86 架构的单二进制文件,一键运行! +🌍 **真正可移植**: 跨 RISC-V、ARM、MIPS 和 x86 架构的单二进制文件,一键运行! 🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。 @@ -307,6 +307,13 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方 | **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) | | **MaixCam** | ⭐ 简单 | 专为 AI 摄像头设计的硬件集成通道 | [查看文档](docs/channels/maixcam/README.zh.md) | +### Telegram 命令注册(启动时自动同步) + +PicoClaw 现在使用统一的命令定义来源。启动时会自动将 Telegram 支持的命令(例如 `/start`、`/help`、`/show`、`/list`)注册到 Bot 命令菜单,确保菜单展示与实际行为一致。 +Telegram 侧保留的是命令菜单注册能力;通用命令的实际执行统一走 Agent Loop 中的 commands executor。 + +如果注册因网络或 API 短暂异常失败,不会阻塞 channel 启动;系统会在后台自动重试。 + ## ClawdChat 加入 Agent 社交网络 只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。 @@ -376,6 +383,12 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work export PICOCLAW_BUILTIN_SKILLS=/path/to/skills ``` +### 统一命令执行策略 + +- 通用斜杠命令通过 `pkg/agent/loop.go` 中的 `commands.Executor` 统一执行。 +- Channel 适配器不再在本地消费通用命令;它们只负责把入站文本转发到 bus/agent 路径。Telegram 仍会在启动时自动注册其支持的命令菜单。 +- 未注册的斜杠命令(例如 `/foo`)会透传给 LLM 按普通输入处理。 +- 已注册但当前 channel 不支持的命令(例如 WhatsApp 上的 `/show`)会返回明确的用户可见错误,并停止后续处理。 ### 心跳 / 周期性任务 (Heartbeat) PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: @@ -715,6 +728,10 @@ picoclaw agent -m "你好" "model": "anthropic/claude-opus-4-5" } }, + "session": { + "dm_scope": "per-channel-peer", + "backlog_limit": 20 + }, "providers": { "openrouter": { "api_key": "sk-or-v1-xxx" diff --git a/assets/wechat.png b/assets/wechat.png index 32998c122..cc88186a8 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw-launcher/internal/ui/index.html b/cmd/picoclaw-launcher/internal/ui/index.html index 93893fd75..d84fd4e6e 100644 --- a/cmd/picoclaw-launcher/internal/ui/index.html +++ b/cmd/picoclaw-launcher/internal/ui/index.html @@ -1392,9 +1392,7 @@ function saveModelFromModal() { saveConfig().then(renderModels); } -document.getElementById('modelModal').addEventListener('click', function(e) { - if (e.target === this) closeModelModal(); -}); + // ── Channel Forms ─────────────────────────────────── function renderChannelForm(chKey) { diff --git a/cmd/picoclaw/internal/auth/helpers.go b/cmd/picoclaw/internal/auth/helpers.go index 4dfbc92e7..a0a229167 100644 --- a/cmd/picoclaw/internal/auth/helpers.go +++ b/cmd/picoclaw/internal/auth/helpers.go @@ -1,6 +1,7 @@ package auth import ( + "bufio" "encoding/json" "fmt" "io" @@ -15,14 +16,17 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" ) -const supportedProvidersMsg = "supported providers: openai, anthropic, google-antigravity" +const ( + supportedProvidersMsg = "supported providers: openai, anthropic, google-antigravity" + defaultAnthropicModel = "claude-sonnet-4.6" +) -func authLoginCmd(provider string, useDeviceCode bool) error { +func authLoginCmd(provider string, useDeviceCode bool, useOauth bool) error { switch provider { case "openai": return authLoginOpenAI(useDeviceCode) case "anthropic": - return authLoginPasteToken(provider) + return authLoginAnthropic(useOauth) case "google-antigravity", "antigravity": return authLoginGoogleAntigravity() default: @@ -163,6 +167,81 @@ func authLoginGoogleAntigravity() error { return nil } +func authLoginAnthropic(useOauth bool) error { + if useOauth { + return authLoginAnthropicSetupToken() + } + + fmt.Println("Anthropic login method:") + fmt.Println(" 1) Setup token (from `claude setup-token`) (Recommended)") + fmt.Println(" 2) API key (from console.anthropic.com)") + + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Print("Choose [1]: ") + choice := "1" + if scanner.Scan() { + text := strings.TrimSpace(scanner.Text()) + if text != "" { + choice = text + } + } + + switch choice { + case "1": + return authLoginAnthropicSetupToken() + case "2": + return authLoginPasteToken("anthropic") + default: + fmt.Printf("Invalid choice: %s. Please enter 1 or 2.\n", choice) + } + } +} + +func authLoginAnthropicSetupToken() error { + cred, err := auth.LoginSetupToken(os.Stdin) + if err != nil { + return fmt.Errorf("login failed: %w", err) + } + + if err = auth.SetCredential("anthropic", cred); err != nil { + return fmt.Errorf("failed to save credentials: %w", err) + } + + appCfg, err := internal.LoadConfig() + if err == nil { + appCfg.Providers.Anthropic.AuthMethod = "oauth" + + found := false + for i := range appCfg.ModelList { + if isAnthropicModel(appCfg.ModelList[i].Model) { + appCfg.ModelList[i].AuthMethod = "oauth" + found = true + break + } + } + if !found { + appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{ + ModelName: defaultAnthropicModel, + Model: "anthropic/" + defaultAnthropicModel, + AuthMethod: "oauth", + }) + // Only set default model if user has no default configured yet + if appCfg.Agents.Defaults.GetModelName() == "" { + appCfg.Agents.Defaults.ModelName = defaultAnthropicModel + } + } + + if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil { + return fmt.Errorf("could not update config: %w", err) + } + } + + fmt.Println("Setup token saved for Anthropic!") + + return nil +} + func fetchGoogleUserEmail(accessToken string) (string, error) { req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil) if err != nil { @@ -220,13 +299,12 @@ func authLoginPasteToken(provider string) error { } if !found { appCfg.ModelList = append(appCfg.ModelList, config.ModelConfig{ - ModelName: "claude-sonnet-4.6", - Model: "anthropic/claude-sonnet-4.6", + ModelName: defaultAnthropicModel, + Model: "anthropic/" + defaultAnthropicModel, AuthMethod: "token", }) + appCfg.Agents.Defaults.ModelName = defaultAnthropicModel } - // Update default model - appCfg.Agents.Defaults.ModelName = "claude-sonnet-4.6" case "openai": appCfg.Providers.OpenAI.AuthMethod = "token" // Update ModelList @@ -363,6 +441,16 @@ func authStatusCmd() error { if !cred.ExpiresAt.IsZero() { fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) } + + if provider == "anthropic" && cred.AuthMethod == "oauth" { + usage, err := auth.FetchAnthropicUsage(cred.AccessToken) + if err != nil { + fmt.Printf(" Usage: unavailable (%v)\n", err) + } else { + fmt.Printf(" Usage (5h): %.1f%%\n", usage.FiveHourUtilization*100) + fmt.Printf(" Usage (7d): %.1f%%\n", usage.SevenDayUtilization*100) + } + } } return nil diff --git a/cmd/picoclaw/internal/auth/login.go b/cmd/picoclaw/internal/auth/login.go index 9a6d28d2f..afbe098aa 100644 --- a/cmd/picoclaw/internal/auth/login.go +++ b/cmd/picoclaw/internal/auth/login.go @@ -6,6 +6,7 @@ func newLoginCommand() *cobra.Command { var ( provider string useDeviceCode bool + useOauth bool ) cmd := &cobra.Command{ @@ -13,12 +14,16 @@ func newLoginCommand() *cobra.Command { Short: "Login via OAuth or paste token", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, _ []string) error { - return authLoginCmd(provider, useDeviceCode) + return authLoginCmd(provider, useDeviceCode, useOauth) }, } cmd.Flags().StringVarP(&provider, "provider", "p", "", "Provider to login with (openai, anthropic)") cmd.Flags().BoolVar(&useDeviceCode, "device-code", false, "Use device code flow (for headless environments)") + cmd.Flags().BoolVar( + &useOauth, "setup-token", false, + "Use setup-token flow for Anthropic (from `claude setup-token`)", + ) _ = cmd.MarkFlagRequired("provider") return cmd diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 174f5db62..00b53e62c 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -16,6 +16,7 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" _ "github.com/sipeed/picoclaw/pkg/channels/discord" _ "github.com/sipeed/picoclaw/pkg/channels/feishu" + _ "github.com/sipeed/picoclaw/pkg/channels/irc" _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" diff --git a/config/config.example.json b/config/config.example.json index f0e6e0c4b..c42926120 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -165,6 +165,28 @@ "max_steps": 10, "welcome_message": "Hello! I'm your AI assistant. How can I help you today?", "reasoning_channel_id": "" + }, + "irc": { + "enabled": false, + "server": "irc.libera.chat:6697", + "tls": true, + "nick": "mybot", + "user": "", + "real_name": "", + "password": "", + "nickserv_password": "", + "sasl_user": "", + "sasl_password": "", + "channels": ["#mychannel"], + "request_caps": ["server-time", "message-tags"], + "allow_from": [], + "group_trigger": { + "mention_only": true + }, + "typing": { + "enabled": false + }, + "reasoning_channel_id": "" } }, "providers": { diff --git a/docs/agent-refactor/README.md b/docs/agent-refactor/README.md new file mode 100644 index 000000000..db8575fc9 --- /dev/null +++ b/docs/agent-refactor/README.md @@ -0,0 +1,145 @@ +# Agent Refactor + +## What this directory is for + +This directory is the working area for the current Agent refactor. + +The purpose of this refactor is simple: + +the project needs a smaller, clearer, and more stable Agent model before more Agent-related behavior is added. + +The codebase already contains meaningful Agent behavior. What it still lacks is a sufficiently explicit and stable semantic boundary around that behavior. + +This refactor exists to fix that first. + +--- + +## Refactor stance + +This is a maintenance-led consolidation effort. + +It is not a general invitation to expand Agent behavior in parallel. + +During this refactor window, Agent-related work should converge on the current refactor track instead of branching into new semantics. + +That means: + +- concept clarification before feature expansion +- boundary tightening before abstraction growth +- semantic consolidation before new behavior + +--- + +## Core rule: minimum concepts only + +This refactor follows one hard rule: + +**do not introduce a new concept unless it is strictly necessary** + +More explicitly: + +- if an existing concept can be clarified, reuse it +- if an existing boundary can be made explicit, do that first +- if a behavior can be expressed without a new abstraction, do not add one +- "future flexibility" is not enough justification on its own + +The goal of this refactor is not to grow the model. + +The goal is to reduce ambiguity. + +--- + +## What is being clarified + +This refactor is currently concerned with the following questions: + +1. what an `Agent` is +2. what an `AgentLoop` is +3. what the lifecycle of `AgentLoop` is +4. what the event surface around `AgentLoop` is +5. how persona / identity is assembled +6. how capabilities are represented +7. how context boundaries and compression work +8. how subagent coordination works + +These are the current working boundaries. + +If they need to be adjusted, they should be adjusted explicitly rather than drift implicitly in code. + +--- + +## Status of this directory + +The documents here are working materials. + +They are not final or immutable. + +If current notes are incomplete, incorrectly split, or too broad, they should be revised. This directory should evolve with the refactor rather than pretending the first draft is complete. + +--- + +## Suggested document split + +This directory may eventually contain notes such as: + +- `agent-overview.md` + - what an Agent is +- `agent-loop.md` + - AgentLoop contract, lifecycle, event surface +- `persona.md` + - persona and identity assembly +- `capability.md` + - tools / skills / MCP capability semantics +- `context.md` + - context scope, history, summary, compression +- `subagent.md` + - subagent coordination rules + +These files should be added only when they help clarify the current refactor work. + +This directory should not turn into a generic architecture dump. + +--- + +## What this directory is not for + +This directory is not intended for: + +- broad speculative architecture +- future multi-node protocol design not required by the current refactor +- parallel feature planning unrelated to Agent consolidation +- adding new concepts before current ones are made clear + +If a topic does not directly help reduce ambiguity in the current Agent model, it probably does not belong here yet. + +--- + +## Relationship to implementation + +Implementation changes should not keep redefining Agent semantics implicitly. + +If a PR changes or depends on Agent semantics, those semantics should either already exist here or be clarified in a linked issue first. + +This directory is here to make implementation narrower and more disciplined. + +--- + +## Relationship to GitHub tracking + +The umbrella issue for this refactor should point here. + +The issue is the coordination surface. + +This directory is the repository-local working surface. + +--- + +## Summary + +The main question of this refactor is not: + +- what more can Agent do + +The main question is: + +- what is the smallest stable model that current Agent behavior can be organized around diff --git a/go.mod b/go.mod index 6fa3a900c..a8e08deb8 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,14 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/ergochat/irc-go v0.5.0 github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/h2non/filetype v1.1.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mdp/qrterminal/v3 v3.2.1 - github.com/modelcontextprotocol/go-sdk v1.3.0 + github.com/modelcontextprotocol/go-sdk v1.3.1 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 @@ -31,7 +32,7 @@ require ( ) require ( - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.1.1 // indirect github.com/beeper/argo-go v1.1.2 // indirect github.com/coder/websocket v1.8.14 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -48,6 +49,8 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/zerolog v1.34.0 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.3 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/vektah/gqlparser/v2 v2.5.27 // indirect go.mau.fi/libsignal v0.2.1 // indirect diff --git a/go.sum b/go.sum index 060594d06..81a1cdd1e 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= +filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc= @@ -48,6 +48,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elliotchance/orderedmap/v3 v3.1.0 h1:j4DJ5ObEmMBt/lcwIecKcoRxIQUEnw0L804lXYDt/pg= github.com/elliotchance/orderedmap/v3 v3.1.0/go.mod h1:G+Hc2RwaZvJMcS4JpGCOyViCnGeKf0bTYCGTO4uhjSo= +github.com/ergochat/irc-go v0.5.0 h1:woQ1RS9YbfgqPgSpPBBQeczXGIGzR0aC7dEgk469fTw= +github.com/ergochat/irc-go v0.5.0/go.mod h1:2vi7KNpIPWnReB5hmLpl92eMywQvuIeIIGdt/FQCph0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw= @@ -134,8 +136,8 @@ github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4= github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= -github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs= -github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE= +github.com/modelcontextprotocol/go-sdk v1.3.1 h1:TfqtNKOIWN4Z1oqmPAiWDC2Jq7K9OdJaooe0teoXASI= +github.com/modelcontextprotocol/go-sdk v1.3.1/go.mod h1:DgVX498dMD8UJlseK1S5i1T4tFz2fkBk4xogC3D15nw= github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= @@ -171,6 +173,10 @@ github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.3 h1:OjMgICtcSFuNvQCdwqMCv9Tg7lEOXGwm1J5RPQccx6w= +github.com/segmentio/encoding v0.5.3/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index d84aea627..719b0cb6d 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -605,7 +605,60 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message } } - return sanitized + // Second pass: ensure every assistant message with tool_calls has matching + // tool result messages following it. This is required by strict providers + // like DeepSeek that enforce: "An assistant message with 'tool_calls' must + // be followed by tool messages responding to each 'tool_call_id'." + final := make([]providers.Message, 0, len(sanitized)) + for i := 0; i < len(sanitized); i++ { + msg := sanitized[i] + if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { + // Collect expected tool_call IDs + expected := make(map[string]bool, len(msg.ToolCalls)) + for _, tc := range msg.ToolCalls { + expected[tc.ID] = false + } + + // Check following messages for matching tool results + toolMsgCount := 0 + for j := i + 1; j < len(sanitized); j++ { + if sanitized[j].Role != "tool" { + break + } + toolMsgCount++ + if _, exists := expected[sanitized[j].ToolCallID]; exists { + expected[sanitized[j].ToolCallID] = true + } + } + + // If any tool_call_id is missing, drop this assistant message and its partial tool messages + allFound := true + for toolCallID, found := range expected { + if !found { + allFound = false + logger.DebugCF( + "agent", + "Dropping assistant message with incomplete tool results", + map[string]any{ + "missing_tool_call_id": toolCallID, + "expected_count": len(expected), + "found_count": toolMsgCount, + }, + ) + break + } + } + + if !allFound { + // Skip this assistant message and its tool messages + i += toolMsgCount + continue + } + } + final = append(final, msg) + } + + return final } func (cb *ContextBuilder) AddToolResult( diff --git a/pkg/agent/context_test.go b/pkg/agent/context_test.go index e023c9c30..5756ed911 100644 --- a/pkg/agent/context_test.go +++ b/pkg/agent/context_test.go @@ -207,3 +207,77 @@ func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) { } } } + +// TestSanitizeHistoryForProvider_IncompleteToolResults tests the forward validation +// that ensures assistant messages with tool_calls have ALL matching tool results. +// This fixes the DeepSeek error: "An assistant message with 'tool_calls' must be +// followed by tool messages responding to each 'tool_call_id'." +func TestSanitizeHistoryForProvider_IncompleteToolResults(t *testing.T) { + // Assistant expects tool results for both A and B, but only A is present + history := []providers.Message{ + msg("user", "do two things"), + assistantWithTools("A", "B"), + toolResult("A"), + // toolResult("B") is missing - this would cause DeepSeek to fail + msg("user", "next question"), + msg("assistant", "answer"), + } + + result := sanitizeHistoryForProvider(history) + // The assistant message with incomplete tool results should be dropped, + // along with its partial tool result. The remaining messages are: + // user ("do two things"), user ("next question"), assistant ("answer") + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "user", "assistant") +} + +// TestSanitizeHistoryForProvider_MissingAllToolResults tests the case where +// an assistant message has tool_calls but no tool results follow at all. +func TestSanitizeHistoryForProvider_MissingAllToolResults(t *testing.T) { + history := []providers.Message{ + msg("user", "do something"), + assistantWithTools("A"), + // No tool results at all + msg("user", "hello"), + msg("assistant", "hi"), + } + + result := sanitizeHistoryForProvider(history) + // The assistant message with no tool results should be dropped. + // Remaining: user ("do something"), user ("hello"), assistant ("hi") + if len(result) != 3 { + t.Fatalf("expected 3 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "user", "assistant") +} + +// TestSanitizeHistoryForProvider_PartialToolResultsInMiddle tests that +// incomplete tool results in the middle of a conversation are properly handled. +func TestSanitizeHistoryForProvider_PartialToolResultsInMiddle(t *testing.T) { + history := []providers.Message{ + msg("user", "first"), + assistantWithTools("A"), + toolResult("A"), + msg("assistant", "done"), + msg("user", "second"), + assistantWithTools("B", "C"), + toolResult("B"), + // toolResult("C") is missing + msg("user", "third"), + assistantWithTools("D"), + toolResult("D"), + msg("assistant", "all done"), + } + + result := sanitizeHistoryForProvider(history) + // First round is complete (user, assistant+tools, tool, assistant), + // second round is incomplete and dropped (assistant+tools, partial tool), + // third round is complete (user, assistant+tools, tool, assistant). + // Remaining: user, assistant, tool, assistant, user, user, assistant, tool, assistant + if len(result) != 9 { + t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result)) + } + assertRoles(t, result, "user", "assistant", "tool", "assistant", "user", "user", "assistant", "tool", "assistant") +} diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 599ea57fc..97cf0fa05 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -37,6 +37,14 @@ type AgentInstance struct { Subagents *config.SubagentsConfig SkillsFilter []string Candidates []providers.FallbackCandidate + + // Router is non-nil when model routing is configured and the light model + // was successfully resolved. It scores each incoming message and decides + // whether to route to LightCandidates or stay with Candidates. + Router *routing.Router + // LightCandidates holds the resolved provider candidates for the light model. + // Pre-computed at agent creation to avoid repeated model_list lookups at runtime. + LightCandidates []providers.FallbackCandidate } // NewAgentInstance creates an agent instance from config. @@ -180,6 +188,25 @@ func NewAgentInstance( candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList) + // Model routing setup: pre-resolve light model candidates at creation time + // to avoid repeated model_list lookups on every incoming message. + var router *routing.Router + var lightCandidates []providers.FallbackCandidate + if rc := defaults.Routing; rc != nil && rc.Enabled && rc.LightModel != "" { + lightModelCfg := providers.ModelConfig{Primary: rc.LightModel} + resolved := providers.ResolveCandidatesWithLookup(lightModelCfg, defaults.Provider, resolveFromModelList) + if len(resolved) > 0 { + router = routing.New(routing.RouterConfig{ + LightModel: rc.LightModel, + Threshold: rc.Threshold, + }) + lightCandidates = resolved + } else { + log.Printf("routing: light_model %q not found in model_list — routing disabled for agent %q", + rc.LightModel, agentID) + } + } + return &AgentInstance{ ID: agentID, Name: agentName, @@ -200,6 +227,8 @@ func NewAgentInstance( Subagents: subagents, SkillsFilter: skillsFilter, Candidates: candidates, + Router: router, + LightCandidates: lightCandidates, } } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 685b346e6..9a54f5077 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/commands" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" @@ -46,6 +47,7 @@ type AgentLoop struct { channelManager *channels.Manager mediaStore media.MediaStore transcriber voice.Transcriber + cmdRegistry *commands.Registry } // processOptions configures how a message is processed @@ -61,7 +63,15 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." +const ( + defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." + sessionKeyAgentPrefix = "agent:" + metadataKeyAccountID = "account_id" + metadataKeyGuildID = "guild_id" + metadataKeyTeamID = "team_id" + metadataKeyParentPeerKind = "parent_peer_kind" + metadataKeyParentPeerID = "parent_peer_id" +) func NewAgentLoop( cfg *config.Config, @@ -84,14 +94,17 @@ func NewAgentLoop( stateManager = state.NewManager(defaultAgent.Workspace) } - return &AgentLoop{ + al := &AgentLoop{ bus: msgBus, cfg: cfg, registry: registry, state: stateManager, summarizing: sync.Map{}, fallback: fallbackChain, + cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()), } + + return al } // registerSharedTools registers tools that are shared across all agents (web, message, spawn). @@ -170,6 +183,17 @@ func registerSharedTools( agent.Tools.Register(messageTool) } + // Send file tool (outbound media via MediaStore — store injected later by SetMediaStore) + if cfg.Tools.IsToolEnabled("send_file") { + sendFileTool := tools.NewSendFileTool( + agent.Workspace, + cfg.Agents.Defaults.RestrictToWorkspace, + cfg.Agents.Defaults.GetMaxMediaSize(), + nil, + ) + agent.Tools.Register(sendFileTool) + } + // Skill discovery and installation tools skills_enabled := cfg.Tools.IsToolEnabled("skills") find_skills_enable := cfg.Tools.IsToolEnabled("find_skills") @@ -196,7 +220,7 @@ func registerSharedTools( // Spawn tool with allowlist checker if cfg.Tools.IsToolEnabled("spawn") { if cfg.Tools.IsToolEnabled("subagent") { - subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace) subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) spawnTool := tools.NewSpawnTool(subagentManager) currentAgentID := agentID @@ -371,6 +395,13 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { // SetMediaStore injects a MediaStore for media lifecycle management. func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s + + // Propagate store to send_file tools in all agents. + al.registry.ForEachTool("send_file", func(t tools.Tool) { + if sf, ok := t.(*tools.SendFileTool); ok { + sf.SetMediaStore(s) + } + }) } // SetTranscriber injects a voice transcriber for agent-level audio transcription. @@ -549,27 +580,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } - // Check for commands - if response, handled := al.handleCommand(ctx, msg); handled { + route, agent, routeErr := al.resolveMessageRoute(msg) + + // Commands are checked before requiring a successful route. + // Global commands (/help, /show, /switch) work even when routing fails; + // context-dependent commands check their own Runtime fields and report + // "unavailable" when the required capability is nil. + if response, handled := al.handleCommand(ctx, msg, agent); handled { return response, nil } - // Route to determine agent and session key - route := al.registry.ResolveRoute(routing.RouteInput{ - Channel: msg.Channel, - AccountID: msg.Metadata["account_id"], - Peer: extractPeer(msg), - ParentPeer: extractParentPeer(msg), - GuildID: msg.Metadata["guild_id"], - TeamID: msg.Metadata["team_id"], - }) - - agent, ok := al.registry.GetAgent(route.AgentID) - if !ok { - agent = al.registry.GetDefaultAgent() - } - if agent == nil { - return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + if routeErr != nil { + return "", routeErr } // Reset message-tool state for this round so we don't skip publishing due to a previous round. @@ -579,17 +601,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } } - // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) - sessionKey := route.SessionKey - if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { - sessionKey = msg.SessionKey - } + // Resolve session key from route, while preserving explicit agent-scoped keys. + scopeKey := resolveScopeKey(route, msg.SessionKey) + sessionKey := scopeKey logger.InfoCF("agent", "Routed message", map[string]any{ - "agent_id": agent.ID, - "session_key": sessionKey, - "matched_by": route.MatchedBy, + "agent_id": agent.ID, + "scope_key": scopeKey, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + "route_agent": route.AgentID, + "route_channel": route.Channel, }) return al.runAgentLoop(ctx, agent, processOptions{ @@ -604,6 +627,34 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) }) } +func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + AccountID: inboundMetadata(msg, metadataKeyAccountID), + Peer: extractPeer(msg), + ParentPeer: extractParentPeer(msg), + GuildID: inboundMetadata(msg, metadataKeyGuildID), + TeamID: inboundMetadata(msg, metadataKeyTeamID), + }) + + agent, ok := al.registry.GetAgent(route.AgentID) + if !ok { + agent = al.registry.GetDefaultAgent() + } + if agent == nil { + return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID) + } + + return route, agent, nil +} + +func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { + if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) { + return msgSessionKey + } + return route.SessionKey +} + func (al *AgentLoop) processSystemMessage( ctx context.Context, msg bus.InboundMessage, @@ -675,9 +726,8 @@ func (al *AgentLoop) runAgentLoop( agent *AgentInstance, opts processOptions, ) (string, error) { - // 0. Record last channel for heartbeat notifications (skip internal channels) + // 0. Record last channel for heartbeat notifications (skip internal channels and cli) if opts.Channel != "" && opts.ChatID != "" { - // Don't record internal channels (cli, system, subagent) if !constants.IsInternalChannel(opts.Channel) { channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) if err := al.RecordLastChannel(channelKey); err != nil { @@ -824,6 +874,12 @@ func (al *AgentLoop) runLLMIteration( iteration := 0 var finalContent string + // Determine effective model tier for this conversation turn. + // selectCandidates evaluates routing once and the decision is sticky for + // all tool-follow-up iterations within the same turn so that a multi-step + // tool chain doesn't switch models mid-way through. + activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages) + for iteration < agent.MaxIterations { iteration++ @@ -842,7 +898,7 @@ func (al *AgentLoop) runLLMIteration( map[string]any{ "agent_id": agent.ID, "iteration": iteration, - "model": agent.Model, + "model": activeModel, "messages_count": len(messages), "tools_count": len(providerToolDefs), "max_tokens": agent.MaxTokens, @@ -858,7 +914,7 @@ func (al *AgentLoop) runLLMIteration( "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM with fallback chain if candidates are configured. + // Call LLM with fallback chain if multiple candidates are configured. var response *providers.LLMResponse var err error @@ -879,10 +935,10 @@ func (al *AgentLoop) runLLMIteration( } callLLM := func() (*providers.LLMResponse, error) { - if len(agent.Candidates) > 1 && al.fallback != nil { + if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( ctx, - agent.Candidates, + activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts) }, @@ -900,7 +956,7 @@ func (al *AgentLoop) runLLMIteration( } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, llmOpts) + return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts) } // Retry loop for context/token errors @@ -999,9 +1055,12 @@ func (al *AgentLoop) runLLMIteration( "target_channel": al.targetReasoningChannelID(opts.Channel), "channel": opts.Channel, }) - // Check if no tool calls - we're done + // Check if no tool calls - then check reasoning content if any if len(response.ToolCalls) == 0 { finalContent = response.Content + if finalContent == "" && response.ReasoningContent != "" { + finalContent = response.ReasoningContent + } logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]any{ "agent_id": agent.ID, @@ -1087,15 +1146,47 @@ func (al *AgentLoop) runLLMIteration( "iteration": iteration, }) - // Create async callback for tools that implement AsyncExecutor - asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { + // Create async callback for tools that implement AsyncExecutor. + // When the background work completes, this publishes the result + // as an inbound system message so processSystemMessage routes it + // back to the user via the normal agent loop. + asyncCallback := func(_ context.Context, result *tools.ToolResult) { + // Send ForUser content directly to the user (immediate feedback), + // mirroring the synchronous tool execution path. if !result.Silent && result.ForUser != "" { - logger.InfoCF("agent", "Async tool completed, agent will handle notification", - map[string]any{ - "tool": tc.Name, - "content_len": len(result.ForUser), - }) + outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer outCancel() + _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: result.ForUser, + }) } + + // Determine content for the agent loop (ForLLM or error). + content := result.ForLLM + if content == "" && result.Err != nil { + content = result.Err.Error() + } + if content == "" { + return + } + + logger.InfoCF("agent", "Async tool completed, publishing result", + map[string]any{ + "tool": tc.Name, + "content_len": len(content), + "channel": opts.Channel, + }) + + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ + Channel: "system", + SenderID: fmt.Sprintf("async:%s", tc.Name), + ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID), + Content: content, + }) } toolResult := agent.Tools.ExecuteWithContext( @@ -1128,7 +1219,7 @@ func (al *AgentLoop) runLLMIteration( } // If tool returned media refs, publish them as outbound media - if len(r.result.Media) > 0 && opts.SendResponse { + if len(r.result.Media) > 0 { parts := make([]bus.MediaPart, 0, len(r.result.Media)) for _, ref := range r.result.Media { part := bus.MediaPart{Ref: ref} @@ -1169,6 +1260,44 @@ func (al *AgentLoop) runLLMIteration( return finalContent, iteration, nil } +// selectCandidates returns the model candidates and resolved model name to use +// for a conversation turn. When model routing is configured and the incoming +// message scores below the complexity threshold, it returns the light model +// candidates instead of the primary ones. +// +// The returned (candidates, model) pair is used for all LLM calls within one +// turn — tool follow-up iterations use the same tier as the initial call so +// that a multi-step tool chain doesn't switch models mid-way. +func (al *AgentLoop) selectCandidates( + agent *AgentInstance, + userMsg string, + history []providers.Message, +) (candidates []providers.FallbackCandidate, model string) { + if agent.Router == nil || len(agent.LightCandidates) == 0 { + return agent.Candidates, agent.Model + } + + _, usedLight, score := agent.Router.SelectModel(userMsg, history, agent.Model) + if !usedLight { + logger.DebugCF("agent", "Model routing: primary model selected", + map[string]any{ + "agent_id": agent.ID, + "score": score, + "threshold": agent.Router.Threshold(), + }) + return agent.Candidates, agent.Model + } + + logger.InfoCF("agent", "Model routing: light model selected", + map[string]any{ + "agent_id": agent.ID, + "light_model": agent.Router.LightModel(), + "score": score, + "threshold": agent.Router.Threshold(), + }) + return agent.LightCandidates, agent.Router.LightModel() +} + // maybeSummarize triggers summarization if the session history exceeds thresholds. func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) @@ -1460,94 +1589,87 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int { return totalChars * 2 / 5 } -func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) { - content := strings.TrimSpace(msg.Content) - if !strings.HasPrefix(content, "/") { +func (al *AgentLoop) handleCommand( + ctx context.Context, + msg bus.InboundMessage, + agent *AgentInstance, +) (string, bool) { + if !commands.HasCommandPrefix(msg.Content) { return "", false } - parts := strings.Fields(content) - if len(parts) == 0 { + if al.cmdRegistry == nil { return "", false } - cmd := parts[0] - args := parts[1:] + rt := al.buildCommandsRuntime(agent) + executor := commands.NewExecutor(al.cmdRegistry, rt) - switch cmd { - case "/show": - if len(args) < 1 { - return "Usage: /show [model|channel|agents]", true - } - switch args[0] { - case "model": - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent == nil { - return "No default agent configured", true - } - return fmt.Sprintf("Current model: %s", defaultAgent.Model), true - case "channel": - return fmt.Sprintf("Current channel: %s", msg.Channel), true - case "agents": - agentIDs := al.registry.ListAgentIDs() - return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true - default: - return fmt.Sprintf("Unknown show target: %s", args[0]), true - } + var commandReply string + result := executor.Execute(ctx, commands.Request{ + Channel: msg.Channel, + ChatID: msg.ChatID, + SenderID: msg.SenderID, + Text: msg.Content, + Reply: func(text string) error { + commandReply = text + return nil + }, + }) - case "/list": - if len(args) < 1 { - return "Usage: /list [models|channels|agents]", true + switch result.Outcome { + case commands.OutcomeHandled: + if result.Err != nil { + return mapCommandError(result), true } - switch args[0] { - case "models": - return "Available models: configured in config.json per agent", true - case "channels": + if commandReply != "" { + return commandReply, true + } + return "", true + default: // OutcomePassthrough — let the message fall through to LLM + return "", false + } +} + +func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance) *commands.Runtime { + rt := &commands.Runtime{ + Config: al.cfg, + ListAgentIDs: al.registry.ListAgentIDs, + ListDefinitions: al.cmdRegistry.Definitions, + GetEnabledChannels: func() []string { if al.channelManager == nil { - return "Channel manager not initialized", true + return nil } - channels := al.channelManager.GetEnabledChannels() - if len(channels) == 0 { - return "No channels enabled", true - } - return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true - case "agents": - agentIDs := al.registry.ListAgentIDs() - return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true - default: - return fmt.Sprintf("Unknown list target: %s", args[0]), true - } - - case "/switch": - if len(args) < 3 || args[1] != "to" { - return "Usage: /switch [model|channel] to ", true - } - target := args[0] - value := args[2] - - switch target { - case "model": - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent == nil { - return "No default agent configured", true - } - oldModel := defaultAgent.Model - defaultAgent.Model = value - return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true - case "channel": + return al.channelManager.GetEnabledChannels() + }, + SwitchChannel: func(value string) error { if al.channelManager == nil { - return "Channel manager not initialized", true + return fmt.Errorf("channel manager not initialized") } if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { - return fmt.Sprintf("Channel '%s' not found or not enabled", value), true + return fmt.Errorf("channel '%s' not found or not enabled", value) } - return fmt.Sprintf("Switched target channel to %s", value), true - default: - return fmt.Sprintf("Unknown switch target: %s", target), true + return nil + }, + } + if agent != nil { + rt.GetModelInfo = func() (string, string) { + return agent.Model, al.cfg.Agents.Defaults.Provider + } + rt.SwitchModel = func(value string) (string, error) { + oldModel := agent.Model + agent.Model = value + return oldModel, nil } } + return rt +} - return "", false +func mapCommandError(result commands.ExecuteResult) string { + if result.Command == "" { + return fmt.Sprintf("Failed to execute command: %v", result.Err) + } + return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err) } // extractPeer extracts the routing peer from the inbound message's structured Peer field. @@ -1566,10 +1688,17 @@ func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} } +func inboundMetadata(msg bus.InboundMessage, key string) string { + if msg.Metadata == nil { + return "" + } + return msg.Metadata[key] +} + // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { - parentKind := msg.Metadata["parent_peer_kind"] - parentID := msg.Metadata["parent_peer_id"] + parentKind := inboundMetadata(msg, metadataKeyParentPeerKind) + parentID := inboundMetadata(msg, metadataKeyParentPeerID) if parentKind == "" || parentID == "" { return nil } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index aa7d59b5a..2e456fa60 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -318,6 +319,29 @@ func (m *simpleMockProvider) GetDefaultModel() string { return "mock-model" } +type countingMockProvider struct { + response string + calls int +} + +func (m *countingMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *countingMockProvider) GetDefaultModel() string { + return "counting-mock-model" +} + // mockCustomTool is a simple mock tool for registration testing type mockCustomTool struct{} @@ -359,6 +383,198 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms const responseTimeout = 3 * time.Second +func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "ok"} + al := NewAgentLoop(cfg, msgBus, provider) + + msg := bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + Peer: extractPeer(msg), + }) + sessionKey := route.SessionKey + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + + helper := testHelper{al: al} + _ = helper.executeAndGetResponse(t, context.Background(), msg) + + history := defaultAgent.Sessions.GetHistory(sessionKey) + if len(history) != 2 { + t.Fatalf("expected session history len=2, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hello" { + t.Fatalf("unexpected first message in session: %+v", history[0]) + } +} + +func TestProcessMessage_CommandOutcomes(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + Session: config.SessionConfig{ + DMScope: "per-channel-peer", + }, + } + + msgBus := bus.NewMessageBus() + provider := &countingMockProvider{response: "LLM reply"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + baseMsg := bus.InboundMessage{ + Channel: "whatsapp", + SenderID: "user1", + ChatID: "chat1", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + } + + showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/show channel", + Peer: baseMsg.Peer, + }) + if showResp != "Current Channel: whatsapp" { + t.Fatalf("unexpected /show reply: %q", showResp) + } + if provider.calls != 0 { + t.Fatalf("LLM should not be called for handled command, calls=%d", provider.calls) + } + + fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/foo", + Peer: baseMsg.Peer, + }) + if fooResp != "LLM reply" { + t.Fatalf("unexpected /foo reply: %q", fooResp) + } + if provider.calls != 1 { + t.Fatalf("LLM should be called exactly once after /foo passthrough, calls=%d", provider.calls) + } + + newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: baseMsg.Channel, + SenderID: baseMsg.SenderID, + ChatID: baseMsg.ChatID, + Content: "/new", + Peer: baseMsg.Peer, + }) + if newResp != "LLM reply" { + t.Fatalf("unexpected /new reply: %q", newResp) + } + if provider.calls != 2 { + t.Fatalf("LLM should be called for passthrough /new command, calls=%d", provider.calls) + } +} + +func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Provider: "openai", + Model: "before-switch", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &countingMockProvider{response: "LLM reply"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + switchResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/switch model to after-switch", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(switchResp, "Switched model from before-switch to after-switch") { + t.Fatalf("unexpected /switch reply: %q", switchResp) + } + + showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ + Channel: "telegram", + SenderID: "user1", + ChatID: "chat1", + Content: "/show model", + Peer: bus.Peer{ + Kind: "direct", + ID: "user1", + }, + }) + if !strings.Contains(showResp, "Current Model: after-switch (Provider: openai)") { + t.Fatalf("unexpected /show model reply after switch: %q", showResp) + } + + if provider.calls != 0 { + t.Fatalf("LLM should not be called for /switch and /show, calls=%d", provider.calls) + } +} + // TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go index 77b846832..0e7973dc3 100644 --- a/pkg/agent/registry.go +++ b/pkg/agent/registry.go @@ -7,6 +7,7 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/tools" ) // AgentRegistry manages multiple agent instances and routes messages to them. @@ -100,6 +101,19 @@ func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bo return false } +// ForEachTool calls fn for every tool registered under the given name +// across all agents. This is useful for propagating dependencies (e.g. +// MediaStore) to tools after registry construction. +func (r *AgentRegistry) ForEachTool(name string, fn func(tools.Tool)) { + r.mu.RLock() + defer r.mu.RUnlock() + for _, agent := range r.agents { + if t, ok := agent.Tools.Get(name); ok { + fn(t) + } + } +} + // GetDefaultAgent returns the default agent instance. func (r *AgentRegistry) GetDefaultAgent() *AgentInstance { r.mu.RLock() diff --git a/pkg/auth/anthropic_usage.go b/pkg/auth/anthropic_usage.go new file mode 100644 index 000000000..716b2908e --- /dev/null +++ b/pkg/auth/anthropic_usage.go @@ -0,0 +1,71 @@ +package auth + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const ( + anthropicBetaHeader = "oauth-2025-04-20" + anthropicAPIVersion = "2023-06-01" +) + +// anthropicUsageURL is the endpoint for fetching OAuth usage stats. +// It is a var (not const) to allow overriding in tests. +var anthropicUsageURL = "https://api.anthropic.com/api/oauth/usage" + +func setAnthropicUsageURL(url string) { anthropicUsageURL = url } + +type AnthropicUsage struct { + FiveHourUtilization float64 + SevenDayUtilization float64 +} + +func FetchAnthropicUsage(token string) (*AnthropicUsage, error) { + req, err := http.NewRequest("GET", anthropicUsageURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Anthropic-Version", anthropicAPIVersion) + req.Header.Set("Anthropic-Beta", anthropicBetaHeader) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading usage response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusForbidden { + return nil, fmt.Errorf("insufficient scope: usage endpoint requires oauth scope") + } + return nil, fmt.Errorf("usage request failed (%d): %s", resp.StatusCode, string(body)) + } + + var result struct { + FiveHour struct { + Utilization float64 `json:"utilization"` + } `json:"five_hour"` + SevenDay struct { + Utilization float64 `json:"utilization"` + } `json:"seven_day"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parsing usage response: %w", err) + } + + return &AnthropicUsage{ + FiveHourUtilization: result.FiveHour.Utilization, + SevenDayUtilization: result.SevenDay.Utilization, + }, nil +} diff --git a/pkg/auth/anthropic_usage_test.go b/pkg/auth/anthropic_usage_test.go new file mode 100644 index 000000000..ef4a35364 --- /dev/null +++ b/pkg/auth/anthropic_usage_test.go @@ -0,0 +1,98 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestFetchAnthropicUsage_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-token" { + t.Errorf("Authorization = %q, want %q", got, "Bearer test-token") + } + if got := r.Header.Get("Anthropic-Beta"); got != anthropicBetaHeader { + t.Errorf("Anthropic-Beta = %q, want %q", got, anthropicBetaHeader) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"five_hour":{"utilization":0.42},"seven_day":{"utilization":0.85}}`)) + })) + defer srv.Close() + + // Temporarily override the URL by using the test server + origURL := anthropicUsageURL + defer func() { setAnthropicUsageURL(origURL) }() + setAnthropicUsageURL(srv.URL) + + usage, err := FetchAnthropicUsage("test-token") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if usage.FiveHourUtilization != 0.42 { + t.Errorf("FiveHourUtilization = %v, want 0.42", usage.FiveHourUtilization) + } + if usage.SevenDayUtilization != 0.85 { + t.Errorf("SevenDayUtilization = %v, want 0.85", usage.SevenDayUtilization) + } +} + +func TestFetchAnthropicUsage_Forbidden(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer srv.Close() + + origURL := anthropicUsageURL + defer func() { setAnthropicUsageURL(origURL) }() + setAnthropicUsageURL(srv.URL) + + _, err := FetchAnthropicUsage("test-token") + if err == nil { + t.Fatal("expected error for 403, got nil") + } + if !strings.Contains(err.Error(), "insufficient scope") { + t.Errorf("expected 'insufficient scope' error, got %q", err.Error()) + } +} + +func TestFetchAnthropicUsage_ServerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`internal error`)) + })) + defer srv.Close() + + origURL := anthropicUsageURL + defer func() { setAnthropicUsageURL(origURL) }() + setAnthropicUsageURL(srv.URL) + + _, err := FetchAnthropicUsage("test-token") + if err == nil { + t.Fatal("expected error for 500, got nil") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected error containing '500', got %q", err.Error()) + } +} + +func TestFetchAnthropicUsage_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + origURL := anthropicUsageURL + defer func() { setAnthropicUsageURL(origURL) }() + setAnthropicUsageURL(srv.URL) + + _, err := FetchAnthropicUsage("test-token") + if err == nil { + t.Fatal("expected error for malformed JSON, got nil") + } + if !strings.Contains(err.Error(), "parsing usage response") { + t.Errorf("expected 'parsing usage response' error, got %q", err.Error()) + } +} diff --git a/pkg/auth/token.go b/pkg/auth/token.go index a5a13ff03..0e69e60ac 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -31,6 +31,35 @@ func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) { }, nil } +func LoginSetupToken(r io.Reader) (*AuthCredential, error) { + fmt.Println("Paste your setup token from `claude setup-token`:") + fmt.Print("> ") + + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading token: %w", err) + } + return nil, fmt.Errorf("no input received") + } + + token := strings.TrimSpace(scanner.Text()) + + if !strings.HasPrefix(token, "sk-ant-oat01-") { + return nil, fmt.Errorf("invalid setup token: expected prefix sk-ant-oat01-") + } + + if len(token) < 80 { + return nil, fmt.Errorf("invalid setup token: too short (expected at least 80 characters)") + } + + return &AuthCredential{ + AccessToken: token, + Provider: "anthropic", + AuthMethod: "oauth", + }, nil +} + func providerDisplayName(provider string) string { switch provider { case "anthropic": diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go new file mode 100644 index 000000000..673cd9d5d --- /dev/null +++ b/pkg/auth/token_test.go @@ -0,0 +1,61 @@ +package auth + +import ( + "strings" + "testing" +) + +func TestLoginSetupToken(t *testing.T) { + // A valid token: correct prefix + at least 80 chars + validToken := "sk-ant-oat01-" + strings.Repeat("a", 80) + + tests := []struct { + name string + input string + wantErr string + }{ + {"valid token", validToken, ""}, + {"empty input", "", "expected prefix sk-ant-oat01-"}, + {"wrong prefix", "sk-ant-api-" + strings.Repeat("a", 80), "expected prefix sk-ant-oat01-"}, + {"too short", "sk-ant-oat01-short", "too short"}, + {"whitespace only", " ", "expected prefix sk-ant-oat01-"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input + "\n") + cred, err := LoginSetupToken(r) + + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cred.AccessToken != validToken { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, validToken) + } + if cred.Provider != "anthropic" { + t.Errorf("Provider = %q, want %q", cred.Provider, "anthropic") + } + if cred.AuthMethod != "oauth" { + t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") + } + }) + } +} + +func TestLoginSetupToken_EmptyReader(t *testing.T) { + r := strings.NewReader("") + _, err := LoginSetupToken(r) + if err == nil { + t.Fatal("expected error for empty reader, got nil") + } +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 1de910c83..c3bcbff8d 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "os" + "regexp" "strings" "sync" "time" @@ -26,6 +27,12 @@ const ( sendTimeout = 10 * time.Second ) +var ( + // Pre-compiled regexes for resolveDiscordRefs (avoid re-compiling per call) + channelRefRe = regexp.MustCompile(`<#(\d+)>`) + msgLinkRe = regexp.MustCompile(`https://(?:discord\.com|discordapp\.com)/channels/(\d+)/(\d+)/(\d+)`) +) + type DiscordChannel struct { *channels.BaseChannel session *discordgo.Session @@ -338,6 +345,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = c.stripBotMention(content) } + // Resolve Discord refs in main content before concatenation to avoid + // double-expanding links that appear in the referenced message. + content = c.resolveDiscordRefs(s, content, m.GuildID) + + // Prepend referenced (quoted) message content if this is a reply + if m.MessageReference != nil && m.ReferencedMessage != nil { + refContent := m.ReferencedMessage.Content + if refContent != "" { + refAuthor := "unknown" + if m.ReferencedMessage.Author != nil { + refAuthor = m.ReferencedMessage.Author.Username + } + refContent = c.resolveDiscordRefs(s, refContent, m.GuildID) + content = fmt.Sprintf("[quoted message from %s]: %s\n\n%s", + refAuthor, refContent, content) + } + } + senderID := m.Author.ID mediaPaths := make([]string, 0, len(m.Attachments)) @@ -508,6 +533,51 @@ func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error { return nil } +// resolveDiscordRefs resolves channel references (<#id> → #channel-name) and +// expands Discord message links to show the linked message content. +// Only links pointing to the same guild are expanded to prevent cross-guild leakage. +func (c *DiscordChannel) resolveDiscordRefs(s *discordgo.Session, text string, guildID string) string { + // 1. Resolve channel references: <#id> → #channel-name + text = channelRefRe.ReplaceAllStringFunc(text, func(match string) string { + parts := channelRefRe.FindStringSubmatch(match) + if len(parts) < 2 { + return match + } + // Prefer session state cache to avoid API calls + if ch, err := s.State.Channel(parts[1]); err == nil { + return "#" + ch.Name + } + if ch, err := s.Channel(parts[1]); err == nil { + return "#" + ch.Name + } + return match + }) + + // 2. Expand Discord message links (max 3, same guild only) + matches := msgLinkRe.FindAllStringSubmatch(text, 3) + for _, m := range matches { + if len(m) < 4 { + continue + } + linkGuildID, channelID, messageID := m[1], m[2], m[3] + // Security: only expand links from the same guild + if linkGuildID != guildID { + continue + } + msg, err := s.ChannelMessage(channelID, messageID) + if err != nil || msg == nil || msg.Content == "" { + continue + } + author := "unknown" + if msg.Author != nil { + author = msg.Author.Username + } + text += fmt.Sprintf("\n[linked message from %s]: %s", author, msg.Content) + } + + return text +} + // stripBotMention removes the bot mention from the message content. // Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). func (c *DiscordChannel) stripBotMention(text string) string { diff --git a/pkg/channels/discord/discord_resolve_test.go b/pkg/channels/discord/discord_resolve_test.go new file mode 100644 index 000000000..4bc65cc18 --- /dev/null +++ b/pkg/channels/discord/discord_resolve_test.go @@ -0,0 +1,98 @@ +package discord + +import ( + "testing" +) + +func TestChannelRefRegex(t *testing.T) { + tests := []struct { + name string + input string + wantID string + wantOK bool + }{ + {"basic channel ref", "<#123456789>", "123456789", true}, + {"long id", "<#9876543210123456>", "9876543210123456", true}, + {"no match plain text", "hello world", "", false}, + {"no match partial", "<#>", "", false}, + {"no match letters", "<#abc>", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := channelRefRe.FindStringSubmatch(tt.input) + if tt.wantOK { + if len(matches) < 2 || matches[1] != tt.wantID { + t.Errorf("channelRefRe(%q) = %v, want ID %q", tt.input, matches, tt.wantID) + } + } else { + if len(matches) >= 2 { + t.Errorf("channelRefRe(%q) should not match, got %v", tt.input, matches) + } + } + }) + } +} + +func TestMsgLinkRegex(t *testing.T) { + tests := []struct { + name string + input string + wantGuild string + wantChan string + wantMsg string + wantOK bool + }{ + { + "discord.com link", + "https://discord.com/channels/111/222/333", + "111", "222", "333", true, + }, + { + "discordapp.com link", + "https://discordapp.com/channels/111/222/333", + "111", "222", "333", true, + }, + { + "real world ids", + "check this https://discord.com/channels/9000000000000001/9000000000000002/9000000000000003 please", + "9000000000000001", "9000000000000002", "9000000000000003", true, + }, + {"no match http", "http://discord.com/channels/1/2/3", "", "", "", false}, + {"no match missing segment", "https://discord.com/channels/1/2", "", "", "", false}, + {"no match plain text", "hello world", "", "", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := msgLinkRe.FindStringSubmatch(tt.input) + if tt.wantOK { + if len(matches) < 4 { + t.Fatalf("msgLinkRe(%q) didn't match, want guild=%s chan=%s msg=%s", + tt.input, tt.wantGuild, tt.wantChan, tt.wantMsg) + } + if matches[1] != tt.wantGuild || matches[2] != tt.wantChan || matches[3] != tt.wantMsg { + t.Errorf("msgLinkRe(%q) = guild=%s chan=%s msg=%s, want %s/%s/%s", + tt.input, matches[1], matches[2], matches[3], + tt.wantGuild, tt.wantChan, tt.wantMsg) + } + } else { + if len(matches) >= 4 { + t.Errorf("msgLinkRe(%q) should not match, got %v", tt.input, matches) + } + } + }) + } +} + +func TestMsgLinkRegex_MultipleMatches(t *testing.T) { + input := "see https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6 and https://discord.com/channels/7/8/9 and https://discord.com/channels/10/11/12" + matches := msgLinkRe.FindAllStringSubmatch(input, 3) + if len(matches) != 3 { + t.Fatalf("expected 3 matches (capped), got %d", len(matches)) + } + // Verify the 3rd match is 7/8/9 (not 10/11/12) + if matches[2][1] != "7" || matches[2][2] != "8" || matches[2][3] != "9" { + t.Errorf("3rd match = %v, want guild=7 chan=8 msg=9", matches[2]) + } +} diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go index 74caeeac5..b3a493761 100644 --- a/pkg/channels/interfaces.go +++ b/pkg/channels/interfaces.go @@ -1,6 +1,10 @@ package channels -import "context" +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/commands" +) // TypingCapable — channels that can show a typing/thinking indicator. // StartTyping begins the indicator and returns a stop function. @@ -39,3 +43,10 @@ type PlaceholderRecorder interface { RecordTypingStop(channel, chatID string, stop func()) RecordReactionUndo(channel, chatID string, undo func()) } + +// CommandRegistrarCapable is implemented by channels that can register +// command menus with their upstream platform (e.g. Telegram BotCommand). +// Channels that do not support platform-level command menus can ignore it. +type CommandRegistrarCapable interface { + RegisterCommands(ctx context.Context, defs []commands.Definition) error +} diff --git a/pkg/channels/interfaces_command_test.go b/pkg/channels/interfaces_command_test.go new file mode 100644 index 000000000..de5502644 --- /dev/null +++ b/pkg/channels/interfaces_command_test.go @@ -0,0 +1,16 @@ +package channels + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/commands" +) + +type mockRegistrar struct{} + +func (mockRegistrar) RegisterCommands(context.Context, []commands.Definition) error { return nil } + +func TestCommandRegistrarCapable_Compiles(t *testing.T) { + var _ CommandRegistrarCapable = mockRegistrar{} +} diff --git a/pkg/channels/irc/handler.go b/pkg/channels/irc/handler.go new file mode 100644 index 000000000..aca4ddd11 --- /dev/null +++ b/pkg/channels/irc/handler.go @@ -0,0 +1,154 @@ +package irc + +import ( + "fmt" + "strings" + "time" + "unicode" + + "github.com/ergochat/irc-go/ircevent" + "github.com/ergochat/irc-go/ircmsg" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// onConnect is called after a successful connection (and on reconnect). +func (c *IRCChannel) onConnect(conn *ircevent.Connection) { + // NickServ auth (only if SASL is not configured) + if c.config.NickServPassword != "" && c.config.SASLUser == "" { + conn.Privmsg("NickServ", "IDENTIFY "+c.config.NickServPassword) + } + + // Join configured channels + for _, ch := range c.config.Channels { + conn.Join(ch) + logger.InfoCF("irc", "Joined IRC channel", map[string]any{ + "channel": ch, + }) + } +} + +// onPrivmsg handles incoming PRIVMSG events. +func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) { + if len(e.Params) < 2 { + return + } + + nick := e.Nick() + currentNick := conn.CurrentNick() + + // Ignore own messages + if strings.EqualFold(nick, currentNick) { + return + } + + target := e.Params[0] // channel name or bot's nick + content := e.Params[1] // message text + + // Determine if this is a DM or channel message + isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&") + + var chatID string + var peer bus.Peer + + if isDM { + chatID = nick + peer = bus.Peer{Kind: "direct", ID: nick} + } else { + chatID = target + peer = bus.Peer{Kind: "group", ID: target} + } + + sender := bus.SenderInfo{ + Platform: "irc", + PlatformID: nick, + CanonicalID: identity.BuildCanonicalID("irc", nick), + Username: nick, + DisplayName: nick, + } + + if !c.IsAllowedSender(sender) { + return + } + + // For channel messages, check group trigger (mention detection) + if !isDM { + isMentioned := isBotMentioned(content, currentNick) + if isMentioned { + content = stripBotMention(content, currentNick) + } + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + return + } + content = cleaned + } + + if strings.TrimSpace(content) == "" { + return + } + + messageID := fmt.Sprintf("%s-%d", nick, time.Now().UnixNano()) + + metadata := map[string]string{ + "platform": "irc", + "server": c.config.Server, + } + if !isDM { + metadata["channel"] = target + } + + c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender) +} + +// nickMentionedAt returns the byte index where botNick is mentioned in content +// with word-boundary checks, or -1 if not found. Also checks for "nick:" / +// "nick," prefix convention. +func nickMentionedAt(content, botNick string) int { + lower := strings.ToLower(content) + lowerNick := strings.ToLower(botNick) + + // "nick:" or "nick," at start (most common IRC convention) + if strings.HasPrefix(lower, lowerNick+":") || strings.HasPrefix(lower, lowerNick+",") { + return 0 + } + + // Word-boundary match anywhere in the message + idx := strings.Index(lower, lowerNick) + if idx < 0 { + return -1 + } + runes := []rune(lower) + nickRunes := []rune(lowerNick) + endIdx := idx + len(string(nickRunes)) + before := idx == 0 || !unicode.IsLetter(runes[idx-1]) && !unicode.IsDigit(runes[idx-1]) + after := endIdx >= len(lower) || !unicode.IsLetter(rune(lower[endIdx])) && !unicode.IsDigit(rune(lower[endIdx])) + if before && after { + return idx + } + return -1 +} + +// isBotMentioned checks if the bot's nick appears in the message. +func isBotMentioned(content, botNick string) bool { + return nickMentionedAt(content, botNick) >= 0 +} + +// stripBotMention removes "nick: " or "nick, " prefix from content. +func stripBotMention(content, botNick string) string { + idx := nickMentionedAt(content, botNick) + if idx != 0 { + return content + } + lowerNick := strings.ToLower(botNick) + lower := strings.ToLower(content) + for _, sep := range []string{":", ","} { + prefix := lowerNick + sep + if strings.HasPrefix(lower, prefix) { + return strings.TrimSpace(content[len(prefix):]) + } + } + return content +} diff --git a/pkg/channels/irc/init.go b/pkg/channels/irc/init.go new file mode 100644 index 000000000..221d41b62 --- /dev/null +++ b/pkg/channels/irc/init.go @@ -0,0 +1,16 @@ +package irc + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("irc", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + if !cfg.Channels.IRC.Enabled { + return nil, nil + } + return NewIRCChannel(cfg.Channels.IRC, b) + }) +} diff --git a/pkg/channels/irc/irc.go b/pkg/channels/irc/irc.go new file mode 100644 index 000000000..28c59b540 --- /dev/null +++ b/pkg/channels/irc/irc.go @@ -0,0 +1,194 @@ +package irc + +import ( + "context" + "crypto/tls" + "fmt" + "strings" + + "github.com/ergochat/irc-go/ircevent" + "github.com/ergochat/irc-go/ircmsg" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// IRCChannel implements the Channel interface for IRC servers. +type IRCChannel struct { + *channels.BaseChannel + config config.IRCConfig + conn *ircevent.Connection + ctx context.Context + cancel context.CancelFunc +} + +// NewIRCChannel creates a new IRC channel. +func NewIRCChannel(cfg config.IRCConfig, messageBus *bus.MessageBus) (*IRCChannel, error) { + if cfg.Server == "" { + return nil, fmt.Errorf("irc server is required") + } + if cfg.Nick == "" { + return nil, fmt.Errorf("irc nick is required") + } + + base := channels.NewBaseChannel("irc", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(400), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) + + return &IRCChannel{ + BaseChannel: base, + config: cfg, + }, nil +} + +// Start connects to the IRC server and begins listening. +func (c *IRCChannel) Start(ctx context.Context) error { + logger.InfoC("irc", "Starting IRC channel") + c.ctx, c.cancel = context.WithCancel(ctx) + + user := c.config.User + if user == "" { + user = c.config.Nick + } + realName := c.config.RealName + if realName == "" { + realName = c.config.Nick + } + caps := []string(c.config.RequestCaps) + if len(caps) == 0 { + caps = []string{"server-time", "message-tags"} + } + + conn := &ircevent.Connection{ + Server: c.config.Server, + Nick: c.config.Nick, + User: user, + RealName: realName, + Password: c.config.Password, + UseTLS: c.config.TLS, + RequestCaps: caps, + QuitMessage: "Goodbye", + Debug: false, + Log: nil, + } + + if c.config.TLS { + conn.TLSConfig = &tls.Config{ + ServerName: extractHost(c.config.Server), + } + } + + // SASL auth (takes priority over NickServ) + if c.config.SASLUser != "" && c.config.SASLPassword != "" { + conn.SASLLogin = c.config.SASLUser + conn.SASLPassword = c.config.SASLPassword + } + + // Register event handlers + conn.AddConnectCallback(func(e ircmsg.Message) { + c.onConnect(conn) + }) + conn.AddCallback("PRIVMSG", func(e ircmsg.Message) { + c.onPrivmsg(conn, e) + }) + + if err := conn.Connect(); err != nil { + return fmt.Errorf("irc connect failed: %w", err) + } + + c.conn = conn + + // ircevent.Connection.Loop() handles reconnection internally. + go conn.Loop() + + c.SetRunning(true) + logger.InfoCF("irc", "IRC channel started", map[string]any{ + "server": c.config.Server, + "nick": c.config.Nick, + }) + return nil +} + +// Stop disconnects from the IRC server. +func (c *IRCChannel) Stop(ctx context.Context) error { + logger.InfoC("irc", "Stopping IRC channel") + c.SetRunning(false) + + if c.conn != nil { + c.conn.Quit() + } + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("irc", "IRC channel stopped") + return nil +} + +// Send sends a message to an IRC channel or user. +func (c *IRCChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + target := msg.ChatID + if target == "" { + return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) + } + + if strings.TrimSpace(msg.Content) == "" { + return nil + } + + // Send each line separately (IRC is line-oriented) + lines := strings.Split(msg.Content, "\n") + for _, line := range lines { + line = strings.TrimRight(line, "\r") + if line == "" { + continue + } + c.conn.Privmsg(target, line) + } + + logger.DebugCF("irc", "Message sent", map[string]any{ + "target": target, + "lines": len(lines), + }) + return nil +} + +// StartTyping implements channels.TypingCapable using IRCv3 +typing client tag. +// Requires typing.enabled in config and server support for message-tags capability. +func (c *IRCChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + noop := func() {} + + if !c.config.Typing.Enabled || !c.IsRunning() || c.conn == nil { + return noop, nil + } + + // Check if server supports message-tags (required for TAGMSG) + if _, ok := c.conn.AcknowledgedCaps()["message-tags"]; !ok { + return noop, nil + } + + c.conn.SendWithTags(map[string]string{"+typing": "active"}, "TAGMSG", chatID) + + return func() { + if c.IsRunning() && c.conn != nil { + c.conn.SendWithTags(map[string]string{"+typing": "done"}, "TAGMSG", chatID) + } + }, nil +} + +// extractHost returns the hostname portion of a host:port string. +func extractHost(server string) string { + host, _, found := strings.Cut(server, ":") + if found { + return host + } + return server +} diff --git a/pkg/channels/irc/irc_test.go b/pkg/channels/irc/irc_test.go new file mode 100644 index 000000000..168252a4d --- /dev/null +++ b/pkg/channels/irc/irc_test.go @@ -0,0 +1,145 @@ +package irc + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewIRCChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing server", func(t *testing.T) { + cfg := config.IRCConfig{Nick: "bot"} + _, err := NewIRCChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing server, got nil") + } + }) + + t.Run("missing nick", func(t *testing.T) { + cfg := config.IRCConfig{Server: "irc.example.com:6667"} + _, err := NewIRCChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing nick, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.IRCConfig{ + Server: "irc.example.com:6667", + Nick: "testbot", + Channels: []string{"#test"}, + } + ch, err := NewIRCChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "irc" { + t.Errorf("Name() = %q, want %q", ch.Name(), "irc") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestExtractHost(t *testing.T) { + tests := []struct { + server string + want string + }{ + {"irc.libera.chat:6697", "irc.libera.chat"}, + {"localhost:6667", "localhost"}, + {"irc.example.com", "irc.example.com"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.server, func(t *testing.T) { + got := extractHost(tt.server) + if got != tt.want { + t.Errorf("extractHost(%q) = %q, want %q", tt.server, got, tt.want) + } + }) + } +} + +func TestNickMentionedAt(t *testing.T) { + tests := []struct { + name string + content string + nick string + want int + }{ + {"colon prefix", "bot: hello", "bot", 0}, + {"comma prefix", "bot, hello", "bot", 0}, + {"case insensitive", "BOT: hello", "bot", 0}, + {"word boundary mid", "hey bot what's up", "bot", 4}, + {"no mention", "hello world", "bot", -1}, + {"substring mismatch", "robotics are cool", "bot", -1}, + {"nick at end", "hello bot", "bot", 6}, + {"empty content", "", "bot", -1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := nickMentionedAt(tt.content, tt.nick) + if got != tt.want { + t.Errorf("nickMentionedAt(%q, %q) = %d, want %d", tt.content, tt.nick, got, tt.want) + } + }) + } +} + +func TestIsBotMentioned(t *testing.T) { + tests := []struct { + name string + content string + nick string + want bool + }{ + {"colon prefix", "bot: hello", "bot", true}, + {"comma prefix", "bot, hello", "bot", true}, + {"case insensitive", "BOT: hello", "bot", true}, + {"word boundary mid", "hey bot what's up", "bot", true}, + {"no mention", "hello world", "bot", false}, + {"substring mismatch", "robotics are cool", "bot", false}, + {"nick at end", "hello bot", "bot", true}, + {"empty content", "", "bot", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isBotMentioned(tt.content, tt.nick) + if got != tt.want { + t.Errorf("isBotMentioned(%q, %q) = %v, want %v", tt.content, tt.nick, got, tt.want) + } + }) + } +} + +func TestStripBotMention(t *testing.T) { + tests := []struct { + name string + content string + nick string + want string + }{ + {"colon prefix", "bot: hello there", "bot", "hello there"}, + {"comma prefix", "bot, help me", "bot", "help me"}, + {"case insensitive", "BOT: hello", "bot", "hello"}, + {"no prefix match", "hello bot", "bot", "hello bot"}, + {"only prefix", "bot:", "bot", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripBotMention(tt.content, tt.nick) + if got != tt.want { + t.Errorf("stripBotMention(%q, %q) = %q, want %q", tt.content, tt.nick, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index fdd6d0c1f..2b1cf8e84 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -62,6 +62,7 @@ var channelRateConfig = map[string]float64{ "discord": 1, "slack": 1, "line": 10, + "irc": 2, } type channelWorker struct { @@ -267,6 +268,10 @@ func (m *Manager) initChannels() error { m.initChannel("pico", "Pico") } + if m.config.Channels.IRC.Enabled && m.config.Channels.IRC.Server != "" { + m.initChannel("irc", "IRC") + } + logger.InfoCF("channels", "Channel initialization completed", map[string]any{ "enabled_channels": len(m.channels), }) diff --git a/pkg/channels/telegram/command_registration.go b/pkg/channels/telegram/command_registration.go new file mode 100644 index 000000000..d3152ec3d --- /dev/null +++ b/pkg/channels/telegram/command_registration.go @@ -0,0 +1,116 @@ +package telegram + +import ( + "context" + "math/rand" + "slices" + "time" + + "github.com/mymmrac/telego" + + "github.com/sipeed/picoclaw/pkg/commands" + "github.com/sipeed/picoclaw/pkg/logger" +) + +var commandRegistrationBackoff = []time.Duration{ + 5 * time.Second, + 15 * time.Second, + 60 * time.Second, + 5 * time.Minute, + 10 * time.Minute, +} + +func commandRegistrationDelay(attempt int) time.Duration { + if len(commandRegistrationBackoff) == 0 { + return 0 + } + base := commandRegistrationBackoff[min(attempt, len(commandRegistrationBackoff)-1)] + // Full jitter in [0.5, 1.0) to avoid synchronized retries across instances. + return time.Duration(float64(base) * (0.5 + rand.Float64()*0.5)) +} + +// RegisterCommands registers bot commands on Telegram platform. +func (c *TelegramChannel) RegisterCommands(ctx context.Context, defs []commands.Definition) error { + botCommands := make([]telego.BotCommand, 0, len(defs)) + for _, def := range defs { + if def.Name == "" || def.Description == "" { + continue + } + botCommands = append(botCommands, telego.BotCommand{ + Command: def.Name, + Description: def.Description, + }) + } + + current, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{}) + if err != nil { + // If we can't read current commands, fall through to set them. + logger.WarnCF("telegram", "Failed to get current commands, will set unconditionally", + map[string]any{"error": err.Error()}) + } else if slices.Equal(current, botCommands) { + logger.DebugCF("telegram", "Bot commands are up to date", nil) + return nil + } + + return c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ + Commands: botCommands, + }) +} + +func (c *TelegramChannel) startCommandRegistration(ctx context.Context, defs []commands.Definition) { + if len(defs) == 0 { + return + } + + register := c.registerFunc + if register == nil { + register = c.RegisterCommands + } + + regCtx, cancel := context.WithCancel(ctx) + c.commandRegCancel = cancel + + // Registration runs asynchronously so Telegram message intake is never blocked + // by temporary upstream API failures. Retry stops on success or channel shutdown. + go func() { + attempt := 0 + timer := time.NewTimer(0) + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + defer timer.Stop() + for { + err := register(regCtx, defs) + if err == nil { + logger.InfoCF("telegram", "Telegram commands registered", map[string]any{ + "count": len(defs), + }) + return + } + + delay := commandRegistrationDelay(attempt) + logger.WarnCF("telegram", "Telegram command registration failed; will retry", map[string]any{ + "error": err.Error(), + "retry_after": delay.String(), + }) + attempt++ + + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(delay) + + select { + case <-regCtx.Done(): + return + case <-timer.C: + } + } + }() +} diff --git a/pkg/channels/telegram/command_registration_test.go b/pkg/channels/telegram/command_registration_test.go new file mode 100644 index 000000000..26f891b2e --- /dev/null +++ b/pkg/channels/telegram/command_registration_test.go @@ -0,0 +1,96 @@ +package telegram + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/commands" +) + +func TestStartCommandRegistration_DoesNotBlock(t *testing.T) { + ch := &TelegramChannel{} + started := make(chan struct{}, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch.registerFunc = func(context.Context, []commands.Definition) error { + started <- struct{}{} + return errors.New("temporary failure") + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help"}}) + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("registration did not start asynchronously") + } +} + +func TestStartCommandRegistration_RetriesUntilSuccessThenStops(t *testing.T) { + ch := &TelegramChannel{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + origBackoff := commandRegistrationBackoff + commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} + defer func() { commandRegistrationBackoff = origBackoff }() + + var attempts atomic.Int32 + ch.registerFunc = func(context.Context, []commands.Definition) error { + n := attempts.Add(1) + if n < 3 { + return errors.New("temporary failure") + } + return nil + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}}) + + deadline := time.Now().Add(250 * time.Millisecond) + for time.Now().Before(deadline) { + if attempts.Load() >= 3 { + break + } + time.Sleep(5 * time.Millisecond) + } + if attempts.Load() < 3 { + t.Fatalf("expected at least 3 attempts, got %d", attempts.Load()) + } + + stable := attempts.Load() + time.Sleep(30 * time.Millisecond) + if attempts.Load() != stable { + t.Fatalf("expected retries to stop after success, got %d -> %d", stable, attempts.Load()) + } +} + +func TestStartCommandRegistration_StopsAfterCancel(t *testing.T) { + ch := &TelegramChannel{} + ctx, cancel := context.WithCancel(context.Background()) + + origBackoff := commandRegistrationBackoff + commandRegistrationBackoff = []time.Duration{5 * time.Millisecond} + defer func() { commandRegistrationBackoff = origBackoff }() + defer cancel() + + var attempts atomic.Int32 + ch.registerFunc = func(context.Context, []commands.Definition) error { + attempts.Add(1) + return errors.New("always fail") + } + + ch.startCommandRegistration(ctx, []commands.Definition{{Name: "help", Description: "Help"}}) + + time.Sleep(20 * time.Millisecond) + cancel() + time.Sleep(20 * time.Millisecond) // allow in-flight attempt to settle + stable := attempts.Load() + time.Sleep(30 * time.Millisecond) + if attempts.Load() != stable { + t.Fatalf("expected retries to quiesce after cancel, got %d -> %d", stable, attempts.Load()) + } +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index f328f32b8..0a36247a6 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,7 +7,6 @@ import ( "net/url" "os" "regexp" - "slices" "strconv" "strings" "time" @@ -18,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/commands" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" @@ -40,13 +40,15 @@ var ( type TelegramChannel struct { *channels.BaseChannel - bot *telego.Bot - bh *th.BotHandler - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - ctx context.Context - cancel context.CancelFunc + bot *telego.Bot + bh *th.BotHandler + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc + + registerFunc func(context.Context, []commands.Definition) error + commandRegCancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -86,14 +88,13 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann telegramCfg, bus, telegramCfg.AllowFrom, - channels.WithMaxMessageLength(4096), + channels.WithMaxMessageLength(4000), channels.WithGroupTrigger(telegramCfg.GroupTrigger), channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID), ) return &TelegramChannel{ BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), bot: bot, config: cfg, chatIDs: make(map[string]int64), @@ -105,12 +106,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error { c.ctx, c.cancel = context.WithCancel(ctx) - if err := c.initBotCommands(c.ctx); err != nil { - logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{ - "error": err.Error(), - }) - } - updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) @@ -126,21 +121,6 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } c.bh = bh - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Start(ctx, message) - }, th.CommandEqual("start")) - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Help(ctx, message) - }, th.CommandEqual("help")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Show(ctx, message) - }, th.CommandEqual("show")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.List(ctx, message) - }, th.CommandEqual("list")) - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { return c.handleMessage(ctx, &message) }, th.AnyMessage()) @@ -150,6 +130,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error { "username": c.bot.Username(), }) + c.startCommandRegistration(c.ctx, commands.BuiltinDefinitions()) + go func() { if err = bh.Start(); err != nil { logger.ErrorCF("telegram", "Bot handler failed", map[string]any{ @@ -174,50 +156,8 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { if c.cancel != nil { c.cancel() } - - return nil -} - -func (c *TelegramChannel) initBotCommands(ctx context.Context) error { - currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{ - Scope: tu.ScopeDefault(), - }) - if err != nil { - return fmt.Errorf("get commands: %w", err) - } - - commands := []telego.BotCommand{ - { - Command: "start", - Description: "Start the bot", - }, - { - Command: "help", - Description: "Show a help message", - }, - { - Command: "show", - Description: "Show current configuration", - }, - { - Command: "list", - Description: "List available options", - }, - } - - // Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed - if !slices.Equal(currentCommands, commands) { - logger.InfoC("telegram", "Updating bot commands") - - err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ - Commands: commands, - Scope: tu.ScopeDefault(), - }) - if err != nil { - return fmt.Errorf("set commands: %w", err) - } - } else { - logger.DebugC("telegram", "Bot commands are up to date") + if c.commandRegCancel != nil { + c.commandRegCancel() } return nil @@ -233,22 +173,57 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } - htmlContent := markdownToTelegramHTML(msg.Content) + if msg.Content == "" { + return nil + } - // Typing/placeholder handled by Manager.preSend — just send the message + // The Manager already splits messages to ≤4000 chars (WithMaxMessageLength), + // so msg.Content is guaranteed to be within that limit. We still need to + // check if HTML expansion pushes it beyond Telegram's 4096-char API limit. + queue := []string{msg.Content} + for len(queue) > 0 { + chunk := queue[0] + queue = queue[1:] + + htmlContent := markdownToTelegramHTML(chunk) + + if len([]rune(htmlContent)) > 4096 { + ratio := float64(len([]rune(chunk))) / float64(len([]rune(htmlContent))) + smallerLen := int(float64(4096) * ratio * 0.95) // 5% safety margin + if smallerLen < 100 { + smallerLen = 100 + } + // Push sub-chunks back to the front of the queue for + // re-validation instead of sending them blindly. + subChunks := channels.SplitMessage(chunk, smallerLen) + queue = append(subChunks, queue...) + continue + } + + if err := c.sendHTMLChunk(ctx, chatID, htmlContent, chunk); err != nil { + return err + } + } + + return nil +} + +// sendHTMLChunk sends a single HTML message, falling back to the original +// markdown as plain text on parse failure so users never see raw HTML tags. +func (c *TelegramChannel) sendHTMLChunk(ctx context.Context, chatID int64, htmlContent, mdFallback string) error { tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML - if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil { logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{ "error": err.Error(), }) + tgMsg.Text = mdFallback tgMsg.ParseMode = "" if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { return fmt.Errorf("telegram send: %w", channels.ErrTemporary) } } - return nil } @@ -721,34 +696,34 @@ func escapeHTML(text string) string { // isBotMentioned checks if the bot is mentioned in the message via entities. func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { - botUsername := c.bot.Username() - if botUsername == "" { + text, entities := telegramEntityTextAndList(message) + if text == "" || len(entities) == 0 { return false } - entities := message.Entities - if entities == nil { - entities = message.CaptionEntities + botUsername := "" + if c.bot != nil { + botUsername = c.bot.Username() } + runes := []rune(text) for _, entity := range entities { - if entity.Type == "mention" { - // Extract the mention text from the message - text := message.Text - if text == "" { - text = message.Caption - } - runes := []rune(text) - end := entity.Offset + entity.Length - if end <= len(runes) { - mention := string(runes[entity.Offset:end]) - if strings.EqualFold(mention, "@"+botUsername) { - return true - } - } + entityText, ok := telegramEntityText(runes, entity) + if !ok { + continue } - if entity.Type == "text_mention" && entity.User != nil { - if entity.User.Username == botUsername { + + switch entity.Type { + case telego.EntityTypeMention: + if botUsername != "" && strings.EqualFold(entityText, "@"+botUsername) { + return true + } + case telego.EntityTypeTextMention: + if botUsername != "" && entity.User != nil && strings.EqualFold(entity.User.Username, botUsername) { + return true + } + case telego.EntityTypeBotCommand: + if isBotCommandEntityForThisBot(entityText, botUsername) { return true } } @@ -756,6 +731,46 @@ func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { return false } +func telegramEntityTextAndList(message *telego.Message) (string, []telego.MessageEntity) { + if message.Text != "" { + return message.Text, message.Entities + } + return message.Caption, message.CaptionEntities +} + +func telegramEntityText(runes []rune, entity telego.MessageEntity) (string, bool) { + if entity.Offset < 0 || entity.Length <= 0 { + return "", false + } + end := entity.Offset + entity.Length + if entity.Offset >= len(runes) || end > len(runes) { + return "", false + } + return string(runes[entity.Offset:end]), true +} + +func isBotCommandEntityForThisBot(entityText, botUsername string) bool { + if !strings.HasPrefix(entityText, "/") { + return false + } + command := strings.TrimPrefix(entityText, "/") + if command == "" { + return false + } + + at := strings.IndexRune(command, '@') + if at == -1 { + // A bare /command delivered to this bot is intended for this bot. + return true + } + + mentionUsername := command[at+1:] + if mentionUsername == "" || botUsername == "" { + return false + } + return strings.EqualFold(mentionUsername, botUsername) +} + // stripBotMention removes the @bot mention from the content. func (c *TelegramChannel) stripBotMention(content string) string { botUsername := c.bot.Username() diff --git a/pkg/channels/telegram/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go deleted file mode 100644 index 496fc5e4f..000000000 --- a/pkg/channels/telegram/telegram_commands.go +++ /dev/null @@ -1,156 +0,0 @@ -package telegram - -import ( - "context" - "fmt" - "strings" - - "github.com/mymmrac/telego" - - "github.com/sipeed/picoclaw/pkg/config" -) - -type TelegramCommander interface { - Help(ctx context.Context, message telego.Message) error - Start(ctx context.Context, message telego.Message) error - Show(ctx context.Context, message telego.Message) error - List(ctx context.Context, message telego.Message) error -} - -type cmd struct { - bot *telego.Bot - config *config.Config -} - -func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { - return &cmd{ - bot: bot, - config: cfg, - } -} - -func commandArgs(text string) string { - parts := strings.SplitN(text, " ", 2) - if len(parts) < 2 { - return "" - } - return strings.TrimSpace(parts[1]) -} - -func (c *cmd) Help(ctx context.Context, message telego.Message) error { - msg := `/start - Start the bot -/help - Show this help message -/show [model|channel] - Show current configuration -/list [models|channels] - List available options - ` - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: msg, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Start(ctx context.Context, message telego.Message) error { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Hello! I am PicoClaw 🦞", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Show(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /show [model|channel]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "model": - response = fmt.Sprintf("Current Model: %s (Provider: %s)", - c.config.Agents.Defaults.GetModelName(), - c.config.Agents.Defaults.Provider) - case "channel": - response = "Current Channel: telegram" - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) List(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /list [models|channels]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "models": - provider := c.config.Agents.Defaults.Provider - if provider == "" { - provider = "configured default" - } - response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json", - c.config.Agents.Defaults.GetModelName(), provider) - - case "channels": - var enabled []string - if c.config.Channels.Telegram.Enabled { - enabled = append(enabled, "telegram") - } - if c.config.Channels.WhatsApp.Enabled { - enabled = append(enabled, "whatsapp") - } - if c.config.Channels.Feishu.Enabled { - enabled = append(enabled, "feishu") - } - if c.config.Channels.Discord.Enabled { - enabled = append(enabled, "discord") - } - if c.config.Channels.Slack.Enabled { - enabled = append(enabled, "slack") - } - response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) - - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} diff --git a/pkg/channels/telegram/telegram_dispatch_test.go b/pkg/channels/telegram/telegram_dispatch_test.go new file mode 100644 index 000000000..1ea4a4824 --- /dev/null +++ b/pkg/channels/telegram/telegram_dispatch_test.go @@ -0,0 +1,52 @@ +package telegram + +import ( + "context" + "testing" + "time" + + "github.com/mymmrac/telego" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" +) + +func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &TelegramChannel{ + BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil), + chatIDs: make(map[string]int64), + ctx: context.Background(), + } + + msg := &telego.Message{ + Text: "/new", + MessageID: 9, + Chat: telego.Chat{ + ID: 123, + Type: "private", + }, + From: &telego.User{ + ID: 42, + FirstName: "Alice", + }, + } + + if err := ch.handleMessage(context.Background(), msg); err != nil { + t.Fatalf("handleMessage error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "telegram" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/new" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/channels/telegram/telegram_group_command_filter_test.go b/pkg/channels/telegram/telegram_group_command_filter_test.go new file mode 100644 index 000000000..0d5b985fe --- /dev/null +++ b/pkg/channels/telegram/telegram_group_command_filter_test.go @@ -0,0 +1,147 @@ +package telegram + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/mymmrac/telego" + ta "github.com/mymmrac/telego/telegoapi" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +type getMeCaller struct { + username string +} + +func (c getMeCaller) Call(_ context.Context, url string, _ *ta.RequestData) (*ta.Response, error) { + if strings.HasSuffix(url, "/getMe") { + result := fmt.Sprintf(`{"id":1,"is_bot":true,"first_name":"bot","username":%q}`, c.username) + return &ta.Response{Ok: true, Result: []byte(result)}, nil + } + return &ta.Response{Ok: true, Result: []byte("true")}, nil +} + +func newTestTelegramBot(t *testing.T, username string) *telego.Bot { + t.Helper() + + token := "123456:" + strings.Repeat("a", 35) + bot, err := telego.NewBot(token, + telego.WithAPICaller(getMeCaller{username: username}), + telego.WithDiscardLogger(), + ) + if err != nil { + t.Fatalf("NewBot error: %v", err) + } + return bot +} + +func newGroupMentionOnlyChannel(t *testing.T, botUsername string) (*TelegramChannel, *bus.MessageBus) { + t.Helper() + + messageBus := bus.NewMessageBus() + ch := &TelegramChannel{ + BaseChannel: channels.NewBaseChannel("telegram", nil, messageBus, nil, + channels.WithGroupTrigger(config.GroupTriggerConfig{MentionOnly: true}), + ), + bot: newTestTelegramBot(t, botUsername), + chatIDs: make(map[string]int64), + ctx: context.Background(), + } + return ch, messageBus +} + +func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) { + tests := []struct { + name string + text string + wantForwarded bool + wantContent string + }{ + { + name: "command with bot username", + text: "/new@testbot", + wantForwarded: true, + wantContent: "/new", + }, + { + name: "bare command", + text: "/new", + wantForwarded: true, + wantContent: "/new", + }, + { + name: "command for another bot", + text: "/new@otherbot", + wantForwarded: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ch, messageBus := newGroupMentionOnlyChannel(t, "testbot") + + msg := &telego.Message{ + Text: tc.text, + Entities: []telego.MessageEntity{{ + Type: telego.EntityTypeBotCommand, + Offset: 0, + Length: len([]rune(tc.text)), + }}, + MessageID: 42, + Chat: telego.Chat{ + ID: 123, + Type: "group", + }, + From: &telego.User{ + ID: 7, + FirstName: "Alice", + }, + } + + if err := ch.handleMessage(context.Background(), msg); err != nil { + t.Fatalf("handleMessage error: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if tc.wantForwarded { + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Content != tc.wantContent { + t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent) + } + return + } + + if ok { + t.Fatalf("expected message to be filtered, got content=%q", inbound.Content) + } + }) + } +} + +func TestIsBotMentioned_MentionEntityUnaffected(t *testing.T) { + ch, _ := newGroupMentionOnlyChannel(t, "testbot") + + msg := &telego.Message{ + Text: "@testbot hello", + Entities: []telego.MessageEntity{{ + Type: telego.EntityTypeMention, + Offset: 0, + Length: len("@testbot"), + }}, + } + + if !ch.isBotMentioned(msg) { + t.Fatal("expected mention entity to be treated as bot mention") + } +} diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go new file mode 100644 index 000000000..3a2f1aa66 --- /dev/null +++ b/pkg/channels/telegram/telegram_test.go @@ -0,0 +1,273 @@ +package telegram + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/mymmrac/telego" + ta "github.com/mymmrac/telego/telegoapi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" +) + +const testToken = "1234567890:aaaabbbbaaaabbbbaaaabbbbaaaabbbbccc" + +// stubCaller implements ta.Caller for testing. +type stubCaller struct { + calls []stubCall + callFn func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) +} + +type stubCall struct { + URL string + Data *ta.RequestData +} + +func (s *stubCaller) Call(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + s.calls = append(s.calls, stubCall{URL: url, Data: data}) + return s.callFn(ctx, url, data) +} + +// stubConstructor implements ta.RequestConstructor for testing. +type stubConstructor struct{} + +func (s *stubConstructor) JSONRequest(parameters any) (*ta.RequestData, error) { + return &ta.RequestData{}, nil +} + +func (s *stubConstructor) MultipartRequest( + parameters map[string]string, + files map[string]ta.NamedReader, +) (*ta.RequestData, error) { + return &ta.RequestData{}, nil +} + +// successResponse returns a ta.Response that telego will treat as a successful SendMessage. +func successResponse(t *testing.T) *ta.Response { + t.Helper() + msg := &telego.Message{MessageID: 1} + b, err := json.Marshal(msg) + require.NoError(t, err) + return &ta.Response{Ok: true, Result: b} +} + +// newTestChannel creates a TelegramChannel with a mocked bot for unit testing. +func newTestChannel(t *testing.T, caller *stubCaller) *TelegramChannel { + t.Helper() + + bot, err := telego.NewBot(testToken, + telego.WithAPICaller(caller), + telego.WithRequestConstructor(&stubConstructor{}), + telego.WithDiscardLogger(), + ) + require.NoError(t, err) + + base := channels.NewBaseChannel("telegram", nil, nil, nil, + channels.WithMaxMessageLength(4000), + ) + base.SetRunning(true) + + return &TelegramChannel{ + BaseChannel: base, + bot: bot, + chatIDs: make(map[string]int64), + } +} + +func TestSend_EmptyContent(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + t.Fatal("SendMessage should not be called for empty content") + return nil, nil + }, + } + ch := newTestChannel(t, caller) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: "", + }) + + assert.NoError(t, err) + assert.Empty(t, caller.calls, "no API calls should be made for empty content") +} + +func TestSend_ShortMessage_SingleCall(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: "Hello, world!", + }) + + assert.NoError(t, err) + assert.Len(t, caller.calls, 1, "short message should result in exactly one SendMessage call") +} + +func TestSend_LongMessage_SingleCall(t *testing.T) { + // With WithMaxMessageLength(4000), the Manager pre-splits messages before + // they reach Send(). A message at exactly 4000 chars should go through + // as a single SendMessage call (no re-split needed since HTML expansion + // won't exceed 4096 for plain text). + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + + longContent := strings.Repeat("a", 4000) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: longContent, + }) + + assert.NoError(t, err) + assert.Len(t, caller.calls, 1, "pre-split message within limit should result in one SendMessage call") +} + +func TestSend_HTMLFallback_PerChunk(t *testing.T) { + callCount := 0 + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + callCount++ + // Fail on odd calls (HTML attempt), succeed on even calls (plain text fallback) + if callCount%2 == 1 { + return nil, errors.New("Bad Request: can't parse entities") + } + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: "Hello **world**", + }) + + assert.NoError(t, err) + // One short message → 1 HTML attempt (fail) + 1 plain text fallback (success) = 2 calls + assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text fallback") +} + +func TestSend_HTMLFallback_BothFail(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return nil, errors.New("send failed") + }, + } + ch := newTestChannel(t, caller) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: "Hello", + }) + + assert.Error(t, err) + assert.True(t, errors.Is(err, channels.ErrTemporary), "error should wrap ErrTemporary") + assert.Equal(t, 2, len(caller.calls), "should have HTML attempt + plain text attempt") +} + +func TestSend_LongMessage_HTMLFallback_StopsOnError(t *testing.T) { + // With a long message that gets split into 2 chunks, if both HTML and + // plain text fail on the first chunk, Send should return early. + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return nil, errors.New("send failed") + }, + } + ch := newTestChannel(t, caller) + + longContent := strings.Repeat("x", 4001) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: longContent, + }) + + assert.Error(t, err) + // Should fail on the first chunk (2 calls: HTML + fallback), never reaching the second chunk. + assert.Equal(t, 2, len(caller.calls), "should stop after first chunk fails both HTML and plain text") +} + +func TestSend_MarkdownShortButHTMLLong_MultipleCalls(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + + // Create markdown whose length is <= 4000 but whose HTML expansion is much longer. + // "**a** " (6 chars) becomes "a " (9 chars) in HTML, so repeating it many times + // yields HTML that exceeds Telegram's limit while markdown stays within it. + markdownContent := strings.Repeat("**a** ", 600) // 3600 chars markdown, HTML ~5400+ chars + assert.LessOrEqual(t, len([]rune(markdownContent)), 4000, "markdown content must not exceed chunk size") + + htmlExpanded := markdownToTelegramHTML(markdownContent) + assert.Greater( + t, len([]rune(htmlExpanded)), 4096, + "HTML expansion must exceed Telegram limit for this test to be meaningful", + ) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: markdownContent, + }) + + assert.NoError(t, err) + assert.Greater( + t, len(caller.calls), 1, + "markdown-short but HTML-long message should be split into multiple SendMessage calls", + ) +} + +func TestSend_NotRunning(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + t.Fatal("should not be called") + return nil, nil + }, + } + ch := newTestChannel(t, caller) + ch.SetRunning(false) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "12345", + Content: "Hello", + }) + + assert.ErrorIs(t, err, channels.ErrNotRunning) + assert.Empty(t, caller.calls) +} + +func TestSend_InvalidChatID(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + t.Fatal("should not be called") + return nil, nil + }, + } + ch := newTestChannel(t, caller) + + err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "not-a-number", + Content: "Hello", + }) + + assert.Error(t, err) + assert.True(t, errors.Is(err, channels.ErrSendFailed), "error should wrap ErrSendFailed") + assert.Empty(t, caller.calls) +} diff --git a/pkg/channels/whatsapp/whatsapp_command_test.go b/pkg/channels/whatsapp/whatsapp_command_test.go new file mode 100644 index 000000000..ee8aa4a52 --- /dev/null +++ b/pkg/channels/whatsapp/whatsapp_command_test.go @@ -0,0 +1,41 @@ +package whatsapp + +import ( + "context" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &WhatsAppChannel{ + BaseChannel: channels.NewBaseChannel("whatsapp", config.WhatsAppConfig{}, messageBus, nil), + ctx: context.Background(), + } + + ch.handleIncomingMessage(map[string]any{ + "type": "message", + "id": "mid1", + "from": "user1", + "chat": "chat1", + "content": "/help", + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "whatsapp" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/help" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/channels/whatsapp_native/whatsapp_command_test.go b/pkg/channels/whatsapp_native/whatsapp_command_test.go new file mode 100644 index 000000000..cc2dcb619 --- /dev/null +++ b/pkg/channels/whatsapp_native/whatsapp_command_test.go @@ -0,0 +1,56 @@ +//go:build whatsapp_native + +package whatsapp + +import ( + "context" + "testing" + "time" + + "go.mau.fi/whatsmeow/proto/waE2E" + "go.mau.fi/whatsmeow/types" + "go.mau.fi/whatsmeow/types/events" + "google.golang.org/protobuf/proto" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) { + messageBus := bus.NewMessageBus() + ch := &WhatsAppNativeChannel{ + BaseChannel: channels.NewBaseChannel("whatsapp_native", config.WhatsAppConfig{}, messageBus, nil), + runCtx: context.Background(), + } + + evt := &events.Message{ + Info: types.MessageInfo{ + MessageSource: types.MessageSource{ + Sender: types.NewJID("1001", types.DefaultUserServer), + Chat: types.NewJID("1001", types.DefaultUserServer), + }, + ID: "mid1", + PushName: "Alice", + }, + Message: &waE2E.Message{ + Conversation: proto.String("/new"), + }, + } + + ch.handleIncoming(evt) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + inbound, ok := messageBus.ConsumeInbound(ctx) + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "whatsapp_native" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/new" { + t.Fatalf("content=%q", inbound.Content) + } +} diff --git a/pkg/commands/builtin.go b/pkg/commands/builtin.go new file mode 100644 index 000000000..a36dd3eba --- /dev/null +++ b/pkg/commands/builtin.go @@ -0,0 +1,16 @@ +package commands + +// BuiltinDefinitions returns all built-in command definitions. +// Each command group is defined in its own cmd_*.go file. +// Definitions are stateless — runtime dependencies are provided +// via the Runtime parameter passed to handlers at execution time. +func BuiltinDefinitions() []Definition { + return []Definition{ + startCommand(), + helpCommand(), + showCommand(), + listCommand(), + switchCommand(), + checkCommand(), + } +} diff --git a/pkg/commands/builtin_test.go b/pkg/commands/builtin_test.go new file mode 100644 index 000000000..66a84825e --- /dev/null +++ b/pkg/commands/builtin_test.go @@ -0,0 +1,145 @@ +package commands + +import ( + "context" + "strings" + "testing" +) + +func findDefinitionByName(t *testing.T, defs []Definition, name string) Definition { + t.Helper() + for _, def := range defs { + if def.Name == name { + return def + } + } + t.Fatalf("missing /%s definition", name) + return Definition{} +} + +func TestBuiltinHelpHandler_ReturnsFormattedMessage(t *testing.T) { + defs := BuiltinDefinitions() + helpDef := findDefinitionByName(t, defs, "help") + if helpDef.Handler == nil { + t.Fatalf("/help handler should not be nil") + } + + var reply string + err := helpDef.Handler(context.Background(), Request{ + Text: "/help", + Reply: func(text string) error { + reply = text + return nil + }, + }, nil) + if err != nil { + t.Fatalf("/help handler error: %v", err) + } + // Now uses auto-generated EffectiveUsage which includes agents + if !strings.Contains(reply, "/show [model|channel|agents]") { + t.Fatalf("/help reply missing /show usage, got %q", reply) + } + if !strings.Contains(reply, "/list [models|channels|agents]") { + t.Fatalf("/help reply missing /list usage, got %q", reply) + } +} + +func TestBuiltinShowChannel_PreservesUserVisibleBehavior(t *testing.T) { + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), nil) + + cases := []string{"telegram", "whatsapp"} + for _, channel := range cases { + var reply string + res := ex.Execute(context.Background(), Request{ + Channel: channel, + Text: "/show channel", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/show channel on %s: outcome=%v, want=%v", channel, res.Outcome, OutcomeHandled) + } + want := "Current Channel: " + channel + if reply != want { + t.Fatalf("/show channel reply=%q, want=%q", reply, want) + } + } +} + +func TestBuiltinListChannels_UsesGetEnabledChannels(t *testing.T) { + rt := &Runtime{ + GetEnabledChannels: func() []string { + return []string{"telegram", "slack"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/list channels", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/list channels: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "telegram") || !strings.Contains(reply, "slack") { + t.Fatalf("/list channels reply=%q, want telegram and slack", reply) + } +} + +func TestBuiltinShowAgents_RestoresOldBehavior(t *testing.T) { + rt := &Runtime{ + ListAgentIDs: func() []string { + return []string{"default", "coder"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show agents", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/show agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") { + t.Fatalf("/show agents reply=%q, want agent IDs", reply) + } +} + +func TestBuiltinListAgents_RestoresOldBehavior(t *testing.T) { + rt := &Runtime{ + ListAgentIDs: func() []string { + return []string{"default", "coder"} + }, + } + defs := BuiltinDefinitions() + ex := NewExecutor(NewRegistry(defs), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/list agents", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("/list agents: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "default") || !strings.Contains(reply, "coder") { + t.Fatalf("/list agents reply=%q, want agent IDs", reply) + } +} diff --git a/pkg/commands/cmd_check.go b/pkg/commands/cmd_check.go new file mode 100644 index 000000000..f0193dc4f --- /dev/null +++ b/pkg/commands/cmd_check.go @@ -0,0 +1,33 @@ +package commands + +import ( + "context" + "fmt" +) + +func checkCommand() Definition { + return Definition{ + Name: "check", + Description: "Check channel availability", + SubCommands: []SubCommand{ + { + Name: "channel", + Description: "Check if a channel is available", + ArgsUsage: "", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.SwitchChannel == nil { + return req.Reply(unavailableMsg) + } + value := nthToken(req.Text, 2) + if value == "" { + return req.Reply("Usage: /check channel ") + } + if err := rt.SwitchChannel(value); err != nil { + return req.Reply(err.Error()) + } + return req.Reply(fmt.Sprintf("Channel '%s' is available and enabled", value)) + }, + }, + }, + } +} diff --git a/pkg/commands/cmd_help.go b/pkg/commands/cmd_help.go new file mode 100644 index 000000000..94f7f0101 --- /dev/null +++ b/pkg/commands/cmd_help.go @@ -0,0 +1,44 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +func helpCommand() Definition { + return Definition{ + Name: "help", + Description: "Show this help message", + Usage: "/help", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + var defs []Definition + if rt != nil && rt.ListDefinitions != nil { + defs = rt.ListDefinitions() + } else { + defs = BuiltinDefinitions() + } + return req.Reply(formatHelpMessage(defs)) + }, + } +} + +func formatHelpMessage(defs []Definition) string { + if len(defs) == 0 { + return "No commands available." + } + + lines := make([]string, 0, len(defs)) + for _, def := range defs { + usage := def.EffectiveUsage() + if usage == "" { + usage = "/" + def.Name + } + desc := def.Description + if desc == "" { + desc = "No description" + } + lines = append(lines, fmt.Sprintf("%s - %s", usage, desc)) + } + return strings.Join(lines, "\n") +} diff --git a/pkg/commands/cmd_list.go b/pkg/commands/cmd_list.go new file mode 100644 index 000000000..bf47b6e9c --- /dev/null +++ b/pkg/commands/cmd_list.go @@ -0,0 +1,52 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +func listCommand() Definition { + return Definition{ + Name: "list", + Description: "List available options", + SubCommands: []SubCommand{ + { + Name: "models", + Description: "Configured models", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetModelInfo == nil { + return req.Reply(unavailableMsg) + } + name, provider := rt.GetModelInfo() + if provider == "" { + provider = "configured default" + } + return req.Reply(fmt.Sprintf( + "Configured Model: %s\nProvider: %s\n\nTo change models, update config.json", + name, provider, + )) + }, + }, + { + Name: "channels", + Description: "Enabled channels", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetEnabledChannels == nil { + return req.Reply(unavailableMsg) + } + enabled := rt.GetEnabledChannels() + if len(enabled) == 0 { + return req.Reply("No channels enabled") + } + return req.Reply(fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))) + }, + }, + { + Name: "agents", + Description: "Registered agents", + Handler: agentsHandler(), + }, + }, + } +} diff --git a/pkg/commands/cmd_show.go b/pkg/commands/cmd_show.go new file mode 100644 index 000000000..c655e6880 --- /dev/null +++ b/pkg/commands/cmd_show.go @@ -0,0 +1,38 @@ +package commands + +import ( + "context" + "fmt" +) + +func showCommand() Definition { + return Definition{ + Name: "show", + Description: "Show current configuration", + SubCommands: []SubCommand{ + { + Name: "model", + Description: "Current model and provider", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.GetModelInfo == nil { + return req.Reply(unavailableMsg) + } + name, provider := rt.GetModelInfo() + return req.Reply(fmt.Sprintf("Current Model: %s (Provider: %s)", name, provider)) + }, + }, + { + Name: "channel", + Description: "Current channel", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply(fmt.Sprintf("Current Channel: %s", req.Channel)) + }, + }, + { + Name: "agents", + Description: "Registered agents", + Handler: agentsHandler(), + }, + }, + } +} diff --git a/pkg/commands/cmd_start.go b/pkg/commands/cmd_start.go new file mode 100644 index 000000000..8b500aa10 --- /dev/null +++ b/pkg/commands/cmd_start.go @@ -0,0 +1,14 @@ +package commands + +import "context" + +func startCommand() Definition { + return Definition{ + Name: "start", + Description: "Start the bot", + Usage: "/start", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply("Hello! I am PicoClaw 🦞") + }, + } +} diff --git a/pkg/commands/cmd_switch.go b/pkg/commands/cmd_switch.go new file mode 100644 index 000000000..fb8fc109e --- /dev/null +++ b/pkg/commands/cmd_switch.go @@ -0,0 +1,42 @@ +package commands + +import ( + "context" + "fmt" +) + +func switchCommand() Definition { + return Definition{ + Name: "switch", + Description: "Switch model", + SubCommands: []SubCommand{ + { + Name: "model", + Description: "Switch to a different model", + ArgsUsage: "to ", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.SwitchModel == nil { + return req.Reply(unavailableMsg) + } + // Parse: /switch model to + value := nthToken(req.Text, 3) // tokens: [/switch, model, to, ] + if nthToken(req.Text, 2) != "to" || value == "" { + return req.Reply("Usage: /switch model to ") + } + oldModel, err := rt.SwitchModel(value) + if err != nil { + return req.Reply(err.Error()) + } + return req.Reply(fmt.Sprintf("Switched model from %s to %s", oldModel, value)) + }, + }, + { + Name: "channel", + Description: "Moved to /check channel", + Handler: func(_ context.Context, req Request, _ *Runtime) error { + return req.Reply("This command has moved. Please use: /check channel ") + }, + }, + }, + } +} diff --git a/pkg/commands/cmd_switch_test.go b/pkg/commands/cmd_switch_test.go new file mode 100644 index 000000000..59ed305bb --- /dev/null +++ b/pkg/commands/cmd_switch_test.go @@ -0,0 +1,279 @@ +package commands + +import ( + "context" + "fmt" + "testing" +) + +func TestSwitchModel_Success(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old-model", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "Switched model from old-model to gpt-4" + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestSwitchModel_MissingToKeyword(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /switch model to " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitchModel_MissingValue(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /switch model to " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitchModel_Error(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "", fmt.Errorf("model not found") + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to bad-model", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "model not found" { + t.Fatalf("reply=%q, want error message", reply) + } +} + +func TestSwitchModel_NilDep(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Command unavailable in current context." { + t.Fatalf("reply=%q, want unavailable message", reply) + } +} + +func TestSwitchChannel_Redirect(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch channel to telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "This command has moved. Please use: /check channel " + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestCheckChannel_Success(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + want := "Channel 'telegram' is available and enabled" + if reply != want { + t.Fatalf("reply=%q, want=%q", reply, want) + } +} + +func TestCheckChannel_Error(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return fmt.Errorf("channel '%s' not found", value) + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel unknown", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "channel 'unknown' not found" { + t.Fatalf("reply=%q, want error message", reply) + } +} + +func TestCheckChannel_NilDep(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel telegram", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Command unavailable in current context." { + t.Fatalf("reply=%q, want unavailable message", reply) + } +} + +func TestCheckChannel_MissingValue(t *testing.T) { + rt := &Runtime{ + SwitchChannel: func(value string) error { + return nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/check channel", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /check channel " { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestSwitch_BangPrefix(t *testing.T) { + rt := &Runtime{ + SwitchModel: func(value string) (string, error) { + return "old", nil + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "!switch model to gpt-4", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("! prefix: outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Switched model from old to gpt-4" { + t.Fatalf("! prefix: reply=%q, want success message", reply) + } +} + +func TestSwitch_NoSubCommand(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), &Runtime{}) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/switch", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + // Should get usage message from executor's sub-command routing + if reply == "" { + t.Fatal("expected usage reply for bare /switch") + } +} diff --git a/pkg/commands/definition.go b/pkg/commands/definition.go new file mode 100644 index 000000000..7309df317 --- /dev/null +++ b/pkg/commands/definition.go @@ -0,0 +1,48 @@ +package commands + +import ( + "fmt" + "strings" +) + +// SubCommand defines a single sub-command within a parent command. +type SubCommand struct { + Name string + Description string + ArgsUsage string // optional, e.g. "" + Handler Handler +} + +// Definition is the single-source metadata and behavior contract for a slash command. +// +// Design notes (phase 1): +// - Every channel reads command shape from this type instead of keeping local copies. +// - Visibility is global: all definitions are considered available to all channels. +// - Platform menu registration (for example Telegram BotCommand) also derives from this +// same definition so UI labels and runtime behavior stay aligned. +type Definition struct { + Name string + Description string + Usage string // for simple commands; ignored when SubCommands is set + Aliases []string + SubCommands []SubCommand // optional; when set, Executor routes to sub-command handlers + Handler Handler // for simple commands without sub-commands +} + +// EffectiveUsage returns the usage string. When SubCommands are present, +// it is auto-generated from sub-command names so metadata and behavior +// cannot drift. +func (d Definition) EffectiveUsage() string { + if len(d.SubCommands) == 0 { + return d.Usage + } + names := make([]string, 0, len(d.SubCommands)) + for _, sc := range d.SubCommands { + name := sc.Name + if sc.ArgsUsage != "" { + name += " " + sc.ArgsUsage + } + names = append(names, name) + } + return fmt.Sprintf("/%s [%s]", d.Name, strings.Join(names, "|")) +} diff --git a/pkg/commands/definition_test.go b/pkg/commands/definition_test.go new file mode 100644 index 000000000..27ad4a0a2 --- /dev/null +++ b/pkg/commands/definition_test.go @@ -0,0 +1,41 @@ +package commands + +import ( + "testing" +) + +func TestDefinition_EffectiveUsage_NoSubCommands(t *testing.T) { + d := Definition{Name: "start", Usage: "/start"} + if got := d.EffectiveUsage(); got != "/start" { + t.Fatalf("EffectiveUsage()=%q, want %q", got, "/start") + } +} + +func TestDefinition_EffectiveUsage_WithSubCommands(t *testing.T) { + d := Definition{ + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + {Name: "channel"}, + {Name: "agents"}, + }, + } + want := "/show [model|channel|agents]" + if got := d.EffectiveUsage(); got != want { + t.Fatalf("EffectiveUsage()=%q, want %q", got, want) + } +} + +func TestDefinition_EffectiveUsage_WithArgsUsage(t *testing.T) { + d := Definition{ + Name: "session", + SubCommands: []SubCommand{ + {Name: "list"}, + {Name: "resume", ArgsUsage: ""}, + }, + } + want := "/session [list|resume ]" + if got := d.EffectiveUsage(); got != want { + t.Fatalf("EffectiveUsage()=%q, want %q", got, want) + } +} diff --git a/pkg/commands/executor.go b/pkg/commands/executor.go new file mode 100644 index 000000000..78a50e6c2 --- /dev/null +++ b/pkg/commands/executor.go @@ -0,0 +1,89 @@ +package commands + +import ( + "context" + "fmt" +) + +type Outcome int + +const ( + // OutcomePassthrough means this input should continue through normal agent flow. + OutcomePassthrough Outcome = iota + // OutcomeHandled means a command handler executed (with or without handler error). + OutcomeHandled +) + +type ExecuteResult struct { + Outcome Outcome + Command string + Err error +} + +type Executor struct { + reg *Registry + rt *Runtime +} + +func NewExecutor(reg *Registry, rt *Runtime) *Executor { + return &Executor{reg: reg, rt: rt} +} + +// Execute implements a two-state command decision: +// 1) handled: execute command immediately; +// 2) passthrough: not a command or intentionally deferred to agent logic. +func (e *Executor) Execute(ctx context.Context, req Request) ExecuteResult { + cmdName, ok := parseCommandName(req.Text) + if !ok { + return ExecuteResult{Outcome: OutcomePassthrough} + } + + if e == nil || e.reg == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName} + } + + def, found := e.reg.Lookup(cmdName) + if !found { + return ExecuteResult{Outcome: OutcomePassthrough, Command: cmdName} + } + + return e.executeDefinition(ctx, req, def) +} + +func (e *Executor) executeDefinition(ctx context.Context, req Request, def Definition) ExecuteResult { + // Ensure Reply is always non-nil so handlers don't need to check. + if req.Reply == nil { + req.Reply = func(string) error { return nil } + } + + // Simple command — no sub-commands + if len(def.SubCommands) == 0 { + if def.Handler == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name} + } + err := def.Handler(ctx, req, e.rt) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + + // Sub-command routing + subName := nthToken(req.Text, 1) + if subName == "" { + err := req.Reply("Usage: " + def.EffectiveUsage()) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + + normalized := normalizeCommandName(subName) + for _, sc := range def.SubCommands { + if normalizeCommandName(sc.Name) == normalized { + if sc.Handler == nil { + return ExecuteResult{Outcome: OutcomePassthrough, Command: def.Name} + } + err := sc.Handler(ctx, req, e.rt) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} + } + } + + // Unknown sub-command + err := req.Reply(fmt.Sprintf("Unknown option: %s. Usage: %s", subName, def.EffectiveUsage())) + return ExecuteResult{Outcome: OutcomeHandled, Command: def.Name, Err: err} +} diff --git a/pkg/commands/executor_test.go b/pkg/commands/executor_test.go new file mode 100644 index 000000000..09350f1b6 --- /dev/null +++ b/pkg/commands/executor_test.go @@ -0,0 +1,260 @@ +package commands + +import ( + "context" + "errors" + "strings" + "testing" +) + +func TestExecutor_RegisteredWithoutHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{{Name: "show"}} + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/show"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} + +func TestExecutor_UnknownSlashCommand_ReturnsPassthrough(t *testing.T) { + defs := []Definition{{Name: "show"}} + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/unknown"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} + +func TestExecutor_SupportedCommandWithHandler_ReturnsHandled(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help@my_bot"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_AliasWithoutHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + { + Name: "show", + Aliases: []string{"display"}, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "whatsapp", Text: "/display"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "show" { + t.Fatalf("command=%q, want=%q", res.Command, "show") + } +} + +func TestExecutor_AliasWithHandler_ReturnsHandled(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "clear", + Aliases: []string{"reset"}, + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/reset"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if res.Command != "clear" { + t.Fatalf("command=%q, want=%q", res.Command, "clear") + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_SupportedCommandWithNilHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + {Name: "placeholder"}, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder list"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "placeholder" { + t.Fatalf("command=%q, want=%q", res.Command, "placeholder") + } +} + +func TestExecutor_NilHandlerDoesNotMaskLaterHandler(t *testing.T) { + // With Lookup-based dispatch, the first registered definition for a name wins. + // A definition with nil Handler and no SubCommands returns Passthrough. + defs := []Definition{ + {Name: "placeholder"}, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/placeholder"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } + if res.Command != "placeholder" { + t.Fatalf("command=%q, want=%q", res.Command, "placeholder") + } +} + +func TestExecutor_HandlerErrorIsPropagated(t *testing.T) { + wantErr := errors.New("handler failed") + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + return wantErr + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "/help"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !errors.Is(res.Err, wantErr) { + t.Fatalf("err=%v, want=%v", res.Err, wantErr) + } +} + +func TestExecutor_SupportsBangPrefixAndCaseInsensitiveCommand(t *testing.T) { + called := false + defs := []Definition{ + { + Name: "help", + Handler: func(context.Context, Request, *Runtime) error { + called = true + return nil + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Channel: "telegram", Text: "!HELP"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !called { + t.Fatalf("expected handler to be called") + } +} + +func TestExecutor_SubCommand_RoutesToCorrectHandler(t *testing.T) { + modelCalled := false + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model", Handler: func(_ context.Context, _ Request, _ *Runtime) error { + modelCalled = true + return nil + }}, + {Name: "channel"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Text: "/show model"}) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !modelCalled { + t.Fatal("model sub-command handler was not called") + } +} + +func TestExecutor_SubCommand_NoArg_RepliesUsage(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + {Name: "channel"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show", + Reply: func(text string) error { reply = text; return nil }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if reply != "Usage: /show [model|channel]" { + t.Fatalf("reply=%q, want usage message", reply) + } +} + +func TestExecutor_SubCommand_UnknownArg_RepliesError(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + var reply string + res := ex.Execute(context.Background(), Request{ + Text: "/show foobar", + Reply: func(text string) error { reply = text; return nil }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if !strings.Contains(reply, "foobar") { + t.Fatalf("reply=%q, should mention unknown sub-command", reply) + } +} + +func TestExecutor_SubCommand_NilHandler_ReturnsPassthrough(t *testing.T) { + defs := []Definition{ + { + Name: "show", + SubCommands: []SubCommand{ + {Name: "model"}, // nil Handler + }, + }, + } + ex := NewExecutor(NewRegistry(defs), nil) + + res := ex.Execute(context.Background(), Request{Text: "/show model"}) + if res.Outcome != OutcomePassthrough { + t.Fatalf("outcome=%v, want=%v", res.Outcome, OutcomePassthrough) + } +} diff --git a/pkg/commands/handler_agents.go b/pkg/commands/handler_agents.go new file mode 100644 index 000000000..c459516eb --- /dev/null +++ b/pkg/commands/handler_agents.go @@ -0,0 +1,21 @@ +package commands + +import ( + "context" + "fmt" + "strings" +) + +// agentsHandler returns a shared handler for both /show agents and /list agents. +func agentsHandler() Handler { + return func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.ListAgentIDs == nil { + return req.Reply(unavailableMsg) + } + ids := rt.ListAgentIDs() + if len(ids) == 0 { + return req.Reply("No agents registered") + } + return req.Reply(fmt.Sprintf("Registered agents: %s", strings.Join(ids, ", "))) + } +} diff --git a/pkg/commands/registry.go b/pkg/commands/registry.go new file mode 100644 index 000000000..e17d489a6 --- /dev/null +++ b/pkg/commands/registry.go @@ -0,0 +1,55 @@ +package commands + +type Registry struct { + defs []Definition + index map[string]int +} + +// NewRegistry stores the canonical command set used by both dispatch and +// optional platform registration adapters. +func NewRegistry(defs []Definition) *Registry { + stored := make([]Definition, len(defs)) + copy(stored, defs) + + index := make(map[string]int, len(stored)*2) + for i, def := range stored { + registerCommandName(index, def.Name, i) + for _, alias := range def.Aliases { + registerCommandName(index, alias, i) + } + } + + return &Registry{defs: stored, index: index} +} + +// Definitions returns all registered command definitions. +// Command availability is global and no longer channel-scoped. +func (r *Registry) Definitions() []Definition { + out := make([]Definition, len(r.defs)) + copy(out, r.defs) + return out +} + +// Lookup returns a command definition by normalized command name or alias. +func (r *Registry) Lookup(name string) (Definition, bool) { + key := normalizeCommandName(name) + if key == "" { + return Definition{}, false + } + idx, ok := r.index[key] + if !ok { + return Definition{}, false + } + return r.defs[idx], true +} + +func registerCommandName(index map[string]int, name string, defIndex int) { + key := normalizeCommandName(name) + if key == "" { + return + } + if _, exists := index[key]; exists { + return + } + index[key] = defIndex +} diff --git a/pkg/commands/registry_test.go b/pkg/commands/registry_test.go new file mode 100644 index 000000000..bfff76b7c --- /dev/null +++ b/pkg/commands/registry_test.go @@ -0,0 +1,49 @@ +package commands + +import "testing" + +func TestRegistry_Definitions_ReturnsCopy(t *testing.T) { + defs := []Definition{ + {Name: "help", Description: "Show help"}, + {Name: "admin", Description: "Admin command"}, + } + r := NewRegistry(defs) + + got := r.Definitions() + if len(got) != 2 { + t.Fatalf("definitions len = %d, want 2", len(got)) + } + + got[0].Name = "mutated" + again := r.Definitions() + if again[0].Name != "help" { + t.Fatalf("registry should not be mutated by caller, got first name %q", again[0].Name) + } +} + +func TestRegistry_Lookup_MatchesByLowercaseNameAndAlias(t *testing.T) { + r := NewRegistry([]Definition{ + {Name: "Help", Aliases: []string{"Assist"}}, + {Name: "List"}, + }) + + def, ok := r.Lookup("help") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by lowercase name failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("HELP") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by uppercase name failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("assist") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by lowercase alias failed: ok=%v def=%+v", ok, def) + } + + def, ok = r.Lookup("ASSIST") + if !ok || def.Name != "Help" { + t.Fatalf("lookup by uppercase alias failed: ok=%v def=%+v", ok, def) + } +} diff --git a/pkg/commands/request.go b/pkg/commands/request.go new file mode 100644 index 000000000..62ee600f2 --- /dev/null +++ b/pkg/commands/request.go @@ -0,0 +1,75 @@ +package commands + +import ( + "context" + "strings" +) + +type Handler func(ctx context.Context, req Request, rt *Runtime) error + +type Request struct { + Channel string + ChatID string + SenderID string + Text string + Reply func(text string) error +} + +const unavailableMsg = "Command unavailable in current context." + +var commandPrefixes = []string{"/", "!"} + +// parseCommandName accepts "/name", "!name", and Telegram's "/name@bot", then +// normalizes to lowercase command names. +func parseCommandName(input string) (string, bool) { + token := nthToken(input, 0) + if token == "" { + return "", false + } + + name, ok := trimCommandPrefix(token) + if !ok { + return "", false + } + if i := strings.Index(name, "@"); i >= 0 { + name = name[:i] + } + name = normalizeCommandName(name) + if name == "" { + return "", false + } + return name, true +} + +func trimCommandPrefix(token string) (string, bool) { + for _, prefix := range commandPrefixes { + if strings.HasPrefix(token, prefix) { + return strings.TrimPrefix(token, prefix), true + } + } + return "", false +} + +// HasCommandPrefix returns true if the input starts with a recognized +// command prefix (e.g. "/" or "!"). +func HasCommandPrefix(input string) bool { + token := nthToken(input, 0) + if token == "" { + return false + } + _, ok := trimCommandPrefix(token) + return ok +} + +// nthToken returns the 0-indexed token from whitespace-split input. +func nthToken(input string, n int) string { + parts := strings.Fields(strings.TrimSpace(input)) + if n >= len(parts) { + return "" + } + return parts[n] +} + +func normalizeCommandName(name string) string { + return strings.ToLower(strings.TrimSpace(name)) +} diff --git a/pkg/commands/request_test.go b/pkg/commands/request_test.go new file mode 100644 index 000000000..4389e453b --- /dev/null +++ b/pkg/commands/request_test.go @@ -0,0 +1,28 @@ +package commands + +import "testing" + +func TestHasCommandPrefix(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"/help", true}, + {"!help", true}, + {"/switch model to gpt-4", true}, + {"!switch model to gpt-4", true}, + {"hello", false}, + {"", false}, + {" ", false}, + {"hello /world", false}, + {"/", true}, + {"!", true}, + {" /help", true}, + } + for _, tt := range tests { + got := HasCommandPrefix(tt.input) + if got != tt.want { + t.Errorf("HasCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/commands/runtime.go b/pkg/commands/runtime.go new file mode 100644 index 000000000..227d495f4 --- /dev/null +++ b/pkg/commands/runtime.go @@ -0,0 +1,16 @@ +package commands + +import "github.com/sipeed/picoclaw/pkg/config" + +// Runtime provides runtime dependencies to command handlers. It is constructed +// per-request by the agent loop so that per-request state (like session scope) +// can coexist with long-lived callbacks (like GetModelInfo). +type Runtime struct { + Config *config.Config + GetModelInfo func() (name, provider string) + ListAgentIDs func() []string + ListDefinitions func() []Definition + GetEnabledChannels func() []string + SwitchModel func(value string) (oldModel string, err error) + SwitchChannel func(value string) error +} diff --git a/pkg/commands/show_list_handlers_test.go b/pkg/commands/show_list_handlers_test.go new file mode 100644 index 000000000..047708f0f --- /dev/null +++ b/pkg/commands/show_list_handlers_test.go @@ -0,0 +1,85 @@ +package commands + +import ( + "context" + "strings" + "testing" +) + +func TestShowListHandlers_ChannelPolicy(t *testing.T) { + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), nil) + + var telegramReply string + handled := ex.Execute(context.Background(), Request{ + Channel: "telegram", + Text: "/show channel", + Reply: func(text string) error { + telegramReply = text + return nil + }, + }) + if handled.Outcome != OutcomeHandled { + t.Fatalf("telegram /show outcome=%v, want=%v", handled.Outcome, OutcomeHandled) + } + if telegramReply != "Current Channel: telegram" { + t.Fatalf("telegram /show reply=%q, want=%q", telegramReply, "Current Channel: telegram") + } + + var whatsappReply string + handledWhatsApp := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/show channel", + Reply: func(text string) error { + whatsappReply = text + return nil + }, + }) + if handledWhatsApp.Outcome != OutcomeHandled { + t.Fatalf("whatsapp /show outcome=%v, want=%v", handledWhatsApp.Outcome, OutcomeHandled) + } + if handledWhatsApp.Command != "show" { + t.Fatalf("whatsapp /show command=%q, want=%q", handledWhatsApp.Command, "show") + } + if whatsappReply != "Current Channel: whatsapp" { + t.Fatalf("whatsapp /show reply=%q, want=%q", whatsappReply, "Current Channel: whatsapp") + } + + passthrough := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/foo", + }) + if passthrough.Outcome != OutcomePassthrough { + t.Fatalf("whatsapp /foo outcome=%v, want=%v", passthrough.Outcome, OutcomePassthrough) + } + if passthrough.Command != "foo" { + t.Fatalf("whatsapp /foo command=%q, want=%q", passthrough.Command, "foo") + } +} + +func TestShowListHandlers_ListHandledOnAllChannels(t *testing.T) { + rt := &Runtime{ + GetEnabledChannels: func() []string { + return []string{"telegram"} + }, + } + ex := NewExecutor(NewRegistry(BuiltinDefinitions()), rt) + + var reply string + res := ex.Execute(context.Background(), Request{ + Channel: "whatsapp", + Text: "/list channels", + Reply: func(text string) error { + reply = text + return nil + }, + }) + if res.Outcome != OutcomeHandled { + t.Fatalf("whatsapp /list outcome=%v, want=%v", res.Outcome, OutcomeHandled) + } + if res.Command != "list" { + t.Fatalf("whatsapp /list command=%q, want=%q", res.Command, "list") + } + if !strings.Contains(reply, "telegram") { + t.Fatalf("whatsapp /list reply=%q, expected enabled channels content", reply) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index b517d8c70..73733be64 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -167,22 +167,35 @@ type SessionConfig struct { IdentityLinks map[string][]string `json:"identity_links,omitempty"` } +// RoutingConfig controls the intelligent model routing feature. +// When enabled, each incoming message is scored against structural features +// (message length, code blocks, tool call history, conversation depth, attachments). +// Messages scoring below Threshold are sent to LightModel; all others use the +// agent's primary model. This reduces cost and latency for simple tasks without +// requiring any keyword matching — all scoring is language-agnostic. +type RoutingConfig struct { + Enabled bool `json:"enabled"` + LightModel string `json:"light_model"` // model_name from model_list to use for simple tasks + Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model +} + type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead - ModelFallbacks []string `json:"model_fallbacks,omitempty"` - ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` - ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` - SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` - SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` - MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Routing *RoutingConfig `json:"routing,omitempty"` } const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB @@ -218,6 +231,7 @@ type ChannelsConfig struct { WeComApp WeComAppConfig `json:"wecom_app"` WeComAIBot WeComAIBotConfig `json:"wecom_aibot"` Pico PicoConfig `json:"pico"` + IRC IRCConfig `json:"irc"` } // GroupTriggerConfig controls when the bot responds in group chats. @@ -402,6 +416,25 @@ type PicoConfig struct { Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } +type IRCConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_IRC_ENABLED"` + Server string `json:"server" env:"PICOCLAW_CHANNELS_IRC_SERVER"` + TLS bool `json:"tls" env:"PICOCLAW_CHANNELS_IRC_TLS"` + Nick string `json:"nick" env:"PICOCLAW_CHANNELS_IRC_NICK"` + User string `json:"user,omitempty" env:"PICOCLAW_CHANNELS_IRC_USER"` + RealName string `json:"real_name,omitempty" env:"PICOCLAW_CHANNELS_IRC_REAL_NAME"` + Password string `json:"password" env:"PICOCLAW_CHANNELS_IRC_PASSWORD"` + NickServPassword string `json:"nickserv_password" env:"PICOCLAW_CHANNELS_IRC_NICKSERV_PASSWORD"` + SASLUser string `json:"sasl_user" env:"PICOCLAW_CHANNELS_IRC_SASL_USER"` + SASLPassword string `json:"sasl_password" env:"PICOCLAW_CHANNELS_IRC_SASL_PASSWORD"` + Channels FlexibleStringSlice `json:"channels" env:"PICOCLAW_CHANNELS_IRC_CHANNELS"` + RequestCaps FlexibleStringSlice `json:"request_caps,omitempty" env:"PICOCLAW_CHANNELS_IRC_REQUEST_CAPS"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_IRC_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_IRC_REASONING_CHANNEL_ID"` +} + type HeartbeatConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 @@ -427,6 +460,7 @@ type ProvidersConfig struct { ShengSuanYun ProviderConfig `json:"shengsuanyun"` DeepSeek ProviderConfig `json:"deepseek"` Cerebras ProviderConfig `json:"cerebras"` + Vivgrid ProviderConfig `json:"vivgrid"` VolcEngine ProviderConfig `json:"volcengine"` GitHubCopilot ProviderConfig `json:"github_copilot"` Antigravity ProviderConfig `json:"antigravity"` @@ -452,6 +486,7 @@ func (p ProvidersConfig) IsEmpty() bool { p.ShengSuanYun.APIKey == "" && p.ShengSuanYun.APIBase == "" && p.DeepSeek.APIKey == "" && p.DeepSeek.APIBase == "" && p.Cerebras.APIKey == "" && p.Cerebras.APIBase == "" && + p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" && p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" && p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" && @@ -595,6 +630,7 @@ type ExecConfig struct { EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"` CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"` CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"` + TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s) } type SkillsToolsConfig struct { @@ -627,6 +663,7 @@ type ToolsConfig struct { ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"` Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"` + SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"` Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"` SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"` Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"` @@ -900,6 +937,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool { return t.Subagent.Enabled case "web_fetch": return t.WebFetch.Enabled + case "send_file": + return t.SendFile.Enabled case "write_file": return t.WriteFile.Enabled case "mcp": diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index e87d7aa0a..4eef6a79e 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -261,6 +261,14 @@ func DefaultConfig() *Config { APIKey: "", }, + // Vivgrid - https://vivgrid.com + { + ModelName: "vivgrid-auto", + Model: "vivgrid/auto", + APIBase: "https://api.vivgrid.com/v1", + APIKey: "", + }, + // Volcengine (火山引擎) - https://console.volcengine.com/ark { ModelName: "doubao-pro", @@ -386,6 +394,7 @@ func DefaultConfig() *Config { Enabled: true, }, EnableDenyPatterns: true, + TimeoutSeconds: 60, }, Skills: SkillsToolsConfig{ ToolConfig: ToolConfig{ @@ -403,6 +412,9 @@ func DefaultConfig() *Config { TTLSeconds: 300, }, }, + SendFile: ToolConfig{ + Enabled: true, + }, MCP: MCPConfig{ ToolConfig: ToolConfig{ Enabled: false, diff --git a/pkg/config/migration.go b/pkg/config/migration.go index 4a17dd6c9..51f21e4f4 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -292,6 +292,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, true }, }, + { + providerNames: []string{"vivgrid"}, + protocol: "vivgrid", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Vivgrid.APIKey == "" && p.Vivgrid.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "vivgrid", + Model: "vivgrid/auto", + APIKey: p.Vivgrid.APIKey, + APIBase: p.Vivgrid.APIBase, + Proxy: p.Vivgrid.Proxy, + RequestTimeout: p.Vivgrid.RequestTimeout, + }, true + }, + }, { providerNames: []string{"volcengine", "doubao"}, protocol: "volcengine", diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index 67ad73db9..d3019aab0 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -155,7 +155,8 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { ShengSuanYun: ProviderConfig{APIKey: "key11"}, DeepSeek: ProviderConfig{APIKey: "key12"}, Cerebras: ProviderConfig{APIKey: "key13"}, - VolcEngine: ProviderConfig{APIKey: "key14"}, + Vivgrid: ProviderConfig{APIKey: "key14"}, + VolcEngine: ProviderConfig{APIKey: "key15"}, GitHubCopilot: ProviderConfig{ConnectMode: "grpc"}, Antigravity: ProviderConfig{AuthMethod: "oauth"}, Qwen: ProviderConfig{APIKey: "key17"}, @@ -166,9 +167,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { result := ConvertProvidersToModelList(cfg) - // All 20 providers should be converted - if len(result) != 20 { - t.Errorf("len(result) = %d, want 20", len(result)) + // All 21 providers should be converted + if len(result) != 21 { + t.Errorf("len(result) = %d, want 21", len(result)) } } diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 6962041c1..04775ac42 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -190,14 +190,21 @@ func (cs *CronService) executeJobByID(jobID string) { cs.mu.RUnlock() if callbackJob == nil { + log.Printf("[cron] job %s not found, skipping", jobID) return } + // Log job execution start + log.Printf("[cron] ▶ executing job '%s' (id: %s, schedule: %s, channel: %s)", + callbackJob.Name, jobID, callbackJob.Schedule.Kind, callbackJob.Payload.Channel) + var err error if cs.onJob != nil { _, err = cs.onJob(callbackJob) } + execDuration := time.Now().UnixMilli() - startTime + // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() @@ -220,22 +227,35 @@ func (cs *CronService) executeJobByID(jobID string) { if err != nil { job.State.LastStatus = "error" job.State.LastError = err.Error() + log.Printf("[cron] ✗ job '%s' failed after %dms: %v", job.Name, execDuration, err) } else { job.State.LastStatus = "ok" job.State.LastError = "" } // Compute next run time + var nextRunStr string if job.Schedule.Kind == "at" { if job.DeleteAfterRun { cs.removeJobUnsafe(job.ID) + nextRunStr = "(deleted)" } else { job.Enabled = false job.State.NextRunAtMS = nil + nextRunStr = "(disabled)" } } else { nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) job.State.NextRunAtMS = nextRun + if nextRun != nil { + nextRunStr = time.UnixMilli(*nextRun).Format("2006-01-02 15:04:05") + } else { + nextRunStr = "(none)" + } + } + + if err == nil { + log.Printf("[cron] ✓ job '%s' completed in %dms, next run: %s", job.Name, execDuration, nextRunStr) } if err := cs.saveStoreUnsafe(); err != nil { diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index 1b250b9b4..242ded175 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -23,7 +23,10 @@ type ( ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition ) -const defaultBaseURL = "https://api.anthropic.com" +const ( + defaultBaseURL = "https://api.anthropic.com" + anthropicBetaHeader = "oauth-2025-04-20" +) type Provider struct { client *anthropic.Client @@ -80,7 +83,10 @@ func (p *Provider) Chat( if err != nil { return nil, fmt.Errorf("refreshing token: %w", err) } - opts = append(opts, option.WithAuthToken(tok)) + opts = append(opts, + option.WithAuthToken(tok), + option.WithHeader("anthropic-beta", anthropicBetaHeader), + ) } params, err := buildParams(messages, tools, model, options) @@ -88,6 +94,11 @@ func (p *Provider) Chat( return nil, err } + // OAuth/setup-tokens require streaming; API keys use non-streaming. + if p.tokenSource != nil { + return p.chatStreaming(ctx, params, opts) + } + resp, err := p.client.Messages.New(ctx, params, opts...) if err != nil { return nil, fmt.Errorf("claude API call: %w", err) @@ -96,6 +107,28 @@ func (p *Provider) Chat( return parseResponse(resp), nil } +func (p *Provider) chatStreaming( + ctx context.Context, + params anthropic.MessageNewParams, + opts []option.RequestOption, +) (*LLMResponse, error) { + stream := p.client.Messages.NewStreaming(ctx, params, opts...) + defer stream.Close() + + var msg anthropic.Message + for stream.Next() { + event := stream.Current() + if err := msg.Accumulate(event); err != nil { + return nil, fmt.Errorf("claude streaming accumulate: %w", err) + } + } + if err := stream.Err(); err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseResponse(&msg), nil +} + func (p *Provider) GetDefaultModel() string { return "claude-sonnet-4.6" } @@ -147,7 +180,16 @@ func buildParams( blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) } for _, tc := range msg.ToolCalls { - blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + args := tc.Arguments + if args == nil && tc.Function != nil && tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + args = map[string]any{} + } + } + if args == nil { + args = map[string]any{} + } + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, args, tc.Name)) } anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) } else { @@ -167,8 +209,12 @@ func buildParams( maxTokens = int64(mt) } + // Normalize model ID: Anthropic API uses hyphens (claude-sonnet-4-6), + // but config may use dots (claude-sonnet-4.6). + apiModel := strings.ReplaceAll(model, ".", "-") + params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), + Model: anthropic.Model(apiModel), Messages: anthropicMessages, MaxTokens: maxTokens, } diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go index 3d21c1d0b..b1aed17b5 100644 --- a/pkg/providers/anthropic/provider_test.go +++ b/pkg/providers/anthropic/provider_test.go @@ -21,8 +21,8 @@ func TestBuildParams_BasicMessage(t *testing.T) { if err != nil { t.Fatalf("buildParams() error: %v", err) } - if string(params.Model) != "claude-sonnet-4.6" { - t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4.6") + if string(params.Model) != "claude-sonnet-4-6" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-6") } if params.MaxTokens != 1024 { t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) @@ -262,6 +262,65 @@ func TestProvider_ChatUsesTokenSource(t *testing.T) { } } +func TestProvider_ChatStreamingRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" { + t.Errorf("Authorization = %q, want %q", got, "Bearer refreshed-token") + } + if got := r.Header.Get("Anthropic-Beta"); got != anthropicBetaHeader { + t.Errorf("Anthropic-Beta = %q, want %q", got, anthropicBetaHeader) + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + events := []string{ + "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"usage\":{\"input_tokens\":12,\"output_tokens\":0}}}\n\n", + "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n", + "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n", + "event: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "event: message_delta\ndata: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":5}}\n\n", + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", + } + for _, e := range events { + w.Write([]byte(e)) + if flusher != nil { + flusher.Flush() + } + } + })) + defer server.Close() + + p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) { + return "refreshed-token", nil + }, server.URL) + + resp, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "Hello"}}, + nil, + "claude-sonnet-4.6", + map[string]any{}, + ) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello world" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello world") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.CompletionTokens != 5 { + t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens) + } +} + func createAnthropicTestClient(baseURL, token string) *anthropic.Client { c := anthropic.NewClient( anthropicoption.WithAuthToken(token), diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index a0d09a835..25916ad03 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -153,6 +153,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = "https://integrate.api.nvidia.com/v1" } } + case "vivgrid": + if cfg.Providers.Vivgrid.APIKey != "" { + sel.apiKey = cfg.Providers.Vivgrid.APIKey + sel.apiBase = cfg.Providers.Vivgrid.APIBase + sel.proxy = cfg.Providers.Vivgrid.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.vivgrid.com/v1" + } + } case "claude-cli", "claude-code", "claudecode": workspace := cfg.WorkspacePath() if workspace == "" { @@ -295,6 +304,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if sel.apiBase == "" { sel.apiBase = "https://integrate.api.nvidia.com/v1" } + case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "": + sel.apiKey = cfg.Providers.Vivgrid.APIKey + sel.apiBase = cfg.Providers.Vivgrid.APIBase + sel.proxy = cfg.Providers.Vivgrid.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.vivgrid.com/v1" + } case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": sel.apiKey = cfg.Providers.Ollama.APIKey sel.apiBase = cfg.Providers.Ollama.APIBase diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index c05fb0ad4..941985964 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", - "volcengine", "vllm", "qwen", "mistral", "avian": + "vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian": // All other OpenAI-compatible HTTP providers if cfg.APIKey == "" && cfg.APIBase == "" { return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) @@ -200,6 +200,8 @@ func getDefaultAPIBase(protocol string) string { return "https://api.deepseek.com/v1" case "cerebras": return "https://api.cerebras.ai/v1" + case "vivgrid": + return "https://api.vivgrid.com/v1" case "volcengine": return "https://ark.cn-beijing.volces.com/api/v3" case "qwen": diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 78389f331..17bc55d25 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -108,6 +108,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) { {"groq", "groq"}, {"openrouter", "openrouter"}, {"cerebras", "cerebras"}, + {"vivgrid", "vivgrid"}, {"qwen", "qwen"}, {"vllm", "vllm"}, {"deepseek", "deepseek"}, diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go index f7a916d9e..36ccda4a1 100644 --- a/pkg/providers/factory_test.go +++ b/pkg/providers/factory_test.go @@ -88,6 +88,17 @@ func TestResolveProviderSelection(t *testing.T) { wantAPIBase: "https://integrate.api.nvidia.com/v1", wantProxy: "http://127.0.0.1:7890", }, + { + name: "explicit vivgrid provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "vivgrid" + cfg.Providers.Vivgrid.APIKey = "vivgrid-key" + cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.vivgrid.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, { name: "openrouter model uses openrouter defaults", setup: func(cfg *config.Config) { diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 1904ee153..5c868626a 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -1,6 +1,7 @@ package openai_compat import ( + "bufio" "bytes" "context" "encoding/json" @@ -183,19 +184,94 @@ func (p *Provider) Chat( } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } + contentType := resp.Header.Get("Content-Type") + // Non-200: read a prefix to tell HTML error page apart from JSON error body. if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256)) + if readErr != nil { + return nil, fmt.Errorf("failed to read response: %w", readErr) + } + if looksLikeHTML(body, contentType) { + return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase) + } + return nil, fmt.Errorf( + "API request failed:\n Status: %d\n Body: %s", + resp.StatusCode, + responsePreview(body, 128), + ) } - return parseResponse(body) + // Peek without consuming so the full stream reaches the JSON decoder. + reader := bufio.NewReader(resp.Body) + prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort + if err != nil && err != io.EOF && err != bufio.ErrBufferFull { + return nil, fmt.Errorf("failed to inspect response: %w", err) + } + if looksLikeHTML(prefix, contentType) { + return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase) + } + + out, err := parseResponse(reader) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + + return out, nil } -func parseResponse(body []byte) (*LLMResponse, error) { +func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error { + respPreview := responsePreview(body, 128) + return fmt.Errorf( + "API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s", + apiBase, + contentType, + statusCode, + respPreview, + ) +} + +func looksLikeHTML(body []byte, contentType string) bool { + contentType = strings.ToLower(strings.TrimSpace(contentType)) + if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") { + return true + } + prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128)) + return bytes.HasPrefix(prefix, []byte(" len(body) { + end = len(body) + } + return body[i:end] + } + } + return nil +} + +func responsePreview(body []byte, maxLen int) string { + trimmed := bytes.TrimSpace(body) + if len(trimmed) == 0 { + return "" + } + if len(trimmed) <= maxLen { + return string(trimmed) + } + return string(trimmed[:maxLen]) + "..." +} + +func parseResponse(body io.Reader) (*LLMResponse, error) { var apiResponse struct { Choices []struct { Message struct { @@ -222,8 +298,8 @@ func parseResponse(body []byte) (*LLMResponse, error) { Usage *UsageInfo `json:"usage"` } - if err := json.Unmarshal(body, &apiResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + if err := json.NewDecoder(body).Decode(&apiResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) } if len(apiResponse.Choices) == 0 { @@ -363,7 +439,8 @@ func normalizeModel(model, apiBase string) string { prefix := strings.ToLower(before) switch prefix { - case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral": + case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", + "openrouter", "zhipu", "mistral", "vivgrid": return after default: return model diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 174bcf00d..9a3a7acc5 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -1,7 +1,10 @@ package openai_compat import ( + "bytes" "encoding/json" + "fmt" + "io" "net/http" "net/http/httptest" "net/url" @@ -212,6 +215,132 @@ func TestProviderChat_HTTPError(t *testing.T) { } } +func TestProviderChat_JSONHTTPErrorDoesNotReportHTML(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"bad request"}`)) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "Status: 400") { + t.Fatalf("expected status code in error, got %v", err) + } + if strings.Contains(err.Error(), "returned HTML instead of JSON") { + t.Fatalf("expected non-HTML http error, got %v", err) + } +} + +func TestProviderChat_HTMLResponsesReturnHelpfulError(t *testing.T) { + tests := []struct { + name string + contentType string + statusCode int + body string + }{ + { + name: "html success response", + contentType: "text/html; charset=utf-8", + statusCode: http.StatusOK, + body: "gateway login", + }, + { + name: "html error response", + contentType: "text/html; charset=utf-8", + statusCode: http.StatusBadGateway, + body: "bad gateway", + }, + { + name: "mislabeled html success response", + contentType: "application/json", + statusCode: http.StatusOK, + body: " \r\n\tgateway login", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tt.contentType) + w.WriteHeader(tt.statusCode) + _, _ = w.Write([]byte(tt.body)) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), fmt.Sprintf("Status: %d", tt.statusCode)) { + t.Fatalf("expected status code in error, got %v", err) + } + if !strings.Contains(err.Error(), "returned HTML instead of JSON") { + t.Fatalf("expected helpful HTML error, got %v", err) + } + if !strings.Contains(err.Error(), "check api_base or proxy configuration") { + t.Fatalf("expected configuration hint, got %v", err) + } + }) + } +} + +func TestProviderChat_SuccessResponseUsesStreamingDecoder(t *testing.T) { + content := strings.Repeat("a", 1024) + body := `{"choices":[{"message":{"content":"` + content + `"},"finish_reason":"stop"}]}` + + p := NewProvider("key", "https://example.com/v1", "") + p.httpClient = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: &errAfterDataReadCloser{ + data: []byte(body), + chunkSize: 64, + }, + }, nil + }), + } + + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if out.Content != content { + t.Fatalf("Content = %q, want %q", out.Content, content) + } +} + +func TestProviderChat_LargeHTMLResponsePreviewIsTruncated(t *testing.T) { + body := append([]byte(""), bytes.Repeat([]byte("A"), 2048)...) + body = append(body, []byte("")...) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write(body) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "Body: ") { + t.Fatalf("expected html preview in error, got %v", err) + } + if !strings.Contains(err.Error(), "...") { + t.Fatalf("expected truncated preview, got %v", err) + } +} + func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) { var requestBody map[string]any @@ -253,7 +382,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin } } -func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { +func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) { tests := []struct { name string input string @@ -279,6 +408,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { input: "deepseek/deepseek-chat", wantModel: "deepseek-chat", }, + { + name: "strips vivgrid prefix", + input: "vivgrid/auto", + wantModel: "auto", + }, } for _, tt := range tests { @@ -383,6 +517,12 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) { if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" { t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto") } + if got := normalizeModel("vivgrid/managed", "https://api.vivgrid.com/v1"); got != "managed" { + t.Fatalf("normalizeModel(vivgrid) = %q, want %q", got, "managed") + } + if got := normalizeModel("vivgrid/auto", "https://api.vivgrid.com/v1"); got != "auto" { + t.Fatalf("normalizeModel(vivgrid auto) = %q, want %q", got, "auto") + } } func TestProvider_RequestTimeoutDefault(t *testing.T) { @@ -399,6 +539,40 @@ func TestProvider_RequestTimeoutOverride(t *testing.T) { } } +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +type errAfterDataReadCloser struct { + data []byte + chunkSize int + offset int +} + +func (r *errAfterDataReadCloser) Read(p []byte) (int, error) { + if r.offset >= len(r.data) { + return 0, io.ErrUnexpectedEOF + } + + n := r.chunkSize + if n <= 0 || n > len(p) { + n = len(p) + } + remaining := len(r.data) - r.offset + if n > remaining { + n = remaining + } + copy(p, r.data[r.offset:r.offset+n]) + r.offset += n + return n, nil +} + +func (r *errAfterDataReadCloser) Close() error { + return nil +} + func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) { p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens")) if p.maxTokensField != "max_completion_tokens" { diff --git a/pkg/routing/classifier.go b/pkg/routing/classifier.go new file mode 100644 index 000000000..8cddaf069 --- /dev/null +++ b/pkg/routing/classifier.go @@ -0,0 +1,80 @@ +package routing + +// Classifier evaluates a feature set and returns a complexity score in [0, 1]. +// A higher score indicates a more complex task that benefits from a heavy model. +// The score is compared against the configured threshold: score >= threshold selects +// the primary (heavy) model; score < threshold selects the light model. +// +// Classifier is an interface so that future implementations (ML-based, embedding-based, +// or any other approach) can be swapped in without changing routing infrastructure. +type Classifier interface { + Score(f Features) float64 +} + +// RuleClassifier is the v1 implementation. +// It uses a weighted sum of structural signals with no external dependencies, +// no API calls, and sub-microsecond latency. The raw sum is capped at 1.0 so +// that the returned score always falls within the [0, 1] contract. +// +// Individual weights (multiple signals can fire simultaneously): +// +// token > 200 (≈600 chars): 0.35 — very long prompts are almost always complex +// token 50-200: 0.15 — medium length; may or may not be complex +// code block present: 0.40 — coding tasks need the heavy model +// tool calls > 3 (recent): 0.25 — dense tool usage signals an agentic workflow +// tool calls 1-3 (recent): 0.10 — some tool activity +// conversation depth > 10: 0.10 — long sessions carry implicit complexity +// attachments present: 1.00 — hard gate; multi-modal always needs heavy model +// +// Default threshold is 0.35, so: +// - Pure greetings / trivial Q&A: 0.00 → light ✓ +// - Medium prose message (50–200 tokens): 0.15 → light ✓ +// - Message with code block: 0.40 → heavy ✓ +// - Long message (>200 tokens): 0.35 → heavy ✓ +// - Active tool session + medium message: 0.25 → light (acceptable) +// - Any message with an image/audio attachment: 1.00 → heavy ✓ +type RuleClassifier struct{} + +// Score computes the complexity score for the given feature set. +// The returned value is in [0, 1]. Attachments short-circuit to 1.0. +func (c *RuleClassifier) Score(f Features) float64 { + // Hard gate: multi-modal inputs always require the heavy model. + if f.HasAttachments { + return 1.0 + } + + var score float64 + + // Token estimate — primary verbosity signal + switch { + case f.TokenEstimate > 200: + score += 0.35 + case f.TokenEstimate > 50: + score += 0.15 + } + + // Fenced code blocks — strongest indicator of a coding/technical task + if f.CodeBlockCount > 0 { + score += 0.40 + } + + // Recent tool call density — indicates an ongoing agentic workflow + switch { + case f.RecentToolCalls > 3: + score += 0.25 + case f.RecentToolCalls > 0: + score += 0.10 + } + + // Conversation depth — accumulated context implies compound task + if f.ConversationDepth > 10 { + score += 0.10 + } + + // Cap at 1.0 to honor the [0, 1] contract even when multiple signals fire + // simultaneously (e.g., long message + code block + tool chain = 1.10 raw). + if score > 1.0 { + score = 1.0 + } + return score +} diff --git a/pkg/routing/features.go b/pkg/routing/features.go new file mode 100644 index 000000000..c371e21aa --- /dev/null +++ b/pkg/routing/features.go @@ -0,0 +1,127 @@ +package routing + +import ( + "strings" + "unicode/utf8" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// lookbackWindow is the number of recent history entries scanned for tool calls. +// Six entries covers roughly one full tool-use round-trip (user → assistant+tool_call → tool_result → assistant). +const lookbackWindow = 6 + +// Features holds the structural signals extracted from a message and its session context. +// Every dimension is language-agnostic by construction — no keyword or pattern matching +// against natural-language content. This ensures consistent routing for all locales. +type Features struct { + // TokenEstimate is a proxy for token count. + // CJK runes count as 1 token each; non-CJK runes as 0.25 tokens each. + // This avoids API calls while giving accurate estimates for all scripts. + TokenEstimate int + + // CodeBlockCount is the number of fenced code blocks (``` pairs) in the message. + // Coding tasks almost always require the heavy model. + CodeBlockCount int + + // RecentToolCalls is the count of tool_call messages in the last lookbackWindow + // history entries. A high density indicates an active agentic workflow. + RecentToolCalls int + + // ConversationDepth is the total number of messages in the session history. + // Deep sessions tend to carry implicit complexity built up over many turns. + ConversationDepth int + + // HasAttachments is true when the message appears to contain media (images, + // audio, video). Multi-modal inputs require vision-capable heavy models. + HasAttachments bool +} + +// ExtractFeatures computes the structural feature vector for a message. +// It is a pure function with no side effects and zero allocations beyond +// the returned struct. +func ExtractFeatures(msg string, history []providers.Message) Features { + return Features{ + TokenEstimate: estimateTokens(msg), + CodeBlockCount: countCodeBlocks(msg), + RecentToolCalls: countRecentToolCalls(history), + ConversationDepth: len(history), + HasAttachments: hasAttachments(msg), + } +} + +// estimateTokens returns a token count proxy that handles both CJK and Latin text. +// CJK runes (U+2E80–U+9FFF, U+F900–U+FAFF, U+AC00–U+D7AF) map to roughly one +// token each, while non-CJK runes average ~0.25 tokens/rune (≈4 chars per token +// for English). Splitting the count this way avoids the 3x underestimation that a +// flat rune_count/3 would produce for Chinese, Japanese, and Korean text. +func estimateTokens(msg string) int { + total := utf8.RuneCountInString(msg) + if total == 0 { + return 0 + } + cjk := 0 + for _, r := range msg { + if r >= 0x2E80 && r <= 0x9FFF || r >= 0xF900 && r <= 0xFAFF || r >= 0xAC00 && r <= 0xD7AF { + cjk++ + } + } + return cjk + (total-cjk)/4 +} + +// countCodeBlocks counts the number of complete fenced code blocks. +// Each ``` delimiter increments a counter; pairs of delimiters form one block. +// An unclosed opening fence (odd count) is treated as zero complete blocks +// since it may just be an inline code span or a typo. +func countCodeBlocks(msg string) int { + n := strings.Count(msg, "```") + return n / 2 +} + +// countRecentToolCalls counts messages with tool calls in the last lookbackWindow +// entries of history. It examines the ToolCalls field rather than parsing +// the content string, so it is robust to any message format. +func countRecentToolCalls(history []providers.Message) int { + start := len(history) - lookbackWindow + if start < 0 { + start = 0 + } + + count := 0 + for _, msg := range history[start:] { + if len(msg.ToolCalls) > 0 { + count += len(msg.ToolCalls) + } + } + return count +} + +// hasAttachments returns true when the message content contains embedded media. +// It checks for base64 data URIs (data:image/, data:audio/, data:video/) and +// common image/audio URL extensions. This is intentionally conservative — +// false negatives (missing an attachment) just mean the routing falls back to +// the primary model anyway. +func hasAttachments(msg string) bool { + lower := strings.ToLower(msg) + + // Base64 data URIs embedded directly in the message + if strings.Contains(lower, "data:image/") || + strings.Contains(lower, "data:audio/") || + strings.Contains(lower, "data:video/") { + return true + } + + // Common image/audio extensions in URLs or file references + mediaExts := []string{ + ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", + ".mp3", ".wav", ".ogg", ".m4a", ".flac", + ".mp4", ".avi", ".mov", ".webm", + } + for _, ext := range mediaExts { + if strings.Contains(lower, ext) { + return true + } + } + + return false +} diff --git a/pkg/routing/router.go b/pkg/routing/router.go new file mode 100644 index 000000000..b1fa347e9 --- /dev/null +++ b/pkg/routing/router.go @@ -0,0 +1,82 @@ +package routing + +import ( + "github.com/sipeed/picoclaw/pkg/providers" +) + +// defaultThreshold is used when the config threshold is zero or negative. +// At 0.35 a message needs at least one strong signal (code block, long text, +// or an attachment) before the heavy model is chosen. +const defaultThreshold = 0.35 + +// RouterConfig holds the validated model routing settings. +// It mirrors config.RoutingConfig but lives in pkg/routing to keep the +// dependency graph simple: pkg/agent resolves config → routing, not the reverse. +type RouterConfig struct { + // LightModel is the model_name (from model_list) used for simple tasks. + LightModel string + + // Threshold is the complexity score cutoff in [0, 1]. + // score >= Threshold → primary (heavy) model. + // score < Threshold → light model. + Threshold float64 +} + +// Router selects the appropriate model tier for each incoming message. +// It is safe for concurrent use from multiple goroutines. +type Router struct { + cfg RouterConfig + classifier Classifier +} + +// New creates a Router with the given config and the default RuleClassifier. +// If cfg.Threshold is zero or negative, defaultThreshold (0.35) is used. +func New(cfg RouterConfig) *Router { + if cfg.Threshold <= 0 { + cfg.Threshold = defaultThreshold + } + return &Router{ + cfg: cfg, + classifier: &RuleClassifier{}, + } +} + +// newWithClassifier creates a Router with a custom Classifier. +// Intended for unit tests that need to inject a deterministic scorer. +func newWithClassifier(cfg RouterConfig, c Classifier) *Router { + if cfg.Threshold <= 0 { + cfg.Threshold = defaultThreshold + } + return &Router{cfg: cfg, classifier: c} +} + +// SelectModel returns the model to use for this conversation turn along with +// the computed complexity score (for logging and debugging). +// +// - If score < cfg.Threshold: returns (cfg.LightModel, true, score) +// - Otherwise: returns (primaryModel, false, score) +// +// The caller is responsible for resolving the returned model name into +// provider candidates (see AgentInstance.LightCandidates). +func (r *Router) SelectModel( + msg string, + history []providers.Message, + primaryModel string, +) (model string, usedLight bool, score float64) { + features := ExtractFeatures(msg, history) + score = r.classifier.Score(features) + if score < r.cfg.Threshold { + return r.cfg.LightModel, true, score + } + return primaryModel, false, score +} + +// LightModel returns the configured light model name. +func (r *Router) LightModel() string { + return r.cfg.LightModel +} + +// Threshold returns the complexity threshold in use. +func (r *Router) Threshold() float64 { + return r.cfg.Threshold +} diff --git a/pkg/routing/router_test.go b/pkg/routing/router_test.go new file mode 100644 index 000000000..2824d10ab --- /dev/null +++ b/pkg/routing/router_test.go @@ -0,0 +1,414 @@ +package routing + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// ── ExtractFeatures ────────────────────────────────────────────────────────── + +func TestExtractFeatures_EmptyMessage(t *testing.T) { + f := ExtractFeatures("", nil) + if f.TokenEstimate != 0 { + t.Errorf("TokenEstimate: got %d, want 0", f.TokenEstimate) + } + if f.CodeBlockCount != 0 { + t.Errorf("CodeBlockCount: got %d, want 0", f.CodeBlockCount) + } + if f.RecentToolCalls != 0 { + t.Errorf("RecentToolCalls: got %d, want 0", f.RecentToolCalls) + } + if f.ConversationDepth != 0 { + t.Errorf("ConversationDepth: got %d, want 0", f.ConversationDepth) + } + if f.HasAttachments { + t.Error("HasAttachments: got true, want false") + } +} + +func TestExtractFeatures_TokenEstimate(t *testing.T) { + // 30 ASCII runes: 0 CJK + 30/4 = 7 tokens + msg := strings.Repeat("a", 30) + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 7 { + t.Errorf("TokenEstimate: got %d, want 7", f.TokenEstimate) + } +} + +func TestExtractFeatures_TokenEstimate_CJK(t *testing.T) { + // 9 CJK runes → 9 tokens (each CJK rune ≈ 1 token). + // Using a rune slice literal avoids CJK string literals in source. + msg := string([]rune{ + 0x4F60, 0x597D, 0x4E16, 0x754C, + 0x4F60, 0x597D, 0x4E16, 0x754C, + 0x4F60, + }) + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 9 { + t.Errorf("CJK TokenEstimate: got %d, want 9", f.TokenEstimate) + } +} + +func TestExtractFeatures_TokenEstimate_Mixed(t *testing.T) { + // Mixed: 4 CJK runes + 8 ASCII runes → 4 + 8/4 = 6 tokens. + msg := string([]rune{0x4F60, 0x597D, 0x4E16, 0x754C}) + "hello ok" + f := ExtractFeatures(msg, nil) + if f.TokenEstimate != 6 { + t.Errorf("Mixed TokenEstimate: got %d, want 6", f.TokenEstimate) + } +} + +func TestExtractFeatures_CodeBlocks(t *testing.T) { + cases := []struct { + msg string + want int + }{ + {"no code here", 0}, + {"```go\nfmt.Println()\n```", 1}, + {"```python\npass\n```\n```js\nconsole.log()\n```", 2}, + {"```unclosed", 0}, // odd number of fences = 0 complete blocks + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.CodeBlockCount != tc.want { + t.Errorf("msg=%q: CodeBlockCount got %d, want %d", tc.msg, f.CodeBlockCount, tc.want) + } + } +} + +func TestExtractFeatures_RecentToolCalls(t *testing.T) { + // History longer than lookbackWindow — only last lookbackWindow entries count. + history := make([]providers.Message, 10) + // Put 2 tool calls at positions 8 and 9 (within the last 6) + history[8] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}}} + history[9] = providers.Message{ + Role: "assistant", + ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}, + } + // Position 3 is outside the lookback window and must NOT be counted + history[3] = providers.Message{Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "old_tool"}}} + + f := ExtractFeatures("test", history) + // 1 (position 8) + 2 (position 9) = 3 + if f.RecentToolCalls != 3 { + t.Errorf("RecentToolCalls: got %d, want 3", f.RecentToolCalls) + } +} + +func TestExtractFeatures_ConversationDepth(t *testing.T) { + history := make([]providers.Message, 7) + f := ExtractFeatures("msg", history) + if f.ConversationDepth != 7 { + t.Errorf("ConversationDepth: got %d, want 7", f.ConversationDepth) + } +} + +func TestExtractFeatures_HasAttachments_DataURI(t *testing.T) { + cases := []struct { + msg string + want bool + }{ + {"plain text", false}, + {"here is an image: data:image/png;base64,abc123", true}, + {"audio: data:audio/mp3;base64,xyz", true}, + {"video: data:video/mp4;base64,xyz", true}, + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.HasAttachments != tc.want { + t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want) + } + } +} + +func TestExtractFeatures_HasAttachments_Extension(t *testing.T) { + cases := []struct { + msg string + want bool + }{ + {"check out photo.jpg", true}, + {"see screenshot.png", true}, + {"listen to audio.mp3", true}, + {"watch clip.mp4", true}, + {"just a .go file", false}, + {"document.pdf", false}, // pdf is not in the media list + } + for _, tc := range cases { + f := ExtractFeatures(tc.msg, nil) + if f.HasAttachments != tc.want { + t.Errorf("msg=%q: HasAttachments got %v, want %v", tc.msg, f.HasAttachments, tc.want) + } + } +} + +// ── RuleClassifier ─────────────────────────────────────────────────────────── + +func TestRuleClassifier_ZeroFeatures(t *testing.T) { + c := &RuleClassifier{} + score := c.Score(Features{}) + if score != 0.0 { + t.Errorf("zero features: got %f, want 0.0", score) + } +} + +func TestRuleClassifier_AttachmentsHardGate(t *testing.T) { + c := &RuleClassifier{} + score := c.Score(Features{HasAttachments: true}) + if score != 1.0 { + t.Errorf("attachments: got %f, want 1.0", score) + } +} + +func TestRuleClassifier_CodeBlockAlone(t *testing.T) { + c := &RuleClassifier{} + // Code block alone = 0.40, above default threshold 0.35 + score := c.Score(Features{CodeBlockCount: 1}) + if score < 0.35 { + t.Errorf("code block: score %f is below default threshold 0.35", score) + } +} + +func TestRuleClassifier_LongMessage(t *testing.T) { + c := &RuleClassifier{} + // >200 tokens = 0.35, exactly at default threshold → heavy + score := c.Score(Features{TokenEstimate: 250}) + if score < 0.35 { + t.Errorf("long message: score %f is below default threshold 0.35", score) + } +} + +func TestRuleClassifier_MediumMessage(t *testing.T) { + c := &RuleClassifier{} + // 50-200 tokens = 0.15, below threshold → light + score := c.Score(Features{TokenEstimate: 100}) + if score >= 0.35 { + t.Errorf("medium message: score %f should be below default threshold 0.35", score) + } +} + +func TestRuleClassifier_ShortMessage(t *testing.T) { + c := &RuleClassifier{} + // <50 tokens, no other signals = 0.0 → light + score := c.Score(Features{TokenEstimate: 10}) + if score != 0.0 { + t.Errorf("short message: got %f, want 0.0", score) + } +} + +func TestRuleClassifier_ToolCallDensity(t *testing.T) { + c := &RuleClassifier{} + + scoreNone := c.Score(Features{RecentToolCalls: 0}) + scoreLow := c.Score(Features{RecentToolCalls: 2}) + scoreHigh := c.Score(Features{RecentToolCalls: 5}) + + if scoreNone != 0.0 { + t.Errorf("no tools: got %f, want 0.0", scoreNone) + } + if scoreLow <= scoreNone { + t.Errorf("low tools should score higher than none: %f vs %f", scoreLow, scoreNone) + } + if scoreHigh <= scoreLow { + t.Errorf("high tools should score higher than low: %f vs %f", scoreHigh, scoreLow) + } +} + +func TestRuleClassifier_DeepConversation(t *testing.T) { + c := &RuleClassifier{} + shallow := c.Score(Features{ConversationDepth: 5}) + deep := c.Score(Features{ConversationDepth: 15}) + if deep <= shallow { + t.Errorf("deep conversation should score higher: %f vs %f", deep, shallow) + } +} + +func TestRuleClassifier_ScoreDoesNotExceedOne(t *testing.T) { + c := &RuleClassifier{} + // Max all signals simultaneously + f := Features{ + TokenEstimate: 500, + CodeBlockCount: 3, + RecentToolCalls: 10, + ConversationDepth: 20, + } + score := c.Score(f) + if score > 1.0 { + t.Errorf("score %f exceeds 1.0", score) + } +} + +// ── Router ─────────────────────────────────────────────────────────────────── + +func TestRouter_DefaultThreshold(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash"}) + if r.Threshold() != defaultThreshold { + t.Errorf("default threshold: got %f, want %f", r.Threshold(), defaultThreshold) + } +} + +func TestRouter_NegativeThresholdFallsBackToDefault(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: -0.1}) + if r.Threshold() != defaultThreshold { + t.Errorf("negative threshold: got %f, want %f", r.Threshold(), defaultThreshold) + } +} + +func TestRouter_SelectModel_SimpleMessageUsesLight(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "hi" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if !usedLight { + t.Error("simple message: expected light model to be selected") + } + if model != "gemini-flash" { + t.Errorf("simple message: model got %q, want %q", model, "gemini-flash") + } +} + +func TestRouter_SelectModel_CodeBlockUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "```go\nfmt.Println(\"hello\")\n```" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("code block: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("code block: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_AttachmentUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + msg := "can you analyze this? data:image/png;base64,abc123" + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("attachment: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("attachment: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_LongMessageUsesPrimary(t *testing.T) { + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + // >200 token estimate: 210 * 3 = 630 chars + msg := strings.Repeat("word ", 210) + model, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("long message: expected primary model to be selected") + } + if model != "claude-sonnet-4-6" { + t.Errorf("long message: model got %q, want %q", model, "claude-sonnet-4-6") + } +} + +func TestRouter_SelectModel_DeepToolChainUsesLight(t *testing.T) { + // Tool calls alone (0.25) don't cross the 0.35 threshold — acceptable behavior. + // Routing is conservative: only promote to heavy when the signal is unambiguous. + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + history := []providers.Message{ + {Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "read_file"}, {Name: "write_file"}}}, + {Role: "assistant", ToolCalls: []providers.ToolCall{{Name: "exec"}, {Name: "search"}}}, + } + msg := "ok" + _, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6") + if !usedLight { + t.Error("short message + moderate tool calls: expected light model (score 0.20 < 0.35)") + } +} + +func TestRouter_SelectModel_ToolChainPlusMediumUsesHeavy(t *testing.T) { + // Tool calls (0.25) + medium message (0.15) = 0.40 >= 0.35 → heavy + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.35}) + history := []providers.Message{ + {Role: "assistant", ToolCalls: []providers.ToolCall{ + {Name: "a"}, {Name: "b"}, {Name: "c"}, {Name: "d"}, + }}, + } + // ~55 tokens * 3 = 165 chars + msg := strings.Repeat("word ", 55) + _, usedLight, _ := r.SelectModel(msg, history, "claude-sonnet-4-6") + if usedLight { + t.Error("tool chain + medium message: expected primary model (score >= 0.35)") + } +} + +func TestRouter_SelectModel_CustomThreshold(t *testing.T) { + // Very low threshold: even a short message triggers heavy model + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.05}) + msg := strings.Repeat("word ", 55) // medium message → 0.15 >= 0.05 + _, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if usedLight { + t.Error("low threshold: medium message should use primary model") + } +} + +func TestRouter_SelectModel_HighThreshold(t *testing.T) { + // Very high threshold: even code blocks route to light + r := New(RouterConfig{LightModel: "gemini-flash", Threshold: 0.99}) + msg := "```go\nfmt.Println()\n```" + _, usedLight, _ := r.SelectModel(msg, nil, "claude-sonnet-4-6") + if !usedLight { + t.Error("very high threshold: code block (0.40) should route to light model") + } +} + +func TestRouter_LightModel(t *testing.T) { + r := New(RouterConfig{LightModel: "my-fast-model", Threshold: 0.35}) + if r.LightModel() != "my-fast-model" { + t.Errorf("LightModel: got %q, want %q", r.LightModel(), "my-fast-model") + } +} + +// ── newWithClassifier (internal testing hook) ───────────────────────────────── + +type fixedScoreClassifier struct{ score float64 } + +func (f *fixedScoreClassifier) Score(_ Features) float64 { return f.score } + +func TestRouter_CustomClassifier_LowScore_SelectsLight(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.2}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if !usedLight { + t.Error("low score with custom classifier: expected light model") + } +} + +func TestRouter_CustomClassifier_HighScore_SelectsPrimary(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.8}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if usedLight { + t.Error("high score with custom classifier: expected primary model") + } +} + +func TestRouter_CustomClassifier_ExactThreshold_SelectsPrimary(t *testing.T) { + // score == threshold → primary (uses >= comparison) + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.5}, + ) + _, usedLight, _ := r.SelectModel("anything", nil, "heavy") + if usedLight { + t.Error("score == threshold: expected primary model (>= threshold → primary)") + } +} + +func TestRouter_SelectModel_ReturnsScore(t *testing.T) { + r := newWithClassifier( + RouterConfig{LightModel: "light", Threshold: 0.5}, + &fixedScoreClassifier{score: 0.42}, + ) + _, _, score := r.SelectModel("anything", nil, "heavy") + if score != 0.42 { + t.Errorf("score: got %f, want 0.42", score) + } +} diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 31ac9ab88..6af0aa9e1 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -141,6 +141,12 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult everySeconds, hasEvery := args["every_seconds"].(float64) cronExpr, hasCron := args["cron_expr"].(string) + // Fix: type assertions return true for zero values, need additional validity checks + // This prevents LLMs that fill unused optional parameters with defaults (0) from triggering wrong type + hasAt = hasAt && atSeconds > 0 + hasEvery = hasEvery && everySeconds > 0 + hasCron = hasCron && cronExpr != "" + // Priority: at_seconds > every_seconds > cron_expr if hasAt { atMS := time.Now().UnixMilli() + int64(atSeconds)*1000 diff --git a/pkg/tools/send_file.go b/pkg/tools/send_file.go new file mode 100644 index 000000000..1a03e58ed --- /dev/null +++ b/pkg/tools/send_file.go @@ -0,0 +1,150 @@ +package tools + +import ( + "context" + "fmt" + "mime" + "os" + "path/filepath" + "strings" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" +) + +// SendFileTool allows the LLM to send a local file (image, document, etc.) +// to the user on the current chat channel via the MediaStore pipeline. +type SendFileTool struct { + workspace string + restrict bool + maxFileSize int + mediaStore media.MediaStore + + defaultChannel string + defaultChatID string +} + +func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool { + if maxFileSize <= 0 { + maxFileSize = config.DefaultMaxMediaSize + } + return &SendFileTool{ + workspace: workspace, + restrict: restrict, + maxFileSize: maxFileSize, + mediaStore: store, + } +} + +func (t *SendFileTool) Name() string { return "send_file" } +func (t *SendFileTool) Description() string { + return "Send a local file (image, document, etc.) to the user on the current chat channel." +} + +func (t *SendFileTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the local file. Relative paths are resolved from workspace.", + }, + "filename": map[string]any{ + "type": "string", + "description": "Optional display filename. Defaults to the basename of path.", + }, + }, + "required": []string{"path"}, + } +} + +func (t *SendFileTool) SetContext(channel, chatID string) { + t.defaultChannel = channel + t.defaultChatID = chatID +} + +func (t *SendFileTool) SetMediaStore(store media.MediaStore) { + t.mediaStore = store +} + +func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + path, _ := args["path"].(string) + if strings.TrimSpace(path) == "" { + return ErrorResult("path is required") + } + + // Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values. + channel := ToolChannel(ctx) + if channel == "" { + channel = t.defaultChannel + } + chatID := ToolChatID(ctx) + if chatID == "" { + chatID = t.defaultChatID + } + if channel == "" || chatID == "" { + return ErrorResult("no target channel/chat available") + } + + if t.mediaStore == nil { + return ErrorResult("media store not configured") + } + + resolved, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return ErrorResult(fmt.Sprintf("invalid path: %v", err)) + } + + info, err := os.Stat(resolved) + if err != nil { + return ErrorResult(fmt.Sprintf("file not found: %v", err)) + } + if info.IsDir() { + return ErrorResult("path is a directory, expected a file") + } + if info.Size() > int64(t.maxFileSize) { + return ErrorResult(fmt.Sprintf( + "file too large: %d bytes (max %d bytes)", + info.Size(), t.maxFileSize, + )) + } + + filename, _ := args["filename"].(string) + if filename == "" { + filename = filepath.Base(resolved) + } + + mediaType := detectMediaType(resolved) + scope := fmt.Sprintf("tool:send_file:%s:%s", channel, chatID) + + ref, err := t.mediaStore.Store(resolved, media.MediaMeta{ + Filename: filename, + ContentType: mediaType, + Source: "tool:send_file", + }, scope) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to register media: %v", err)) + } + + return MediaResult(fmt.Sprintf("File %q sent to user", filename), []string{ref}) +} + +// detectMediaType determines the MIME type of a file. +// Uses magic-bytes detection (h2non/filetype) first, then falls back to +// extension-based lookup via mime.TypeByExtension. +func detectMediaType(path string) string { + kind, err := filetype.MatchFile(path) + if err == nil && kind != filetype.Unknown { + return kind.MIME.Value + } + + if ext := filepath.Ext(path); ext != "" { + if t := mime.TypeByExtension(ext); t != "" { + return t + } + } + + return "application/octet-stream" +} diff --git a/pkg/tools/send_file_test.go b/pkg/tools/send_file_test.go new file mode 100644 index 000000000..08d129674 --- /dev/null +++ b/pkg/tools/send_file_test.go @@ -0,0 +1,176 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" +) + +func TestSendFileTool_MissingPath(t *testing.T) { + store := media.NewFileMediaStore() + tool := NewSendFileTool("/tmp", false, 0, store) + tool.SetContext("feishu", "chat123") + + result := tool.Execute(context.Background(), map[string]any{}) + if !result.IsError { + t.Fatal("expected error for missing path") + } +} + +func TestSendFileTool_NoContext(t *testing.T) { + store := media.NewFileMediaStore() + tool := NewSendFileTool("/tmp", false, 0, store) + // no SetContext call + + result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"}) + if !result.IsError { + t.Fatal("expected error when no channel context") + } +} + +func TestSendFileTool_NoMediaStore(t *testing.T) { + tool := NewSendFileTool("/tmp", false, 0, nil) + tool.SetContext("feishu", "chat123") + + result := tool.Execute(context.Background(), map[string]any{"path": "/tmp/test.txt"}) + if !result.IsError { + t.Fatal("expected error when no media store") + } +} + +func TestSendFileTool_Directory(t *testing.T) { + store := media.NewFileMediaStore() + tool := NewSendFileTool("/tmp", false, 0, store) + tool.SetContext("feishu", "chat123") + + result := tool.Execute(context.Background(), map[string]any{"path": "/tmp"}) + if !result.IsError { + t.Fatal("expected error for directory path") + } +} + +func TestSendFileTool_FileTooLarge(t *testing.T) { + dir := t.TempDir() + testFile := filepath.Join(dir, "big.bin") + // Create a file larger than the limit + if err := os.WriteFile(testFile, make([]byte, 1024), 0o644); err != nil { + t.Fatal(err) + } + + store := media.NewFileMediaStore() + tool := NewSendFileTool(dir, false, 512, store) // 512 byte limit + tool.SetContext("feishu", "chat123") + + result := tool.Execute(context.Background(), map[string]any{"path": testFile}) + if !result.IsError { + t.Fatal("expected error for oversized file") + } + if !strings.Contains(result.ForLLM, "too large") { + t.Errorf("expected 'too large' in error, got %q", result.ForLLM) + } +} + +func TestSendFileTool_DefaultMaxSize(t *testing.T) { + tool := NewSendFileTool("/tmp", false, 0, nil) + if tool.maxFileSize != config.DefaultMaxMediaSize { + t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize) + } +} + +func TestSendFileTool_Success(t *testing.T) { + dir := t.TempDir() + testFile := filepath.Join(dir, "photo.png") + if err := os.WriteFile(testFile, []byte("fake png"), 0o644); err != nil { + t.Fatal(err) + } + + store := media.NewFileMediaStore() + tool := NewSendFileTool(dir, false, 0, store) + tool.SetContext("feishu", "chat123") + + result := tool.Execute(context.Background(), map[string]any{"path": testFile}) + if result.IsError { + t.Fatalf("unexpected error: %s", result.ForLLM) + } + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } + if result.Media[0][:8] != "media://" { + t.Errorf("expected media:// ref, got %q", result.Media[0]) + } +} + +func TestSendFileTool_CustomFilename(t *testing.T) { + dir := t.TempDir() + testFile := filepath.Join(dir, "img.jpg") + if err := os.WriteFile(testFile, []byte("fake jpg"), 0o644); err != nil { + t.Fatal(err) + } + + store := media.NewFileMediaStore() + tool := NewSendFileTool(dir, false, 0, store) + tool.SetContext("telegram", "chat456") + + result := tool.Execute(context.Background(), map[string]any{ + "path": testFile, + "filename": "my-photo.jpg", + }) + if result.IsError { + t.Fatalf("unexpected error: %s", result.ForLLM) + } + if len(result.Media) != 1 { + t.Fatalf("expected 1 media ref, got %d", len(result.Media)) + } +} + +func TestDetectMediaType_MagicBytes(t *testing.T) { + dir := t.TempDir() + + // Minimal valid PNG header + pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + pngFile := filepath.Join(dir, "image.dat") // wrong extension, but valid PNG bytes + if err := os.WriteFile(pngFile, pngHeader, 0o644); err != nil { + t.Fatal(err) + } + + got := detectMediaType(pngFile) + if got != "image/png" { + t.Errorf("expected image/png from magic bytes, got %q", got) + } +} + +func TestDetectMediaType_FallbackToExtension(t *testing.T) { + dir := t.TempDir() + + // File with unrecognizable content but known extension + txtFile := filepath.Join(dir, "readme.txt") + if err := os.WriteFile(txtFile, []byte("hello world"), 0o644); err != nil { + t.Fatal(err) + } + + got := detectMediaType(txtFile) + // text/plain or similar — just verify it's not application/octet-stream + if got == "application/octet-stream" { + t.Errorf("expected extension-based MIME for .txt, got %q", got) + } +} + +func TestDetectMediaType_UnknownFallsToOctetStream(t *testing.T) { + dir := t.TempDir() + + // File with no extension and random bytes + unknownFile := filepath.Join(dir, "mystery") + if err := os.WriteFile(unknownFile, []byte{0x00, 0x01, 0x02}, 0o644); err != nil { + t.Fatal(err) + } + + got := detectMediaType(unknownFile) + if got != "application/octet-stream" { + t.Errorf("expected application/octet-stream, got %q", got) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index a0c83eb1e..b8a811d03 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -59,7 +59,7 @@ var ( regexp.MustCompile(`\bchown\b`), regexp.MustCompile(`\bpkill\b`), regexp.MustCompile(`\bkillall\b`), - regexp.MustCompile(`\bkill\s+-[9]\b`), + regexp.MustCompile(`\bkill\b`), regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`), regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`), regexp.MustCompile(`\bnpm\s+install\s+-g\b`), @@ -131,9 +131,14 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf denyPatterns = append(denyPatterns, defaultDenyPatterns...) } + timeout := 60 * time.Second + if config != nil && config.Tools.Exec.TimeoutSeconds > 0 { + timeout = time.Duration(config.Tools.Exec.TimeoutSeconds) * time.Second + } + return &ExecTool{ workingDir: workingDir, - timeout: 60 * time.Second, + timeout: timeout, denyPatterns: denyPatterns, allowPatterns: nil, customAllowPatterns: customAllowPatterns, diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index a6abca8ea..ff9ea4a15 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -151,6 +151,26 @@ func TestShellTool_DangerousCommand(t *testing.T) { } } +func TestShellTool_DangerousCommand_KillBlocked(t *testing.T) { + tool, err := NewExecTool("", false) + if err != nil { + t.Errorf("unable to configure exec tool: %s", err) + } + + ctx := context.Background() + args := map[string]any{ + "command": "kill 12345", + } + + result := tool.Execute(ctx, args) + if !result.IsError { + t.Errorf("Expected kill command to be blocked") + } + if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { + t.Errorf("Expected blocked message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + // TestShellTool_MissingCommand verifies error handling for missing command func TestShellTool_MissingCommand(t *testing.T) { tool, err := NewExecTool("", false) diff --git a/pkg/tools/spawn_test.go b/pkg/tools/spawn_test.go index 0646c82a9..43223b8db 100644 --- a/pkg/tools/spawn_test.go +++ b/pkg/tools/spawn_test.go @@ -8,7 +8,7 @@ import ( func TestSpawnTool_Execute_EmptyTask(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSpawnTool(manager) ctx := context.Background() @@ -42,7 +42,7 @@ func TestSpawnTool_Execute_EmptyTask(t *testing.T) { func TestSpawnTool_Execute_ValidTask(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSpawnTool(manager) ctx := context.Background() diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 429340047..e51cbaafa 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -6,7 +6,6 @@ import ( "sync" "time" - "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -27,7 +26,6 @@ type SubagentManager struct { mu sync.RWMutex provider providers.LLMProvider defaultModel string - bus *bus.MessageBus workspace string tools *ToolRegistry maxIterations int @@ -41,13 +39,11 @@ type SubagentManager struct { func NewSubagentManager( provider providers.LLMProvider, defaultModel, workspace string, - bus *bus.MessageBus, ) *SubagentManager { return &SubagentManager{ tasks: make(map[string]*SubagentTask), provider: provider, defaultModel: defaultModel, - bus: bus, workspace: workspace, tools: NewToolRegistry(), maxIterations: 10, @@ -214,20 +210,6 @@ After completing the task, provide a clear summary of what was done.` Async: false, } } - - // Send announce message back to main agent - if sm.bus != nil { - announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) - pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer pubCancel() - sm.bus.PublishInbound(pubCtx, bus.InboundMessage{ - Channel: "system", - SenderID: fmt.Sprintf("subagent:%s", task.ID), - // Format: "original_channel:original_chat_id" for routing back - ChatID: fmt.Sprintf("%s:%s", task.OriginChannel, task.OriginChatID), - Content: announceContent, - }) - } } func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) { diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go index a1450410a..4b6f130a5 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent_tool_test.go @@ -5,7 +5,6 @@ import ( "strings" "testing" - "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/providers" ) @@ -47,7 +46,7 @@ func (m *MockLLMProvider) GetContextWindow() int { func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") manager.SetLLMOptions(2048, 0.6) tool := NewSubagentTool(manager) @@ -73,7 +72,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { // TestSubagentTool_Name verifies tool name func TestSubagentTool_Name(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) if tool.Name() != "subagent" { @@ -84,7 +83,7 @@ func TestSubagentTool_Name(t *testing.T) { // TestSubagentTool_Description verifies tool description func TestSubagentTool_Description(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) desc := tool.Description() @@ -99,7 +98,7 @@ func TestSubagentTool_Description(t *testing.T) { // TestSubagentTool_Parameters verifies tool parameters schema func TestSubagentTool_Parameters(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) params := tool.Parameters() @@ -149,8 +148,7 @@ func TestSubagentTool_Parameters(t *testing.T) { // TestSubagentTool_Execute_Success tests successful execution func TestSubagentTool_Execute_Success(t *testing.T) { provider := &MockLLMProvider{} - msgBus := bus.NewMessageBus() - manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) ctx := WithToolContext(context.Background(), "telegram", "chat-123") @@ -204,8 +202,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) { // TestSubagentTool_Execute_NoLabel tests execution without label func TestSubagentTool_Execute_NoLabel(t *testing.T) { provider := &MockLLMProvider{} - msgBus := bus.NewMessageBus() - manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) ctx := context.Background() @@ -228,7 +225,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) { // TestSubagentTool_Execute_MissingTask tests error handling for missing task func TestSubagentTool_Execute_MissingTask(t *testing.T) { provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) ctx := context.Background() @@ -278,8 +275,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) { // TestSubagentTool_Execute_ContextPassing verifies context is properly used func TestSubagentTool_Execute_ContextPassing(t *testing.T) { provider := &MockLLMProvider{} - msgBus := bus.NewMessageBus() - manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) channel := "test-channel" @@ -304,8 +300,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) { func TestSubagentTool_ForUserTruncation(t *testing.T) { // Create a mock provider that returns very long content provider := &MockLLMProvider{} - msgBus := bus.NewMessageBus() - manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + manager := NewSubagentManager(provider, "test-model", "/tmp/test") tool := NewSubagentTool(manager) ctx := context.Background() diff --git a/scripts/test-irc.sh b/scripts/test-irc.sh new file mode 100755 index 000000000..40db01756 --- /dev/null +++ b/scripts/test-irc.sh @@ -0,0 +1,56 @@ +#!/bin/sh +# Starts a local Ergo IRC server for testing the IRC channel. +# +# Requirements: docker +# Usage: ./scripts/test-irc.sh + +set -e + +CONTAINER_NAME="picoclaw-test-ergo" +IRC_PORT=6667 + +# Clean up any previous instance +docker rm -f "$CONTAINER_NAME" >/dev/null 2>&1 || true + +echo "Starting Ergo IRC server on port $IRC_PORT..." +docker run -d \ + --name "$CONTAINER_NAME" \ + -p "$IRC_PORT:6667" \ + ghcr.io/ergochat/ergo:stable + +for i in $(seq 1 10); do + if nc -z localhost "$IRC_PORT" 2>/dev/null; then + break + fi + if [ "$i" -eq 10 ]; then + echo "ERROR: Server did not start within 10s" + exit 1 + fi + sleep 1 +done + +echo "" +echo "IRC server ready on localhost:$IRC_PORT" +echo "" +echo "Add this to your ~/.picoclaw/config.json under \"channels\":" +echo "" +echo ' "irc": {' +echo ' "enabled": true,' +echo ' "server": "localhost:6667",' +echo ' "tls": false,' +echo ' "nick": "picobot",' +echo ' "channels": ["#test"],' +echo ' "allow_from": [],' +echo ' "group_trigger": { "mention_only": true }' +echo ' }' +echo "" +echo "Then run picoclaw:" +echo " cd packages/picoclaw && go run ./cmd/picoclaw gateway" +echo "" +echo "Connect with an IRC client:" +echo " irssi: /connect localhost $IRC_PORT" +echo " weechat: /server add test localhost/$IRC_PORT && /connect test" +echo " Join #test, then: picobot: hello" +echo "" +echo "To stop the IRC server:" +echo " docker rm -f $CONTAINER_NAME"