diff --git a/README.md b/README.md index 2420df864..2aa3b631f 100644 --- a/README.md +++ b/README.md @@ -191,15 +191,510 @@ make install For detailed guides, see the docs below. The README covers quick start only. -| Topic | Description | -|-------|-------------| -| 🐳 [Docker & Quick Start](docs/docker.md) | Docker Compose setup, Launcher/Agent modes, Quick Start configuration | -| 💬 [Chat Apps](docs/chat-apps.md) | Telegram, Discord, WhatsApp, Matrix, QQ, Slack, IRC, DingTalk, LINE, Feishu, WeCom, and more | -| ⚙️ [Configuration](docs/configuration.md) | Environment variables, workspace layout, skill sources, security sandbox, heartbeat | -| 🔌 [Providers & Models](docs/providers.md) | 20+ LLM providers, model routing, model_list configuration, provider architecture | -| 🔄 [Spawn & Async Tasks](docs/spawn-tasks.md) | Quick tasks, long tasks with spawn, async sub-agent orchestration | -| 🐛 [Troubleshooting](docs/troubleshooting.md) | Common issues and solutions | -| 🔧 [Tools Configuration](docs/tools_configuration.md) | Per-tool enable/disable, exec policies | +```bash +# 1. Clone this repo +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. First run — auto-generates docker/data/config.json then exits +docker compose -f docker/docker-compose.yml --profile gateway up +# The container prints "First-run setup complete." and stops. + +# 3. Set your API keys +vim docker/data/config.json # Set provider API keys, bot tokens, etc. + +# 4. Start +docker compose -f docker/docker-compose.yml --profile gateway up -d +``` + +> [!TIP] +> **Docker Users**: By default, the Gateway listens on `127.0.0.1` which is not accessible from the host. If you need to access the health endpoints or expose ports, set `PICOCLAW_GATEWAY_HOST=0.0.0.0` in your environment or update `config.json`. + +```bash +# 5. Check logs +docker compose -f docker/docker-compose.yml logs -f picoclaw-gateway + +# 6. Stop +docker compose -f docker/docker-compose.yml --profile gateway down +``` + +### Launcher Mode (Web Console) + +The `launcher` image includes all three binaries (`picoclaw`, `picoclaw-launcher`, `picoclaw-launcher-tui`) and starts the web console by default, which provides a browser-based UI for configuration and chat. + +```bash +docker compose -f docker/docker-compose.yml --profile launcher up -d +``` + +Open http://localhost:18800 in your browser. The launcher manages the gateway process automatically. + +> [!WARNING] +> The web console does not yet support authentication. Avoid exposing it to the public internet. + +### Agent Mode (One-shot) + +```bash +# Ask a question +docker compose -f docker/docker-compose.yml run --rm picoclaw-agent -m "What is 2+2?" + +# Interactive mode +docker compose -f docker/docker-compose.yml run --rm picoclaw-agent +``` + +### Update + +```bash +docker compose -f docker/docker-compose.yml pull +docker compose -f docker/docker-compose.yml --profile gateway up -d +``` + +### 🚀 Quick Start + +> [!TIP] +> Set your API Key in `~/.picoclaw/config.json`. Get API Keys: [Volcengine (CodingPlan)](https://console.volcengine.com) (LLM) · [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM). Web search is optional — get a free [Tavily API](https://tavily.com) (1000 free queries/month) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month). + +**1. Initialize** + +```bash +picoclaw onboard +``` + +**2. Configure** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model_name": "gpt-5.4", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "model_list": [ + { + "model_name": "ark-code-latest", + "model": "volcengine/ark-code-latest", + "api_key": "sk-your-api-key" + }, + { + "model_name": "gpt-5.4", + "model": "openai/gpt-5.4", + "api_key": "your-api-key", + "request_timeout": 300 + }, + { + "model_name": "claude-sonnet-4.6", + "model": "anthropic/claude-sonnet-4.6", + "api_key": "your-anthropic-key" + } + ], + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "tavily": { + "enabled": false, + "api_key": "YOUR_TAVILY_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "YOUR_PERPLEXITY_API_KEY", + "max_results": 5 + }, + "searxng": { + "enabled": false, + "base_url": "http://your-searxng-instance:8888", + "max_results": 5 + } + } + } +} +``` + +> **New**: The `model_list` configuration format allows zero-code provider addition. See [Model Configuration](#model-configuration-model_list) for details. +> `request_timeout` is optional and uses seconds. If omitted or set to `<= 0`, PicoClaw uses the default timeout (120s). + +**3. Get API Keys** + +* **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Web Search** (optional): + * [Brave Search](https://brave.com/search/api) - Paid ($5/1000 queries, ~$5-6/month) + * [Perplexity](https://www.perplexity.ai) - AI-powered search with chat interface + * [SearXNG](https://github.com/searxng/searxng) - Self-hosted metasearch engine (free, no API key needed) + * [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month) + * DuckDuckGo - Built-in fallback (no API key required) + +> **Note**: See `config.example.json` for a complete configuration template. + +**4. Chat** + +```bash +picoclaw agent -m "What is 2+2?" +``` + +That's it! You have a working AI assistant in 2 minutes. + +--- + +## 💬 Chat Apps + +Talk to your picoclaw through Telegram, Discord, WhatsApp, Matrix, QQ, DingTalk, LINE, or WeCom + +> **Note**: All webhook-based channels (LINE, WeCom, etc.) are served on a single shared Gateway HTTP server (`gateway.host`:`gateway.port`, default `127.0.0.1:18790`). There are no per-channel ports to configure. Note: Feishu uses WebSocket/SDK mode and does not use the shared HTTP webhook server. + +| Channel | Setup | +| ------------ | ---------------------------------- | +| **Telegram** | Easy (just a token) | +| **Discord** | Easy (bot token + intents) | +| **WhatsApp** | Easy (native: QR scan; or bridge URL) | +| **Matrix** | Medium (homeserver + bot access token) | +| **QQ** | Easy (AppID + AppSecret) | +| **DingTalk** | Medium (app credentials) | +| **LINE** | Medium (credentials + webhook URL) | +| **WeCom AI Bot** | Medium (Token + AES key) | + +
+Telegram (Recommended) + +**1. Create a bot** + +* Open Telegram, search `@BotFather` +* Send `/newbot`, follow prompts +* Copy the token + +**2. Configure** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +> Get your user ID from `@userinfobot` on Telegram. + +**3. Run** + +```bash +picoclaw gateway +``` + +**4. Telegram command menu (auto-registered at startup)** + +PicoClaw now keeps command definitions in one shared registry. On startup, Telegram will automatically register supported bot commands (for example `/start`, `/help`, `/show`, `/list`) so command menu and runtime behavior stay in sync. +Telegram command menu registration remains channel-local discovery UX; generic command execution is handled centrally in the agent loop via the commands executor. + +If command registration fails (network/API transient errors), the channel still starts and PicoClaw retries registration in the background. + +
+ +
+Discord + +**1. Create a bot** + +* Go to +* Create an application → Bot → Add Bot +* Copy the bot token + +**2. Enable intents** + +* In the Bot settings, enable **MESSAGE CONTENT INTENT** +* (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data + +**3. Get your User ID** +* Discord Settings → Advanced → enable **Developer Mode** +* Right-click your avatar → **Copy User ID** + +**4. Configure** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +**5. Invite the bot** + +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* Open the generated invite URL and add the bot to your server + +**Optional: Group trigger mode** + +By default the bot responds to all messages in a server channel. To restrict responses to @-mentions only, add: + +```json +{ + "channels": { + "discord": { + "group_trigger": { "mention_only": true } + } + } +} +``` + +You can also trigger by keyword prefixes (e.g. `!bot`): + +```json +{ + "channels": { + "discord": { + "group_trigger": { "prefixes": ["!bot"] } + } + } +} +``` + +**6. Run** + +```bash +picoclaw gateway +``` + +
+ +
+WhatsApp (native via whatsmeow) + +PicoClaw can connect to WhatsApp in two ways: + +- **Native (recommended):** In-process using [whatsmeow](https://github.com/tulir/whatsmeow). No separate bridge. Set `"use_native": true` and leave `bridge_url` empty. On first run, scan the QR code with WhatsApp (Linked Devices). Session is stored under your workspace (e.g. `workspace/whatsapp/`). The native channel is **optional** to keep the default binary small; build with `-tags whatsapp_native` (e.g. `make build-whatsapp-native` or `go build -tags whatsapp_native ./cmd/...`). +- **Bridge:** Connect to an external WebSocket bridge. Set `bridge_url` (e.g. `ws://localhost:3001`) and keep `use_native` false. + +**Configure (native)** + +```json +{ + "channels": { + "whatsapp": { + "enabled": true, + "use_native": true, + "session_store_path": "", + "allow_from": [] + } + } +} +``` + +If `session_store_path` is empty, the session is stored in `<workspace>/whatsapp/`. Run `picoclaw gateway`; on first run, scan the QR code printed in the terminal with WhatsApp → Linked Devices. + +
+ +
+QQ + +**1. Create a bot** + +- Go to [QQ Open Platform](https://q.qq.com/#) +- Create an application → Get **AppID** and **AppSecret** + +**2. Configure** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> Set `allow_from` to empty to allow all users, or specify QQ numbers to restrict access. + +**3. Run** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Create a bot** + +* Go to [Open Platform](https://open.dingtalk.com/) +* Create an internal app +* Copy Client ID and Client Secret + +**2. Configure** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> Set `allow_from` to empty to allow all users, or specify DingTalk user IDs to restrict access. + +**3. Run** + +```bash +picoclaw gateway +``` +
+ +
+Matrix + +**1. Prepare bot account** + +* Use your preferred homeserver (e.g. `https://matrix.org` or self-hosted) +* Create a bot user and obtain its access token + +**2. Configure** + +```json +{ + "channels": { + "matrix": { + "enabled": true, + "homeserver": "https://matrix.org", + "user_id": "@your-bot:matrix.org", + "access_token": "YOUR_MATRIX_ACCESS_TOKEN", + "allow_from": [] + } + } +} +``` + +**3. Run** + +```bash +picoclaw gateway +``` + +For full options (`device_id`, `join_on_invite`, `group_trigger`, `placeholder`, `reasoning_channel_id`), see [Matrix Channel Configuration Guide](docs/channels/matrix/README.md). + +
+ +
+LINE + +**1. Create a LINE Official Account** + +- Go to [LINE Developers Console](https://developers.line.biz/) +- Create a provider → Create a Messaging API channel +- Copy **Channel Secret** and **Channel Access Token** + +**2. Configure** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +> LINE webhook is served on the shared Gateway server (`gateway.host`:`gateway.port`, default `127.0.0.1:18790`). + +**3. Set up Webhook URL** + +LINE requires HTTPS for webhooks. Use a reverse proxy or tunnel: + +```bash +# Example with ngrok (gateway default port is 18790) +ngrok http 18790 +``` + +Then set the Webhook URL in LINE Developers Console to `https://your-domain/webhook/line` and enable **Use webhook**. + +**4. Run** + +```bash +picoclaw gateway +``` + +> In group chats, the bot responds only when @mentioned. Replies quote the original message. + +
+ +
+WeCom (企业微信) + +PicoClaw supports three types of WeCom integration: + +**Option 1: WeCom Bot (Bot)** - Easier setup, supports group chats +**Option 2: WeCom App (Custom App)** - More features, proactive messaging, private chat only +**Option 3: WeCom AI Bot (AI Bot)** - Official AI Bot, streaming replies, supports group & private chat + +See [WeCom AI Bot Configuration Guide](docs/channels/wecom/wecom_aibot/README.zh.md) for detailed setup instructions. + +**Quick Setup - WeCom AI Bot:** + +**1. Create an AI Bot** + +* Go to WeCom Admin Console → AI Bot +* Create a new AI Bot → Set name, avatar, etc. +* Copy **Bot ID** and **Secret** + +**2. Configure** + +```json +{ + "channels": { + "wecom_aibot": { + "enabled": true, + "bot_id": "YOUR_BOT_ID", + "secret": "YOUR_SECRET", + "allow_from": [], + "welcome_message": "Hello! How can I help you?" + } + } +} +``` + +**3. Run** + +```bash +picoclaw gateway +``` + +> **Note**: WeCom AI Bot uses streaming pull protocol — no reply timeout concerns. Long tasks (>30 seconds) automatically switch to `response_url` push delivery. + +
## ClawdChat Join the Agent Social Network diff --git a/config/config.example.json b/config/config.example.json index c214f26fa..221e89491 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -8,7 +8,11 @@ "temperature": 0.7, "max_tool_iterations": 20, "summarize_message_threshold": 20, - "summarize_token_percent": 75 + "summarize_token_percent": 75, + "tool_feedback": { + "enabled": false, + "max_args_length": 300 + } } }, "model_list": [ @@ -200,6 +204,8 @@ "wecom_aibot": { "_comment": "WeCom AI Bot (智能机器人) - Official WeCom AI Bot integration, supports proactive messaging and private chats.", "enabled": false, + "bot_id": "YOUR_BOT_ID", + "secret": "YOUR_SECRET", "token": "YOUR_TOKEN", "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", "webhook_path": "/webhook/wecom-aibot", diff --git a/docs/channels/wecom/wecom_aibot/README.zh.md b/docs/channels/wecom/wecom_aibot/README.zh.md index d210528af..de4fba445 100644 --- a/docs/channels/wecom/wecom_aibot/README.zh.md +++ b/docs/channels/wecom/wecom_aibot/README.zh.md @@ -1,6 +1,6 @@ # 企业微信智能机器人 (AI Bot) -企业微信智能机器人(AI Bot)是企业微信官方提供的 AI 对话接入方式,支持私聊与群聊,内置流式响应协议,并支持超时后通过 `response_url` 主动推送最终回复。 +企业微信智能机器人(AI Bot)是企业微信官方提供的 AI 对话接入方式,支持私聊与群聊,内置流式响应协议。 ## 与其他 WeCom 通道的对比 @@ -19,9 +19,8 @@ "channels": { "wecom_aibot": { "enabled": true, - "token": "YOUR_TOKEN", - "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", - "webhook_path": "/webhook/wecom-aibot", + "bot_id": "YOUR_BOT_ID", + "secret": "YOUR_SECRET", "allow_from": [], "welcome_message": "你好!有什么可以帮助你的吗?", "max_steps": 10 @@ -32,9 +31,8 @@ | 字段 | 类型 | 必填 | 描述 | | ---------------- | ------ | ---- | -------------------------------------------------- | -| token | string | 是 | 回调验证令牌,在 AI Bot 管理页面配置 | -| encoding_aes_key | string | 是 | 43 字符 AES 密钥,在 AI Bot 管理页面随机生成 | -| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-aibot) | +| bot_id | string | 是 | AI Bot 的唯一标识,在 AI Bot 管理页面配置 | +| secret | string | 是 | AI Bot 的密钥,在 AI Bot 管理页面配置 | | allow_from | array | 否 | 用户 ID 白名单,空数组表示允许所有用户 | | welcome_message | string | 否 | 用户进入聊天时发送的欢迎语,留空则不发送 | | reply_timeout | int | 否 | 回复超时时间(秒,默认:5) | @@ -44,42 +42,8 @@ 1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin) 2. 进入"应用管理" → "智能机器人",创建或选择一个 AI Bot -3. 在 AI Bot 配置页面,填写"消息接收"信息: - - **URL**:`http://:18791/webhook/wecom-aibot` - - **Token**:随机生成或自定义 - - **EncodingAESKey**:点击"随机生成",得到 43 字符密钥 -4. 将 Token 和 EncodingAESKey 填入 PicoClaw 配置文件,启动服务后回到管理后台保存(企业微信会发送验证请求) - -> [!TIP] -> 服务器需要能被企业微信服务器访问。如在内网/本地开发,可使用 [ngrok](https://ngrok.com) 或 frp 做内网穿透。 - -## 流式响应协议 - -WeCom AI Bot 使用"流式拉取"协议,区别于普通 Webhook 的一次性回复: - -``` -用户发消息 - │ - ▼ -PicoClaw 立即返回 {finish: false}(Agent 开始处理) - │ - ▼ -企业微信每隔约 1 秒拉取一次 {msgtype: "stream", stream: {id: "..."}} - │ - ├─ Agent 未完成 → 返回 {finish: false}(继续等待) - │ - └─ Agent 完成 → 返回 {finish: true, content: "回答内容"} -``` - -**超时处理**(任务超过 30 秒): - -若 Agent 处理时间超过约 30 秒(企业微信最大轮询窗口为 6 分钟),PicoClaw 会: - -1. 立即关闭流,向用户显示「⏳ 正在处理中,请稍候,结果将稍后发送。」 -2. Agent 继续在后台运行 -3. Agent 完成后,通过消息中携带的 `response_url` 将最终回复主动推送给用户 - -> `response_url` 由企业微信颁发,有效期 1 小时,只可使用一次,无需加密,直接 POST markdown 消息体即可。 +3. 在 AI Bot 配置页面,配置Bot的名称、头像等信息,获取 `Bot ID` 和 `Secret` +4. 在 PicoClaw 配置文件中添加上述配置,重启 PicoClaw ## 欢迎语 @@ -91,26 +55,12 @@ PicoClaw 立即返回 {finish: false}(Agent 开始处理) ## 常见问题 -### 回调 URL 验证失败 - -- 确认服务器防火墙已开放对应端口(默认 18791) -- 确认 `token` 与 `encoding_aes_key` 填写正确 -- 检查 PicoClaw 日志是否收到了来自企业微信的 GET 请求 - ### 消息没有回复 - 检查 `allow_from` 是否意外限制了发送者 - 查看日志中是否出现 `context canceled` 或 Agent 错误 - 确认 Agent 配置(`model_name` 等)正确 -### 超长任务没有收到最终推送 - -- 确认消息回调中携带了 `response_url`(仅企业微信新版 AI Bot 支持) -- 确认服务器能主动访问外网(需向 `response_url` POST 请求) -- 查看日志关键词 `response_url mode` 和 `Sending reply via response_url` - ## 参考文档 -- [企业微信 AI Bot 接入文档](https://developer.work.weixin.qq.com/document/path/100719) -- [流式响应协议说明](https://developer.work.weixin.qq.com/document/path/100719) -- [response_url 主动回复](https://developer.work.weixin.qq.com/document/path/101138) +- [企业微信 AI Bot 接入文档](https://developer.work.weixin.qq.com/document/path/101463) diff --git a/docs/debug.md b/docs/debug.md index 7e28a15f2..b9e776f0f 100644 --- a/docs/debug.md +++ b/docs/debug.md @@ -31,3 +31,69 @@ When this flag is active, the global truncation function is disabled. This is ex * Verifying the exact syntax of the messages sent to the provider. * Reading the complete output of tools like `exec`, `web_fetch`, or `read_file`. * Debugging the session history saved in memory. + +## Tool Call Visibility in Debug Logs + +When debug mode is active, the agent emits structured log entries at each stage of the tool execution lifecycle. These entries carry a `component=agent` label and use `INFO` or `DEBUG` level depending on the amount of detail: + +| Log message | Level | Key fields | Description | +|---|---|---|---| +| `LLM requested tool calls` | INFO | `tools`, `count`, `iteration` | List of tool names the model decided to call | +| `Tool call: ()` | INFO | `tool`, `iteration` | The tool name and a preview of its arguments (truncated to 200 chars) | +| `Sent tool result to user` | DEBUG | `tool`, `content_len` | Fired when a tool result is forwarded to the chat channel | +| `TTL tick after tool execution` | DEBUG | `agent_id`, `iteration` | MCP tool-discovery TTL decrement after each tool round | +| `Async tool completed, publishing result` | INFO | `tool`, `content_len`, `channel` | Only for tools that run asynchronously in the background | + +### Reading a tool call log entry + +A typical synchronous tool call produces two consecutive lines in the console: + +``` +[...] [INFO] agent: LLM requested tool calls {tools=[web_search], count=1, iteration=1} +[...] [INFO] agent: Tool call: web_search({"query":"picoclaw release notes"}) {tool=web_search, iteration=1} +``` + +The arguments preview is hard-capped at **200 characters** in the logs regardless of the `--no-truncate` flag, because it belongs to the `INFO`-level path. Use `--no-truncate` together with `--debug` to see the full `tools_json` field emitted by the `Full LLM request` DEBUG entry, which contains every tool definition sent to the model. + +## Real-Time Tool Feedback in Chat (tool_feedback) + +Debug logs are server-side only. If you want the agent to send a visible notification directly into the chat channel every time it executes a tool—useful when sharing the bot with other users or for transparency—enable the `tool_feedback` feature in `config.json`: + +```json +{ + "agents": { + "defaults": { + "tool_feedback": { + "enabled": true, + "max_args_length": 300 + } + } + } +} +``` + +When `enabled` is `true`, every tool call sends a short message to the chat before the tool result is returned to the model. The message looks like: + +```bash +🔧 `web_search` +{"query": "picoclaw release notes"} +``` + + +### Options + +| Field | Type | Default | Description | +|---|---|---|---| +| `enabled` | bool | `false` | Send a chat notification for each tool call | +| `max_args_length` | int | `300` | Maximum characters of the serialised arguments included in the notification | + +### Environment variables + +Both fields can also be set via environment variables: + +```bash +PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_ENABLED=true +PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_MAX_ARGS_LENGTH=300 +``` + +> **Note:** `tool_feedback` is independent of `--debug` mode. It works in production and does not require the gateway to be started with any special flag. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 36dc4a257..9d0c3c0dd 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1569,8 +1569,29 @@ func (al *AgentLoop) runLLMIteration( "iteration": iteration, }) + // Send tool feedback to chat channel if enabled + if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() && opts.Channel != "" { + feedbackPreview := utils.Truncate( + string(argsJSON), + al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(), + ) + feedbackMsg := fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", tc.Name, feedbackPreview) + fbCtx, fbCancel := context.WithTimeout(ctx, 3*time.Second) + _ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: feedbackMsg, + }) + fbCancel() + } + // Create async callback for tools that implement AsyncExecutor. + // When the background work completes, this publishes the result + // as an inbound system message so processSystemMessage routes it + // back to the user via the normal agent loop. asyncCallback := func(_ context.Context, result *tools.ToolResult) { + // Send ForUser content directly to the user (immediate feedback), + // mirroring the synchronous tool execution path. if !result.Silent && result.ForUser != "" { outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) defer outCancel() @@ -1581,6 +1602,7 @@ func (al *AgentLoop) runLLMIteration( }) } + // Determine content for the agent loop (ForLLM or error). content := result.ForLLM if content == "" && result.Err != nil { content = result.Err.Error() diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 9e5fea1b6..2e1e12ded 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -296,7 +296,9 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error { m.initChannel("wecom", "WeCom") } - if channels.WeComAIBot.Enabled && channels.WeComAIBot.Token != "" { + if m.config.Channels.WeComAIBot.Enabled && + ((m.config.Channels.WeComAIBot.BotID != "" && m.config.Channels.WeComAIBot.Secret != "") || + m.config.Channels.WeComAIBot.Token != "") { m.initChannel("wecom_aibot", "WeCom AI Bot") } diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go index 93fe8c36d..999f4f13b 100644 --- a/pkg/channels/wecom/aibot.go +++ b/pkg/channels/wecom/aibot.go @@ -22,6 +22,10 @@ import ( "github.com/sipeed/picoclaw/pkg/utils" ) +// responseURLHTTPClient is a shared HTTP client for posting to WeCom response_url. +// Reusing it enables connection pooling across replies. +var responseURLHTTPClient = &http.Client{Timeout: 15 * time.Second} + // WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人) type WeComAIBotChannel struct { *channels.BaseChannel @@ -134,13 +138,25 @@ type WeComAIBotEncryptedResponse struct { Nonce string `json:"nonce"` } -// NewWeComAIBotChannel creates a new WeCom AI Bot channel instance +// NewWeComAIBotChannel creates a WeCom AI Bot channel instance. +// If cfg.BotID and cfg.Secret are both set, it returns a WeComAIBotWSChannel +// using the WebSocket long-connection API. +// Otherwise it returns the webhook-mode WeComAIBotChannel (requires Token + +// EncodingAESKey). func NewWeComAIBotChannel( cfg config.WeComAIBotConfig, messageBus *bus.MessageBus, -) (*WeComAIBotChannel, error) { +) (channels.Channel, error) { + // WebSocket long-connection mode takes priority when BotID + Secret are set. + if cfg.BotID != "" && cfg.Secret != "" { + logger.InfoC("wecom_aibot", "BotID and Secret provided, using WebSocket mode") + return newWeComAIBotWSChannel(cfg, messageBus) + } + // Webhook (short-connection) mode. if cfg.Token == "" || cfg.EncodingAESKey == "" { - return nil, fmt.Errorf("token and encoding_aes_key are required for WeCom AI Bot") + return nil, fmt.Errorf( + "WeCom AI Bot requires either (bot_id + secret) for WebSocket mode " + + "or (token + encoding_aes_key) for webhook mode") } base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom, @@ -782,8 +798,7 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro } req.Header.Set("Content-Type", "application/json; charset=utf-8") - client := &http.Client{Timeout: 15 * time.Second} - resp, err := client.Do(req) + resp, err := responseURLHTTPClient.Do(req) if err != nil { return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err) } @@ -793,7 +808,8 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro return nil } - respBody, err := io.ReadAll(resp.Body) + const maxErrBody = 64 << 10 // 64 KB is more than enough for any error response + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxErrBody)) if err != nil { return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err) } @@ -895,17 +911,80 @@ func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, return base64.StdEncoding.EncodeToString(ciphertext), nil } -// generateStreamID generates a random stream ID -func (c *WeComAIBotChannel) generateStreamID() string { +// func (c *WeComAIBotChannel) downloadAndDecryptImage( +// ctx context.Context, +// imageURL string, +// ) ([]byte, error) { +// // Download image +// req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil) +// if err != nil { +// return nil, fmt.Errorf("failed to create request: %w", err) +// } + +// client := &http.Client{ +// Timeout: 15 * time.Second, +// } + +// resp, err := client.Do(req) +// if err != nil { +// return nil, fmt.Errorf("failed to download image: %w", err) +// } +// defer resp.Body.Close() + +// if resp.StatusCode != http.StatusOK { +// return nil, fmt.Errorf("download failed with status: %d", resp.StatusCode) +// } + +// // Limit image download to 20 MB to prevent memory exhaustion +// const maxImageSize = 20 << 20 // 20 MB +// encryptedData, err := io.ReadAll(io.LimitReader(resp.Body, maxImageSize+1)) +// if err != nil { +// return nil, fmt.Errorf("failed to read image data: %w", err) +// } +// if len(encryptedData) > maxImageSize { +// return nil, fmt.Errorf("image too large (exceeds %d MB)", maxImageSize>>20) +// } + +// logger.DebugCF("wecom_aibot", "Image downloaded", map[string]any{ +// "size": len(encryptedData), +// }) + +// // Decode AES key +// aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey) +// if err != nil { +// return nil, err +// } + +// // Decrypt image (AES-CBC with IV = first 16 bytes of key, PKCS7 padding stripped) +// decryptedData, err := decryptAESCBC(aesKey, encryptedData) +// if err != nil { +// return nil, fmt.Errorf("failed to decrypt image: %w", err) +// } + +// logger.DebugCF("wecom_aibot", "Image decrypted", map[string]any{ +// "size": len(decryptedData), +// }) + +// return decryptedData, nil +// } + +// generateRandomID generates a cryptographically random alphanumeric ID of +// length n. Used for stream IDs and WebSocket request IDs. +func generateRandomID(n int) string { const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, 10) + b := make([]byte, n) for i := range b { - n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) - b[i] = letters[n.Int64()] + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + b[i] = letters[num.Int64()] } return string(b) } +// generateStreamID generates a random 10-character stream ID (webhook mode). +func (c *WeComAIBotChannel) generateStreamID() string { + return generateRandomID(10) +} + // cleanupLoop periodically cleans up old streaming tasks func (c *WeComAIBotChannel) cleanupLoop() { ticker := time.NewTicker(5 * time.Minute) diff --git a/pkg/channels/wecom/aibot_test.go b/pkg/channels/wecom/aibot_test.go index 6f0664187..7c5ae67b1 100644 --- a/pkg/channels/wecom/aibot_test.go +++ b/pkg/channels/wecom/aibot_test.go @@ -3,12 +3,16 @@ package wecom import ( "context" "testing" + "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) -func TestNewWeComAIBotChannel(t *testing.T) { +// ---- Webhook mode tests ---- + +func TestNewWeComAIBotChannel_WebhookMode(t *testing.T) { t.Run("success with valid config", func(t *testing.T) { cfg := config.WeComAIBotConfig{ Enabled: true, @@ -22,14 +26,16 @@ func TestNewWeComAIBotChannel(t *testing.T) { if err != nil { t.Fatalf("Expected no error, got %v", err) } - if ch == nil { t.Fatal("Expected channel to be created") } - if ch.Name() != "wecom_aibot" { t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name()) } + // Webhook mode must implement WebhookHandler. + if _, ok := ch.(channels.WebhookHandler); !ok { + t.Error("Webhook mode channel should implement WebhookHandler") + } }) t.Run("error with missing token", func(t *testing.T) { @@ -37,10 +43,8 @@ func TestNewWeComAIBotChannel(t *testing.T) { Enabled: true, EncodingAESKey: "testkey1234567890123456789012345678901234567", } - messageBus := bus.NewMessageBus() _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { t.Fatal("Expected error for missing token, got nil") } @@ -51,17 +55,15 @@ func TestNewWeComAIBotChannel(t *testing.T) { Enabled: true, Token: "test_token", } - messageBus := bus.NewMessageBus() _, err := NewWeComAIBotChannel(cfg, messageBus) - if err == nil { t.Fatal("Expected error for missing encoding key, got nil") } }) } -func TestWeComAIBotChannelStartStop(t *testing.T) { +func TestWeComAIBotWebhookChannelStartStop(t *testing.T) { cfg := config.WeComAIBotConfig{ Enabled: true, Token: "test_token", @@ -76,22 +78,18 @@ func TestWeComAIBotChannelStartStop(t *testing.T) { ctx := context.Background() - // Test Start if err := ch.Start(ctx); err != nil { t.Fatalf("Failed to start channel: %v", err) } - if !ch.IsRunning() { - t.Error("Expected channel to be running") + t.Error("Expected channel to be running after Start") } - // Test Stop if err := ch.Stop(ctx); err != nil { t.Fatalf("Failed to stop channel: %v", err) } - if ch.IsRunning() { - t.Error("Expected channel to be stopped") + t.Error("Expected channel to be stopped after Stop") } } @@ -102,13 +100,16 @@ func TestWeComAIBotChannelWebhookPath(t *testing.T) { Token: "test_token", EncodingAESKey: "testkey1234567890123456789012345678901234567", } - messageBus := bus.NewMessageBus() ch, _ := NewWeComAIBotChannel(cfg, messageBus) + wh, ok := ch.(channels.WebhookHandler) + if !ok { + t.Fatal("Expected channel to implement WebhookHandler") + } expectedPath := "/webhook/wecom-aibot" - if ch.WebhookPath() != expectedPath { - t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, ch.WebhookPath()) + if wh.WebhookPath() != expectedPath { + t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, wh.WebhookPath()) } }) @@ -120,12 +121,15 @@ func TestWeComAIBotChannelWebhookPath(t *testing.T) { EncodingAESKey: "testkey1234567890123456789012345678901234567", WebhookPath: customPath, } - messageBus := bus.NewMessageBus() ch, _ := NewWeComAIBotChannel(cfg, messageBus) - if ch.WebhookPath() != customPath { - t.Errorf("Expected webhook path '%s', got '%s'", customPath, ch.WebhookPath()) + wh, ok := ch.(channels.WebhookHandler) + if !ok { + t.Fatal("Expected channel to implement WebhookHandler") + } + if wh.WebhookPath() != customPath { + t.Errorf("Expected webhook path '%s', got '%s'", customPath, wh.WebhookPath()) } }) } @@ -136,19 +140,19 @@ func TestGenerateStreamID(t *testing.T) { Token: "test_token", EncodingAESKey: "testkey1234567890123456789012345678901234567", } - messageBus := bus.NewMessageBus() ch, _ := NewWeComAIBotChannel(cfg, messageBus) + webhookCh, ok := ch.(*WeComAIBotChannel) + if !ok { + t.Fatal("Expected webhook mode channel") + } - // Generate multiple IDs and check they are unique ids := make(map[string]bool) for i := 0; i < 100; i++ { - id := ch.generateStreamID() - + id := webhookCh.generateStreamID() if len(id) != 10 { t.Errorf("Expected stream ID length 10, got %d", len(id)) } - if ids[id] { t.Errorf("Duplicate stream ID generated: %s", id) } @@ -157,35 +161,33 @@ func TestGenerateStreamID(t *testing.T) { } func TestEncryptDecrypt(t *testing.T) { - // Use a valid 43-character base64 key (企业微信标准格式) cfg := config.WeComAIBotConfig{ Enabled: true, Token: "test_token", EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters } - messageBus := bus.NewMessageBus() ch, _ := NewWeComAIBotChannel(cfg, messageBus) + webhookCh, ok := ch.(*WeComAIBotChannel) + if !ok { + t.Fatal("Expected webhook mode channel") + } plaintext := "Hello, World!" receiveid := "" - // Encrypt - encrypted, err := ch.encryptMessage(plaintext, receiveid) + encrypted, err := webhookCh.encryptMessage(plaintext, receiveid) if err != nil { t.Fatalf("Failed to encrypt message: %v", err) } - if encrypted == "" { t.Fatal("Encrypted message is empty") } - // Decrypt decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid) if err != nil { t.Fatalf("Failed to decrypt message: %v", err) } - if decrypted != plaintext { t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted) } @@ -198,13 +200,256 @@ func TestGenerateSignature(t *testing.T) { encrypt := "encrypted_msg" signature := computeSignature(token, timestamp, nonce, encrypt) - if signature == "" { t.Error("Generated signature is empty") } - - // Verify signature using verifySignature function if !verifySignature(token, signature, timestamp, nonce, encrypt) { t.Error("Generated signature does not verify correctly") } } + +// ---- WebSocket long-connection mode tests ---- + +func TestNewWeComAIBotChannel_WSMode(t *testing.T) { + t.Run("success with bot_id and secret", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + BotID: "test_bot_id", + Secret: "test_secret", + } + messageBus := bus.NewMessageBus() + ch, err := NewWeComAIBotChannel(cfg, messageBus) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if ch == nil { + t.Fatal("Expected channel to be created") + } + if ch.Name() != "wecom_aibot" { + t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name()) + } + // WebSocket mode must NOT implement WebhookHandler. + if _, ok := ch.(channels.WebhookHandler); ok { + t.Error("WebSocket mode channel should NOT implement WebhookHandler") + } + }) + + t.Run("ws mode takes priority over webhook fields", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + BotID: "test_bot_id", + Secret: "test_secret", + Token: "also_set", + EncodingAESKey: "testkey1234567890123456789012345678901234567", + } + messageBus := bus.NewMessageBus() + ch, err := NewWeComAIBotChannel(cfg, messageBus) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if _, ok := ch.(*WeComAIBotWSChannel); !ok { + t.Error("Expected WebSocket mode channel when both BotID+Secret and Token+Key are set") + } + }) + + t.Run("error with missing bot_id", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + Secret: "test_secret", + } + messageBus := bus.NewMessageBus() + _, err := NewWeComAIBotChannel(cfg, messageBus) + // Missing bot_id alone means neither WS mode nor webhook mode is fully configured. + if err == nil { + t.Fatal("Expected error for missing bot_id, got nil") + } + }) + + t.Run("error with missing secret", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + BotID: "test_bot_id", + } + messageBus := bus.NewMessageBus() + _, err := NewWeComAIBotChannel(cfg, messageBus) + if err == nil { + t.Fatal("Expected error for missing secret, got nil") + } + }) +} + +func TestWeComAIBotWSChannelStartStop(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + BotID: "test_bot_id", + Secret: "test_secret", + } + messageBus := bus.NewMessageBus() + ch, err := NewWeComAIBotChannel(cfg, messageBus) + if err != nil { + t.Fatalf("Failed to create channel: %v", err) + } + + ctx := context.Background() + + // Start launches a background goroutine; it should not block or return an error. + if err := ch.Start(ctx); err != nil { + t.Fatalf("Failed to start channel: %v", err) + } + if !ch.IsRunning() { + t.Error("Expected channel to be running after Start") + } + + // Stop should work regardless of whether the WebSocket actually connected. + if err := ch.Stop(ctx); err != nil { + t.Fatalf("Failed to stop channel: %v", err) + } + if ch.IsRunning() { + t.Error("Expected channel to be stopped after Stop") + } +} + +func TestGenerateRandomID(t *testing.T) { + ids := make(map[string]bool) + for i := 0; i < 200; i++ { + id := generateRandomID(10) + if len(id) != 10 { + t.Errorf("Expected ID length 10, got %d", len(id)) + } + if ids[id] { + t.Errorf("Duplicate ID generated: %s", id) + } + ids[id] = true + } +} + +func TestWSGenerateID(t *testing.T) { + ids := make(map[string]bool) + for i := 0; i < 200; i++ { + id := wsGenerateID() + if len(id) != 10 { + t.Errorf("Expected ID length 10, got %d", len(id)) + } + if ids[id] { + t.Errorf("Duplicate wsGenerateID result: %s", id) + } + ids[id] = true + } +} + +// ---- Webhook streaming fallback tests ---- + +// makeWebhookChannel creates a started WeComAIBotChannel for testing. +func makeWebhookChannel(t *testing.T) *WeComAIBotChannel { + t.Helper() + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", + } + ch, err := NewWeComAIBotChannel(cfg, bus.NewMessageBus()) + if err != nil { + t.Fatalf("create channel: %v", err) + } + wc := ch.(*WeComAIBotChannel) + wc.ctx, wc.cancel = context.WithCancel(context.Background()) + return wc +} + +// makeStreamTask creates and registers a streamTask for testing. +func makeStreamTask(t *testing.T, ch *WeComAIBotChannel, streamID, chatID string, deadline time.Time) *streamTask { + t.Helper() + task := &streamTask{ + StreamID: streamID, + ChatID: chatID, + Deadline: deadline, + answerCh: make(chan string, 1), + } + task.ctx, task.cancel = context.WithCancel(ch.ctx) + ch.taskMu.Lock() + ch.streamTasks[streamID] = task + ch.chatTasks[chatID] = append(ch.chatTasks[chatID], task) + ch.taskMu.Unlock() + return task +} + +// TestGetStreamResponse_ImmediateAnswer verifies that when the agent has already +// placed its answer in answerCh, getStreamResponse returns a finish=true response +// and fully removes the task. +func TestGetStreamResponse_ImmediateAnswer(t *testing.T) { + ch := makeWebhookChannel(t) + defer ch.cancel() + + task := makeStreamTask(t, ch, "stream-1", "chat-1", time.Now().Add(30*time.Second)) + task.answerCh <- "hello from agent" + + result := ch.getStreamResponse(task, "ts123", "nonce123") + if result == "" { + t.Fatal("expected non-empty encrypted response") + } + + ch.taskMu.RLock() + _, exists := ch.streamTasks["stream-1"] + ch.taskMu.RUnlock() + if exists { + t.Error("task should have been removed from streamTasks after normal finish") + } + if !task.Finished { + t.Error("task.Finished should be true after normal finish") + } +} + +// TestGetStreamResponse_DeadlinePassed verifies that when the stream deadline has +// elapsed (no agent reply yet), getStreamResponse closes the stream but keeps the +// task alive so the response_url fallback can still deliver the answer. +func TestGetStreamResponse_DeadlinePassed(t *testing.T) { + ch := makeWebhookChannel(t) + defer ch.cancel() + + task := makeStreamTask(t, ch, "stream-2", "chat-2", time.Now().Add(-time.Millisecond)) + + result := ch.getStreamResponse(task, "ts456", "nonce456") + if result == "" { + t.Fatal("expected non-empty encrypted response") + } + + ch.taskMu.RLock() + _, stillStreaming := ch.streamTasks["stream-2"] + ch.taskMu.RUnlock() + if stillStreaming { + t.Error("task should have been removed from streamTasks after deadline") + } + if !task.StreamClosed { + t.Error("task.StreamClosed should be true after deadline") + } + if task.Finished { + t.Error("task.Finished must remain false: agent reply still expected via response_url") + } +} + +// TestGetStreamResponse_StillPending verifies that when neither the agent has +// replied nor the deadline has passed, getStreamResponse returns without altering +// task state (client should poll again). +func TestGetStreamResponse_StillPending(t *testing.T) { + ch := makeWebhookChannel(t) + defer ch.cancel() + + task := makeStreamTask(t, ch, "stream-3", "chat-3", time.Now().Add(30*time.Second)) + + result := ch.getStreamResponse(task, "ts789", "nonce789") + if result == "" { + t.Fatal("expected non-empty encrypted response") + } + + ch.taskMu.RLock() + _, exists := ch.streamTasks["stream-3"] + ch.taskMu.RUnlock() + if !exists { + t.Error("pending task should still be in streamTasks") + } + if task.Finished || task.StreamClosed { + t.Error("pending task should not be finished or stream-closed") + } + // Cleanup. + ch.removeTask(task) +} diff --git a/pkg/channels/wecom/aibot_ws.go b/pkg/channels/wecom/aibot_ws.go new file mode 100644 index 000000000..830e763b9 --- /dev/null +++ b/pkg/channels/wecom/aibot_ws.go @@ -0,0 +1,1346 @@ +package wecom + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// Long-connection WebSocket endpoint. +// Ref: https://developer.work.weixin.qq.com/document/path/101463 +const ( + wsEndpoint = "wss://openws.work.weixin.qq.com" + wsHeartbeatInterval = 30 * time.Second + wsConnectTimeout = 15 * time.Second + wsSubscribeTimeout = 10 * time.Second + wsSendMsgTimeout = 10 * time.Second + wsRespondMsgTimeout = 10 * time.Second + wsWelcomeMsgTimeout = 5 * time.Second // WeCom requires welcome reply within 5 seconds + wsMaxReconnectWait = 60 * time.Second + wsInitialReconnect = time.Second + + // WeCom requires finish=true within 6 minutes of the first stream frame. + // wsStreamTickInterval controls how often we send an in-progress hint. + // wsStreamMaxDuration is a safety margin below the 6-minute hard limit. + wsStreamTickInterval = 30 * time.Second + wsStreamMaxDuration = 5*time.Minute + 30*time.Second + + // wsImageDownloadTimeout caps the time we spend downloading an inbound image. + wsImageDownloadTimeout = 30 * time.Second + + // Keep req_id -> chat route for late fallback pushes after stream window closes. + wsLateReplyRouteTTL = 30 * time.Minute + + // wsStreamMaxContentBytes is the maximum UTF-8 byte length for the content field + // of a single WeCom AI Bot stream / text / markdown frame. + // Ref: https://developer.work.weixin.qq.com/document/path/101463 + wsStreamMaxContentBytes = 20480 +) + +// wsImageHTTPClient is a shared HTTP client for downloading inbound images. +// Reusing it enables connection pooling across multiple image downloads. +var wsImageHTTPClient = &http.Client{Timeout: wsImageDownloadTimeout} + +// WeComAIBotWSChannel implements channels.Channel for WeCom AI Bot using the +// WebSocket long-connection API. +// Unlike the webhook counterpart it does NOT implement WebhookHandler, so the +// HTTP manager will not register any callback URL for it. +type WeComAIBotWSChannel struct { + *channels.BaseChannel + config config.WeComAIBotConfig + ctx context.Context + cancel context.CancelFunc + + // conn is the active WebSocket connection; nil when disconnected. + // All writes are serialized through connMu. + conn *websocket.Conn + connMu sync.Mutex + + // dedupe prevents duplicate message processing (WeCom may re-deliver). + dedupe *MessageDeduplicator + + // reqStates holds per-req_id runtime state. + // It unifies active task state and late-reply fallback routing. + reqStates map[string]*wsReqState + reqStatesMu sync.Mutex + + // reqPending correlates command req_ids with response channels. + // Used only for subscribe/ping command-response pairs. + reqPending map[string]chan wsEnvelope + reqPendingMu sync.Mutex +} + +// wsTask tracks one in-progress agent reply for a single chat turn. +type wsTask struct { + ReqID string // req_id echoed in all replies for this turn + ChatID string + ChatType uint32 + StreamID string // our generated stream.id + answerCh chan string // agent delivers its reply here via Send() + ctx context.Context + cancel context.CancelFunc +} + +type wsReqState struct { + Task *wsTask + Route wsLateReplyRoute +} + +type wsLateReplyRoute struct { + ChatID string + ChatType uint32 + ReadyAt time.Time + ExpiresAt time.Time +} + +// ---- WebSocket protocol types ---- + +// wsEnvelope is the generic JSON envelope for all WebSocket messages. +type wsEnvelope struct { + Cmd string `json:"cmd,omitempty"` + Headers wsHeaders `json:"headers"` + Body json.RawMessage `json:"body,omitempty"` + ErrCode int `json:"errcode,omitempty"` + ErrMsg string `json:"errmsg,omitempty"` +} + +type wsHeaders struct { + ReqID string `json:"req_id"` +} + +// wsCommand is an outgoing request sent over the WebSocket. +type wsCommand struct { + Cmd string `json:"cmd"` + Headers wsHeaders `json:"headers"` + Body any `json:"body,omitempty"` +} + +type wsSendMsgBody struct { + ChatID string `json:"chatid"` + ChatType uint32 `json:"chat_type,omitempty"` + MsgType string `json:"msgtype"` + Markdown *wsMarkdownContent `json:"markdown,omitempty"` +} + +// wsRespondMsgBody is the body for aibot_respond_msg / aibot_respond_welcome_msg. +type wsRespondMsgBody struct { + MsgType string `json:"msgtype"` + Stream *wsStreamContent `json:"stream,omitempty"` + Text *wsTextContent `json:"text,omitempty"` + Markdown *wsMarkdownContent `json:"markdown,omitempty"` + Image *wsImageContent `json:"image,omitempty"` +} + +type wsStreamContent struct { + ID string `json:"id"` + Finish bool `json:"finish"` + Content string `json:"content,omitempty"` +} + +// wsImageContent carries a base64-encoded image payload for outbound messages. +type wsImageContent struct { + Base64 string `json:"base64"` + MD5 string `json:"md5"` +} + +type wsTextContent struct { + Content string `json:"content"` +} + +type wsMarkdownContent struct { + Content string `json:"content"` +} + +// WeComAIBotWSMessage is the decoded body of aibot_msg_callback / +// aibot_event_callback in WebSocket long-connection mode. +// The structure mirrors WeComAIBotMessage but includes extra fields +// that only appear in long-connection callbacks (Voice, AESKey on Image/File). +type WeComAIBotWSMessage struct { + MsgID string `json:"msgid"` + CreateTime int64 `json:"create_time,omitempty"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid,omitempty"` + ChatType string `json:"chattype,omitempty"` // "single" | "group" + From struct { + UserID string `json:"userid"` + } `json:"from"` + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + Image *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` // long-connection: per-resource decrypt key + } `json:"image,omitempty"` + Voice *struct { + Content string `json:"content"` // WeCom transcribes voice to text in callbacks + } `json:"voice,omitempty"` + Mixed *struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + Image *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"image,omitempty"` + } `json:"msg_item"` + } `json:"mixed,omitempty"` + Event *struct { + EventType string `json:"eventtype"` + } `json:"event,omitempty"` + File *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"file,omitempty"` + Video *struct { + URL string `json:"url"` + AESKey string `json:"aeskey,omitempty"` + } `json:"video,omitempty"` +} + +// ---- Constructor ---- + +// newWeComAIBotWSChannel creates a WeComAIBotWSChannel for WebSocket mode. +func newWeComAIBotWSChannel( + cfg config.WeComAIBotConfig, + messageBus *bus.MessageBus, +) (*WeComAIBotWSChannel, error) { + if cfg.BotID == "" || cfg.Secret == "" { + return nil, fmt.Errorf("bot_id and secret are required for WeCom AI Bot WebSocket mode") + } + + base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom, + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) + + return &WeComAIBotWSChannel{ + BaseChannel: base, + config: cfg, + dedupe: NewMessageDeduplicator(wecomMaxProcessedMessages), + reqStates: make(map[string]*wsReqState), + reqPending: make(map[string]chan wsEnvelope), + }, nil +} + +// ---- Channel interface ---- + +// Name implements channels.Channel. +func (c *WeComAIBotWSChannel) Name() string { return "wecom_aibot" } + +// Start connects to the WeCom WebSocket endpoint and begins message processing. +func (c *WeComAIBotWSChannel) Start(ctx context.Context) error { + logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel (WebSocket long-connection mode)...") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + go c.connectLoop() + logger.InfoC("wecom_aibot", "WeCom AI Bot channel started (WebSocket mode)") + return nil +} + +// Stop shuts down the channel and closes the WebSocket connection. +func (c *WeComAIBotWSChannel) Stop(_ context.Context) error { + logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel (WebSocket mode)...") + if c.cancel != nil { + c.cancel() + } + c.connMu.Lock() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + c.connMu.Unlock() + c.SetRunning(false) + logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped") + return nil +} + +// Send delivers the agent reply for msg.ChatID. +// The waiting task goroutine picks it up and writes the final stream response. +func (c *WeComAIBotWSChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + // msg.ChatID carries the inbound req_id (set by dispatchWSAgentTask). + // For cron-triggered messages, msg.ChatID is the real WeCom chat/user ID + // and there will be no matching entry in reqStates; fall through to proactive push. + task, route, ok := c.getReqState(msg.ChatID) + if !ok { + // No req_id record found — this is a cron/scheduler-originated message. + // Send it as a proactive markdown push using the chat ID directly. + logger.InfoCF("wecom_aibot", "Send: no req_id state, delivering via proactive push (cron/scheduler)", + map[string]any{"chat_id": msg.ChatID}) + if err := c.wsSendActivePush(msg.ChatID, 0, msg.Content); err != nil { + logger.WarnCF("wecom_aibot", "Proactive push failed", + map[string]any{"chat_id": msg.ChatID, "error": err.Error()}) + return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed) + } + return nil + } + + if task == nil { + if time.Now().Before(route.ReadyAt) { + // Keep using aibot_respond_msg within stream window; do not proactively + // push unless wsStreamMaxDuration has elapsed. + logger.WarnCF("wecom_aibot", "Send: stream window still open, skip proactive push", + map[string]any{"req_id": msg.ChatID, "ready_at": route.ReadyAt.Format(time.RFC3339)}) + return nil + } + + if err := c.wsSendActivePush(route.ChatID, route.ChatType, msg.Content); err != nil { + logger.WarnCF("wecom_aibot", "Late reply proactive push failed", + map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "error": err.Error()}) + return fmt.Errorf("websocket delivery failed: %w", channels.ErrSendFailed) + } + logger.InfoCF("wecom_aibot", "Late reply delivered via proactive push", + map[string]any{"req_id": msg.ChatID, "chat_id": route.ChatID, "chat_type": route.ChatType}) + c.deleteReqState(msg.ChatID) + return nil + } + + // Non-blocking fast path: when answerCh has space, deliver without racing + // against task.ctx.Done() (which fires when the task is canceled by a new + // incoming message, but the response must still be sent). + select { + case task.answerCh <- msg.Content: + return nil + default: + } + // answerCh was full; block with cancellation guards. + select { + case task.answerCh <- msg.Content: + case <-task.ctx.Done(): + return nil + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +// ---- Connection management ---- + +// wsBackoffResetDuration is the minimum duration a WebSocket connection must +// stay up before we reset the reconnect backoff to its initial value. This +// prevents a short burst of failures from causing long waits after later, +// stable connection periods. +const wsBackoffResetDuration = time.Minute + +// connectLoop maintains the WebSocket connection, reconnecting on failure with +// exponential backoff. +func (c *WeComAIBotWSChannel) connectLoop() { + backoff := wsInitialReconnect + for { + select { + case <-c.ctx.Done(): + return + default: + } + + logger.InfoC("wecom_aibot", "Connecting to WeCom WebSocket endpoint...") + start := time.Now() + if err := c.runConnection(); err != nil { + elapsed := time.Since(start) + // If the connection was stable for long enough, reset backoff so that + // a previous burst of failures does not keep us at the maximum delay. + if elapsed >= wsBackoffResetDuration { + backoff = wsInitialReconnect + } + select { + case <-c.ctx.Done(): + return + default: + logger.WarnCF("wecom_aibot", "WebSocket connection lost, reconnecting", + map[string]any{"error": err.Error(), "backoff": backoff.String()}) + select { + case <-time.After(backoff): + case <-c.ctx.Done(): + return + } + if backoff < wsMaxReconnectWait { + backoff *= 2 + if backoff > wsMaxReconnectWait { + backoff = wsMaxReconnectWait + } + } + } + } else { + // Clean exit (context canceled); stop reconnecting. + return + } + } +} + +// runConnection dials, subscribes, and runs the read/heartbeat loops until the +// connection closes or the channel context is canceled. +func (c *WeComAIBotWSChannel) runConnection() error { + dialCtx, dialCancel := context.WithTimeout(c.ctx, wsConnectTimeout) + conn, httpResp, err := websocket.DefaultDialer.DialContext(dialCtx, wsEndpoint, nil) + dialCancel() + if httpResp != nil { + httpResp.Body.Close() + } + if err != nil { + return fmt.Errorf("dial failed: %w", err) + } + + c.connMu.Lock() + c.conn = conn + c.connMu.Unlock() + + defer func() { + c.connMu.Lock() + if c.conn == conn { + c.conn = nil + } + c.connMu.Unlock() + // Cancel any tasks that were started over this connection so their + // agent goroutines do not keep running after the connection is gone. + c.cancelAllTasks() + }() + + // ---- Read loop (must start BEFORE subscribing) ---- + // sendAndWait blocks waiting for the subscribe response on reqPending; + // readLoop is the only goroutine that delivers messages to reqPending. + // Starting readLoop first avoids a deadlock where sendAndWait times out + // because no one reads the server's reply. + readErrCh := make(chan error, 1) + go func() { readErrCh <- c.readLoop(conn) }() + + // ---- Subscribe ---- + reqID := wsGenerateID() + resp, err := c.sendAndWait(conn, reqID, wsCommand{ + Cmd: "aibot_subscribe", + Headers: wsHeaders{ReqID: reqID}, + Body: map[string]string{ + "bot_id": c.config.BotID, + "secret": c.config.Secret, + }, + }, wsSubscribeTimeout) + if err != nil { + conn.Close() // stop readLoop + <-readErrCh + return fmt.Errorf("subscribe failed: %w", err) + } + if resp.ErrCode != 0 { + conn.Close() + <-readErrCh + return fmt.Errorf("subscribe rejected (errcode=%d): %s", resp.ErrCode, resp.ErrMsg) + } + + logger.InfoC("wecom_aibot", "WebSocket subscription successful") + + // ---- Heartbeat goroutine ---- + hbDone := make(chan struct{}) + go func() { + defer close(hbDone) + c.heartbeatLoop(conn) + }() + + // Wait for the read loop to exit, then tear down the heartbeat. + readErr := <-readErrCh + conn.Close() // signal heartbeat to stop (idempotent) + <-hbDone + return readErr +} + +// sendAndWait registers a pending-response slot, sends cmd, and blocks until +// the matching response arrives or the timeout/context fires. +func (c *WeComAIBotWSChannel) sendAndWait( + conn *websocket.Conn, + reqID string, + cmd wsCommand, + timeout time.Duration, +) (wsEnvelope, error) { + ch := make(chan wsEnvelope, 1) + c.reqPendingMu.Lock() + c.reqPending[reqID] = ch + c.reqPendingMu.Unlock() + + cleanup := func() { + c.reqPendingMu.Lock() + delete(c.reqPending, reqID) + c.reqPendingMu.Unlock() + } + + data, err := json.Marshal(cmd) + if err != nil { + cleanup() + return wsEnvelope{}, fmt.Errorf("marshal command: %w", err) + } + c.connMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.connMu.Unlock() + if err != nil { + cleanup() + return wsEnvelope{}, fmt.Errorf("write command: %w", err) + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case env := <-ch: + return env, nil + case <-timer.C: + cleanup() + return wsEnvelope{}, fmt.Errorf("timeout waiting for response (req_id=%s)", reqID) + case <-c.ctx.Done(): + cleanup() + return wsEnvelope{}, c.ctx.Err() + } +} + +// heartbeatLoop sends a ping every wsHeartbeatInterval until conn is closed. +// It validates the server's pong response via sendAndWait; a failed pong +// triggers a reconnection by closing the connection. +func (c *WeComAIBotWSChannel) heartbeatLoop(conn *websocket.Conn) { + ticker := time.NewTicker(wsHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + reqID := wsGenerateID() + resp, err := c.sendAndWait(conn, reqID, wsCommand{ + Cmd: "ping", + Headers: wsHeaders{ReqID: reqID}, + }, wsHeartbeatInterval) + if err != nil { + logger.WarnCF("wecom_aibot", "Heartbeat failed, closing connection", + map[string]any{"error": err.Error()}) + conn.Close() + return + } + if resp.ErrCode != 0 { + logger.WarnCF("wecom_aibot", "Heartbeat rejected", + map[string]any{"errcode": resp.ErrCode, "errmsg": resp.ErrMsg}) + conn.Close() + return + } + logger.DebugCF("wecom_aibot", "Heartbeat pong received", map[string]any{"req_id": reqID}) + case <-c.ctx.Done(): + return + } + } +} + +// readLoop reads WebSocket messages and dispatches them until the connection +// closes or the channel is stopped. +func (c *WeComAIBotWSChannel) readLoop(conn *websocket.Conn) error { + for { + _, raw, err := conn.ReadMessage() + if err != nil { + select { + case <-c.ctx.Done(): + return nil // clean shutdown + default: + return fmt.Errorf("read error: %w", err) + } + } + + var env wsEnvelope + if err := json.Unmarshal(raw, &env); err != nil { + logger.WarnCF("wecom_aibot", "Failed to parse WebSocket message", + map[string]any{"error": err.Error(), "raw": string(raw)}) + continue + } + + // Command responses have an empty Cmd field; forward to any waiting + // sendAndWait() call, or silently drop if no one is waiting (e.g. + // late responses after timeout). + if env.Cmd == "" && env.Headers.ReqID != "" { + c.reqPendingMu.Lock() + ch, ok := c.reqPending[env.Headers.ReqID] + if ok { + delete(c.reqPending, env.Headers.ReqID) + } + c.reqPendingMu.Unlock() + if ok { + ch <- env + } + continue + } + + // Dispatch to appropriate handler in a separate goroutine so the + // read loop is never blocked by a slow agent. + go c.handleEnvelope(env) + } +} + +// ---- Message / event handlers ---- + +// handleEnvelope routes a WebSocket envelope to the right handler. +func (c *WeComAIBotWSChannel) handleEnvelope(env wsEnvelope) { + switch env.Cmd { + case "aibot_msg_callback": + c.handleMsgCallback(env) + case "aibot_event_callback": + c.handleEventCallback(env) + default: + logger.DebugCF("wecom_aibot", "Unhandled WebSocket command", + map[string]any{"cmd": env.Cmd}) + } +} + +// handleMsgCallback processes aibot_msg_callback. +func (c *WeComAIBotWSChannel) handleMsgCallback(env wsEnvelope) { + var msg WeComAIBotWSMessage + if err := json.Unmarshal(env.Body, &msg); err != nil { + logger.WarnCF("wecom_aibot", "Failed to parse msg callback body", + map[string]any{"error": err.Error()}) + return + } + + // Deduplicate by msgid (WeCom may re-deliver on network issues). + if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) { + logger.DebugCF("wecom_aibot", "Duplicate message ignored", + map[string]any{"msgid": msg.MsgID}) + return + } + + reqID := env.Headers.ReqID + switch msg.MsgType { + case "text": + c.handleWSTextMessage(reqID, msg) + case "image": + c.handleWSImageMessage(reqID, msg) + case "voice": + c.handleWSVoiceMessage(reqID, msg) + case "mixed": + c.handleWSMixedMessage(reqID, msg) + case "file": + c.handleWSFileMessage(reqID, msg) + case "video": + c.handleWSVideoMessage(reqID, msg) + default: + logger.WarnCF("wecom_aibot", "Unsupported message type", + map[string]any{"msgtype": msg.MsgType}) + c.wsSendStreamFinish(reqID, wsGenerateID(), + "Unsupported message type: "+msg.MsgType) + } +} + +// handleEventCallback processes aibot_event_callback. +func (c *WeComAIBotWSChannel) handleEventCallback(env wsEnvelope) { + var msg WeComAIBotWSMessage + if err := json.Unmarshal(env.Body, &msg); err != nil { + logger.WarnCF("wecom_aibot", "Failed to parse event callback body", + map[string]any{"error": err.Error()}) + return + } + + // Deduplicate by msgid. + if msg.MsgID != "" && !c.dedupe.MarkMessageProcessed(msg.MsgID) { + logger.DebugCF("wecom_aibot", "Duplicate event ignored", + map[string]any{"msgid": msg.MsgID}) + return + } + + var eventType string + if msg.Event != nil { + eventType = msg.Event.EventType + } + logger.DebugCF("wecom_aibot", "Received event callback", + map[string]any{"event_type": eventType}) + + switch eventType { + case "enter_chat": + if c.config.WelcomeMessage != "" { + c.wsSendWelcomeMsg(env.Headers.ReqID, c.config.WelcomeMessage) + } + case "disconnected_event": + // The server will close this connection after sending this event. + // connectLoop will detect the closure and reconnect automatically. + logger.WarnC("wecom_aibot", + "Received disconnected_event: this connection is being replaced by a newer one") + default: + logger.DebugCF("wecom_aibot", "Unhandled event type", + map[string]any{"event_type": eventType}) + } +} + +// handleWSTextMessage dispatches a plain-text message to the agent and streams +// the reply back over the WebSocket connection. +func (c *WeComAIBotWSChannel) handleWSTextMessage(reqID string, msg WeComAIBotWSMessage) { + if msg.Text == nil { + logger.ErrorC("wecom_aibot", "text message missing text field") + return + } + c.dispatchWSAgentTask(reqID, msg, msg.Text.Content, nil) +} + +// handleWSImageMessage downloads and stores the inbound image, then dispatches +// it to the agent as a media-tagged message. +func (c *WeComAIBotWSChannel) handleWSImageMessage(reqID string, msg WeComAIBotWSMessage) { + if msg.Image == nil { + logger.WarnC("wecom_aibot", "Image message missing image field") + c.wsSendStreamFinish(reqID, wsGenerateID(), "Image message could not be processed.") + return + } + c.wsHandleMediaMessage(reqID, msg, msg.Image.URL, msg.Image.AESKey, "image") +} + +// wsHandleMediaMessage is a shared helper for image, file and video messages. +// It downloads the resource, stores it in MediaStore, and dispatches to the agent. +func (c *WeComAIBotWSChannel) wsHandleMediaMessage( + reqID string, msg WeComAIBotWSMessage, + resourceURL, aesKey, label string, +) { + chatID := wsChatID(msg) + + ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout) + defer cancel() + + ref, err := c.storeWSMedia(ctx, chatID, msg.MsgID, resourceURL, aesKey, wsLabelToDefaultExt(label)) + if err != nil { + logger.WarnCF("wecom_aibot", "Failed to download/store WS "+label, + map[string]any{"error": err.Error(), "url": resourceURL}) + c.wsSendStreamFinish(reqID, wsGenerateID(), + strings.ToUpper(label[:1])+label[1:]+" message could not be processed.") + return + } + + c.dispatchWSAgentTask(reqID, msg, "["+label+"]", []string{ref}) +} + +// handleWSMixedMessage handles mixed text+image messages. +// All text parts are collected into the content string; all image parts are +// downloaded and stored in MediaStore before dispatching to the agent. +func (c *WeComAIBotWSChannel) handleWSMixedMessage(reqID string, msg WeComAIBotWSMessage) { + if msg.Mixed == nil { + logger.WarnC("wecom_aibot", "Mixed message has no content") + c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.") + return + } + + chatID := wsChatID(msg) + + ctx, cancel := context.WithTimeout(c.ctx, wsImageDownloadTimeout) + defer cancel() + + var textParts []string + var mediaRefs []string + for _, item := range msg.Mixed.MsgItem { + switch item.MsgType { + case "text": + if item.Text != nil && item.Text.Content != "" { + textParts = append(textParts, item.Text.Content) + } + case "image": + if item.Image != nil { + ref, err := c.storeWSMedia(ctx, chatID, + msg.MsgID+"-"+wsGenerateID(), item.Image.URL, item.Image.AESKey, ".jpg") + if err != nil { + logger.WarnCF("wecom_aibot", "Failed to download/store mixed image", + map[string]any{"error": err.Error()}) + } else { + mediaRefs = append(mediaRefs, ref) + } + } + default: + logger.WarnCF("wecom_aibot", "Unsupported item type in mixed message", + map[string]any{"msgtype": item.MsgType}) + } + } + + if len(textParts) == 0 && len(mediaRefs) == 0 { + logger.WarnC("wecom_aibot", "Mixed message has no usable content") + c.wsSendStreamFinish(reqID, wsGenerateID(), "Mixed message type is not yet fully supported.") + return + } + + content := strings.Join(textParts, "\n") + if content == "" { + content = "[images]" + } + c.dispatchWSAgentTask(reqID, msg, content, mediaRefs) +} + +// dispatchWSAgentTask registers a new agent task, sends the opening stream frame, +// and starts a goroutine that runs the agent and streams the reply back. +// content is the text forwarded to the agent; mediaRefs are optional media +// store references attached to the inbound message. +func (c *WeComAIBotWSChannel) dispatchWSAgentTask( + reqID string, + msg WeComAIBotWSMessage, + content string, + mediaRefs []string, +) { + userID := msg.From.UserID + if userID == "" { + userID = "unknown" + } + // actualChatID is the real WeCom chat/user ID used for peer identification. + // reqID is used as the routing chatID so each turn is independently addressable. + actualChatID := wsChatID(msg) + + streamID := wsGenerateID() + chatType := wsChatTypeValue(msg.ChatType) + taskCtx, taskCancel := context.WithCancel(c.ctx) + + task := &wsTask{ + ReqID: reqID, + ChatID: actualChatID, + ChatType: chatType, + StreamID: streamID, + answerCh: make(chan string, 1), + ctx: taskCtx, + cancel: taskCancel, + } + // Each req_id is unique per WeCom turn; tasks run concurrently, no cancellation. + c.setReqState(reqID, &wsReqState{ + Task: task, + Route: wsLateReplyRoute{ + ChatID: actualChatID, + ChatType: chatType, + ReadyAt: time.Now().Add(wsStreamMaxDuration), + ExpiresAt: time.Now().Add(wsLateReplyRouteTTL), + }, + }) + + logger.DebugCF("wecom_aibot", "Registered new agent task", + map[string]any{"chat_id": actualChatID, "req_id": reqID, "stream_id": streamID}) + + // Send an empty stream opening frame (finish=false) immediately. + c.wsSendStreamChunk(reqID, streamID, false, "") + + go func() { + defer func() { + taskCancel() + c.clearReqTask(reqID, task) + }() + + sender := bus.SenderInfo{ + Platform: "wecom_aibot", + PlatformID: userID, + CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID), + DisplayName: userID, + } + peerKind := "direct" + if msg.ChatType == "group" { + peerKind = "group" + } + peer := bus.Peer{Kind: peerKind, ID: actualChatID} + metadata := map[string]string{ + "channel": "wecom_aibot", + "chat_id": actualChatID, + "chat_type": msg.ChatType, + "msg_type": msg.MsgType, + "msgid": msg.MsgID, + "aibotid": msg.AIBotID, + "stream_id": streamID, + } + // Pass reqID as chatID: OutboundMessage.ChatID = reqID → Send() finds tasks[reqID]. + c.HandleMessage(taskCtx, peer, reqID, userID, reqID, + content, mediaRefs, metadata, sender) + + // Wait for the agent reply. While waiting, send periodic finish=false + // hints so the user knows processing is still in progress. + // WeCom requires finish=true within 6 minutes of the first stream frame; + // wsStreamMaxDuration enforces that limit with a safety margin. + waitHints := []string{ + "⏳ Processing, please wait...", + "⏳ Still processing, please wait...", + "⏳ Almost there, please wait...", + } + ticker := time.NewTicker(wsStreamTickInterval) + defer ticker.Stop() + deadlineTimer := time.NewTimer(wsStreamMaxDuration) + defer deadlineTimer.Stop() + tickCount := 0 + for { + select { + case answer := <-task.answerCh: + // Split the answer into byte-bounded chunks and send as stream frames. + // All but the last carry finish=false; the final frame closes the stream. + chunks := splitWSContent(answer, wsStreamMaxContentBytes) + for i, chunk := range chunks { + c.wsSendStreamChunk(reqID, streamID, i == len(chunks)-1, chunk) + } + c.deleteReqState(reqID) + return + case <-ticker.C: + hint := waitHints[tickCount%len(waitHints)] + tickCount++ + logger.DebugCF("wecom_aibot", "Sending stream progress hint", + map[string]any{"chat_id": actualChatID, "tick": tickCount}) + c.wsSendStreamChunk(reqID, streamID, false, hint) + case <-deadlineTimer.C: + logger.WarnCF("wecom_aibot", + "Stream response deadline reached, closing stream; late reply will be pushed", + map[string]any{"chat_id": actualChatID}) + c.wsSendStreamFinish(reqID, streamID, + "⏳ Processing is taking longer than expected, the response will be sent as a follow-up message.") + return + case <-taskCtx.Done(): + // Give a short grace period so that a response queued in the bus + // just before cancellation can still be delivered. This closes a + // race where a rapid second message cancels this task after the + // agent already published but before Send() wrote to answerCh. + // + // The connection is gone at this point, so we cannot use + // wsSendStreamFinish. Try wsSendActivePush on the (possibly + // already-restored) connection; if that also fails, leave the + // route intact so Send() can push the reply once reconnected. + select { + case answer := <-task.answerCh: + if err := c.wsSendActivePush(task.ChatID, task.ChatType, answer); err != nil { + logger.WarnCF("wecom_aibot", + "Grace-period push failed after task cancellation; reply may be lost", + map[string]any{"req_id": reqID, "chat_id": task.ChatID, "error": err.Error()}) + } else { + c.deleteReqState(reqID) + } + case <-time.After(100 * time.Millisecond): + } + return + } + } + }() +} + +// handleWSVoiceMessage handles voice messages. +// WeCom transcribes voice to text in the callback; if the transcription is +// present it is dispatched as plain text to the agent. +func (c *WeComAIBotWSChannel) handleWSVoiceMessage(reqID string, msg WeComAIBotWSMessage) { + if msg.Voice != nil && msg.Voice.Content != "" { + c.dispatchWSAgentTask(reqID, msg, msg.Voice.Content, nil) + return + } + c.wsSendStreamFinish(reqID, wsGenerateID(), "Voice messages are not yet supported.") +} + +// handleWSFileMessage handles file messages. +func (c *WeComAIBotWSChannel) handleWSFileMessage(reqID string, msg WeComAIBotWSMessage) { + if msg.File == nil { + logger.WarnC("wecom_aibot", "File message missing file field") + c.wsSendStreamFinish(reqID, wsGenerateID(), "File message could not be processed.") + return + } + c.wsHandleMediaMessage(reqID, msg, msg.File.URL, msg.File.AESKey, "file") +} + +// handleWSVideoMessage handles video messages. +func (c *WeComAIBotWSChannel) handleWSVideoMessage(reqID string, msg WeComAIBotWSMessage) { + if msg.Video == nil { + logger.WarnC("wecom_aibot", "Video message missing video field") + c.wsSendStreamFinish(reqID, wsGenerateID(), "Video message could not be processed.") + return + } + c.wsHandleMediaMessage(reqID, msg, msg.Video.URL, msg.Video.AESKey, "video") +} + +// ---- WebSocket write helpers ---- + +// wsSendStreamChunk sends an aibot_respond_msg stream frame. +func (c *WeComAIBotWSChannel) wsSendStreamChunk(reqID, streamID string, finish bool, content string) { + logger.DebugCF("wecom_aibot", "Sending stream chunk", map[string]any{ + "stream_id": streamID, + "finish": finish, + "preview": utils.Truncate(content, 100), + }) + cmd := wsCommand{ + Cmd: "aibot_respond_msg", + Headers: wsHeaders{ReqID: reqID}, + Body: wsRespondMsgBody{ + MsgType: "stream", + Stream: &wsStreamContent{ + ID: streamID, + Finish: finish, + Content: content, + }, + }, + } + if err := c.writeWSAndWait(cmd, wsRespondMsgTimeout); err != nil { + logger.WarnCF("wecom_aibot", "Stream chunk ack failed", map[string]any{ + "req_id": reqID, + "stream_id": streamID, + "finish": finish, + "error": err, + }) + } +} + +// wsSendStreamFinish sends the final aibot_respond_msg frame (finish=true, no images). +func (c *WeComAIBotWSChannel) wsSendStreamFinish(reqID, streamID, content string) { + c.wsSendStreamChunk(reqID, streamID, true, content) +} + +// wsSendWelcomeMsg sends a text welcome message via aibot_respond_welcome_msg. +func (c *WeComAIBotWSChannel) wsSendWelcomeMsg(reqID, content string) { + logger.DebugCF("wecom_aibot", "Sending welcome message", map[string]any{"req_id": reqID}) + cmd := wsCommand{ + Cmd: "aibot_respond_welcome_msg", + Headers: wsHeaders{ReqID: reqID}, + Body: wsRespondMsgBody{ + MsgType: "text", + Text: &wsTextContent{Content: content}, + }, + } + if err := c.writeWSAndWait(cmd, wsWelcomeMsgTimeout); err != nil { + logger.WarnCF("wecom_aibot", "Welcome message ack failed", + map[string]any{"req_id": reqID, "error": err.Error()}) + } +} + +// wsSendActivePush sends a proactive markdown message using aibot_send_msg. +// Long content is automatically split into byte-bounded chunks (≤ wsStreamMaxContentBytes +// each) and delivered as consecutive messages. +// It is used as a fallback for late replies after stream response window expires. +func (c *WeComAIBotWSChannel) wsSendActivePush(chatID string, chatType uint32, content string) error { + if chatID == "" { + return fmt.Errorf("chatid is empty") + } + for _, chunk := range splitWSContent(content, wsStreamMaxContentBytes) { + reqID := wsGenerateID() + if err := c.writeWSAndWait(wsCommand{ + Cmd: "aibot_send_msg", + Headers: wsHeaders{ReqID: reqID}, + Body: wsSendMsgBody{ + ChatID: chatID, + ChatType: chatType, + MsgType: "markdown", + Markdown: &wsMarkdownContent{Content: chunk}, + }, + }, wsSendMsgTimeout); err != nil { + return err + } + } + return nil +} + +// writeWSAndWait writes cmd to the active connection and validates the command response. +func (c *WeComAIBotWSChannel) writeWSAndWait(cmd wsCommand, timeout time.Duration) error { + if cmd.Headers.ReqID == "" { + return fmt.Errorf("req_id is empty") + } + + c.connMu.Lock() + conn := c.conn + c.connMu.Unlock() + if conn == nil { + return fmt.Errorf("websocket not connected") + } + + resp, err := c.sendAndWait(conn, cmd.Headers.ReqID, cmd, timeout) + if err != nil { + return err + } + if resp.ErrCode != 0 { + return fmt.Errorf("%s rejected (errcode=%d): %s", cmd.Cmd, resp.ErrCode, resp.ErrMsg) + } + return nil +} + +// cancelAllTasks cancels every pending agent task; called when the connection drops. +// It also expires each task's stream window (ReadyAt = now) so that when the agent +// eventually delivers its reply via Send(), the message is forwarded via +// wsSendActivePush on the restored connection instead of being silently discarded. +func (c *WeComAIBotWSChannel) cancelAllTasks() { + c.reqStatesMu.Lock() + defer c.reqStatesMu.Unlock() + now := time.Now() + for _, state := range c.reqStates { + if state != nil && state.Task != nil { + state.Task.cancel() + state.Task = nil + // Expire the stream window immediately so Send() uses wsSendActivePush. + state.Route.ReadyAt = now + } + } +} + +func (c *WeComAIBotWSChannel) setReqState(reqID string, state *wsReqState) { + c.reqStatesMu.Lock() + defer c.reqStatesMu.Unlock() + now := time.Now() + for k, v := range c.reqStates { + if v == nil || now.After(v.Route.ExpiresAt) { + delete(c.reqStates, k) + } + } + c.reqStates[reqID] = state +} + +func (c *WeComAIBotWSChannel) getReqState(reqID string) (*wsTask, wsLateReplyRoute, bool) { + c.reqStatesMu.Lock() + defer c.reqStatesMu.Unlock() + state, ok := c.reqStates[reqID] + if !ok || state == nil { + return nil, wsLateReplyRoute{}, false + } + if time.Now().After(state.Route.ExpiresAt) { + delete(c.reqStates, reqID) + return nil, wsLateReplyRoute{}, false + } + return state.Task, state.Route, true +} + +func (c *WeComAIBotWSChannel) deleteReqState(reqID string) { + c.reqStatesMu.Lock() + delete(c.reqStates, reqID) + c.reqStatesMu.Unlock() +} + +func (c *WeComAIBotWSChannel) clearReqTask(reqID string, task *wsTask) { + c.reqStatesMu.Lock() + defer c.reqStatesMu.Unlock() + state, ok := c.reqStates[reqID] + if !ok || state == nil { + return + } + if state.Task == task { + state.Task = nil + } +} + +func wsChatTypeValue(chatType string) uint32 { + if chatType == "group" { + return 2 + } + return 1 +} + +// wsChatID returns the effective chat ID from a WS message. +// For group messages it is msg.ChatID; for single chats it falls back to the sender's UserID. +func wsChatID(msg WeComAIBotWSMessage) string { + if msg.ChatID != "" { + return msg.ChatID + } + return msg.From.UserID +} + +// wsGenerateID generates a random 10-character alphanumeric ID. +// It is package-level (not a method) so it can be shared by both channel modes. +func wsGenerateID() string { + return generateRandomID(10) +} + +// ---- Inbound media download helpers ---- + +// storeWSMedia downloads the resource at resourceURL (with optional AES-CBC +// decryption) and stores it in the MediaStore. The file extension is inferred +// from the HTTP Content-Type response header; defaultExt is used as a fallback +// when the content type is absent or unrecognized. +func (c *WeComAIBotWSChannel) storeWSMedia( + ctx context.Context, + chatID, msgID, resourceURL, aesKey, defaultExt string, +) (string, error) { + store := c.GetMediaStore() + if store == nil { + return "", fmt.Errorf("no media store available") + } + + const maxSize = 20 << 20 // 20 MB + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + resp, err := wsImageHTTPClient.Do(req) + if err != nil { + return "", fmt.Errorf("download: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download HTTP %d", resp.StatusCode) + } + + // Infer file extension from the Content-Type response header. + ext := wsMediaExtFromContentType(resp.Header.Get("Content-Type")) + if ext == "" { + ext = defaultExt + } + + // Buffer the media in memory, bounded to maxSize. + data, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxSize)+1)) + if err != nil { + return "", fmt.Errorf("read media: %w", err) + } + if len(data) > maxSize { + return "", fmt.Errorf("media too large (> %d MB)", maxSize>>20) + } + + // AES-CBC decryption if a key is present. + if aesKey != "" { + key, decErr := base64.StdEncoding.DecodeString(aesKey) + if decErr != nil || len(key) != 32 { + key, decErr = decodeWeComAESKey(aesKey) + if decErr != nil { + return "", fmt.Errorf("decode media AES key: %w", decErr) + } + } + data, err = decryptAESCBC(key, data) + if err != nil { + return "", fmt.Errorf("decrypt media: %w", err) + } + } + + // Write to a temp file. The file is owned by the MediaStore and deleted by + // store.ReleaseAll — no caller-side cleanup needed. + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err = os.MkdirAll(mediaDir, 0o700); err != nil { + return "", fmt.Errorf("mkdir: %w", err) + } + tmpFile, err := os.CreateTemp(mediaDir, msgID+"-*"+ext) + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + _, writeErr := tmpFile.Write(data) + closeErr := tmpFile.Close() + if writeErr != nil { + os.Remove(tmpPath) + return "", fmt.Errorf("write media: %w", writeErr) + } + if closeErr != nil { + os.Remove(tmpPath) + return "", fmt.Errorf("close media: %w", closeErr) + } + + scope := channels.BuildMediaScope("wecom_aibot", chatID, msgID) + ref, err := store.Store(tmpPath, media.MediaMeta{ + Filename: msgID + ext, + Source: "wecom_aibot", + }, scope) + if err != nil { + os.Remove(tmpPath) + return "", fmt.Errorf("store: %w", err) + } + return ref, nil +} + +// wsMediaExtFromContentType returns the lowercase file extension (with leading +// dot) for the given Content-Type value, or "" when the type is unrecognized. +func wsMediaExtFromContentType(contentType string) string { + if contentType == "" { + return "" + } + // Strip parameters (e.g. "image/jpeg; charset=utf-8" → "image/jpeg"). + mt := strings.ToLower(strings.TrimSpace(strings.SplitN(contentType, ";", 2)[0])) + switch mt { + case "image/jpeg", "image/jpg": + return ".jpg" + case "image/png": + return ".png" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "video/mp4": + return ".mp4" + case "video/mpeg", "video/x-mpeg": + return ".mpeg" + case "video/quicktime": + return ".mov" + case "video/webm": + return ".webm" + case "audio/mpeg", "audio/mp3": + return ".mp3" + case "audio/ogg": + return ".ogg" + case "audio/wav": + return ".wav" + case "application/pdf": + return ".pdf" + case "application/zip": + return ".zip" + case "application/x-rar-compressed", "application/vnd.rar": + return ".rar" + case "text/plain": + return ".txt" + case "application/msword": + return ".doc" + case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + return ".docx" + case "application/vnd.ms-excel": + return ".xls" + case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": + return ".xlsx" + case "application/vnd.ms-powerpoint": + return ".ppt" + case "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return ".pptx" + } + return "" +} + +// wsLabelToDefaultExt returns the default file extension for the given media label +// used in wsHandleMediaMessage. It is the fallback when Content-Type detection fails. +func wsLabelToDefaultExt(label string) string { + switch label { + case "image": + return ".jpg" + case "video": + return ".mp4" + default: // "file" and any future labels + return ".bin" + } +} + +// ---- Content length helpers ---- + +// splitWSContent splits content into chunks each fitting within maxBytes UTF-8 +// bytes, preserving code block integrity via channels.SplitMessage. +// When SplitMessage still produces an oversized chunk (e.g. dense CJK content), +// splitAtByteBoundary is applied as a last-resort byte-level fallback. +func splitWSContent(content string, maxBytes int) []string { + if len(content) <= maxBytes { + return []string{content} + } + // SplitMessage works in runes. Use maxBytes as the rune limit: for pure ASCII + // this is exact; for multibyte content the byte verification below catches + // any chunk that still overflows. + chunks := channels.SplitMessage(content, maxBytes) + var result []string + for _, chunk := range chunks { + if len(chunk) <= maxBytes { + result = append(result, chunk) + } else { + // Still too large in bytes (e.g. dense CJK); force-split at UTF-8 boundaries. + result = append(result, splitAtByteBoundary(chunk, maxBytes)...) + } + } + return result +} + +// splitAtByteBoundary splits s into parts each ≤ maxBytes bytes by walking back +// from the hard byte limit to find a valid UTF-8 rune start boundary. +// This is a last-resort fallback; it does not try to preserve code blocks. +func splitAtByteBoundary(s string, maxBytes int) []string { + var parts []string + for len(s) > maxBytes { + end := maxBytes + // Walk back past any UTF-8 continuation bytes (high two bits == 10). + for end > 0 && s[end]>>6 == 0b10 { + end-- + } + if end == 0 { + end = maxBytes // shouldn't happen with valid UTF-8 + } + parts = append(parts, s[:end]) + s = strings.TrimLeft(s[end:], " \t\n\r") + } + if s != "" { + parts = append(parts, s) + } + return parts +} diff --git a/pkg/channels/wecom/aibot_ws_test.go b/pkg/channels/wecom/aibot_ws_test.go new file mode 100644 index 000000000..0a533da5d --- /dev/null +++ b/pkg/channels/wecom/aibot_ws_test.go @@ -0,0 +1,295 @@ +package wecom + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" +) + +// newTestWSChannel creates a WeComAIBotWSChannel ready for unit testing. +func newTestWSChannel(t *testing.T) *WeComAIBotWSChannel { + t.Helper() + cfg := config.WeComAIBotConfig{ + Enabled: true, + BotID: "test_bot_id", + Secret: "test_secret", + } + ch, err := newWeComAIBotWSChannel(cfg, bus.NewMessageBus()) + if err != nil { + t.Fatalf("create WS channel: %v", err) + } + return ch +} + +// TestStoreWSMedia_NilStore verifies that storeWSMedia returns an error when no +// MediaStore has been injected. +func TestStoreWSMedia_NilStore(t *testing.T) { + ch := newTestWSChannel(t) + _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://any", "", ".jpg") + if err == nil { + t.Fatal("expected error when no MediaStore is set") + } +} + +// TestStoreWSMedia_HTTPError verifies that storeWSMedia propagates HTTP errors +// from the media server. +func TestStoreWSMedia_HTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "not found", http.StatusNotFound) + })) + defer srv.Close() + + ch := newTestWSChannel(t) + ch.SetMediaStore(media.NewFileMediaStore()) + + _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg") + if err == nil { + t.Fatal("expected error for HTTP 404") + } +} + +// TestStoreWSMedia_ServerUnavailable verifies that storeWSMedia returns a clear +// error when the media server cannot be reached. +func TestStoreWSMedia_ServerUnavailable(t *testing.T) { + ch := newTestWSChannel(t) + ch.SetMediaStore(media.NewFileMediaStore()) + + // Port 1 is reserved and will refuse the connection immediately. + _, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", "http://127.0.0.1:1", "", ".jpg") + if err == nil { + t.Fatal("expected error for unreachable server") + } +} + +// TestStoreWSMedia_Success_NoAES verifies the happy path: the media is downloaded, +// a media ref is returned, and the file persists and is readable via Resolve until +// ReleaseAll is called. The server returns no Content-Type, so the defaultExt is used. +func TestStoreWSMedia_Success_NoAES(t *testing.T) { + imageData := bytes.Repeat([]byte("x"), 256) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(imageData) + })) + defer srv.Close() + + ch := newTestWSChannel(t) + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + ref, err := ch.storeWSMedia(context.Background(), "chat1", "msg1", srv.URL, "", ".jpg") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if ref == "" { + t.Fatal("expected non-empty ref") + } + + // File must be accessible after storeWSMedia returns (no premature deletion). + path, err := store.Resolve(ref) + if err != nil { + t.Fatalf("ref should resolve: %v", err) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("file should exist at %s: %v", path, err) + } + if !bytes.Equal(got, imageData) { + t.Errorf("content mismatch: got len=%d, want len=%d", len(got), len(imageData)) + } + + // ReleaseAll must delete the file (store owns lifecycle). + scope := channels.BuildMediaScope("wecom_aibot", "chat1", "msg1") + if err := store.ReleaseAll(scope); err != nil { + t.Fatalf("ReleaseAll failed: %v", err) + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Errorf("file should have been deleted by ReleaseAll, stat err: %v", err) + } +} + +// TestStoreWSMedia_MultipleMessages verifies that concurrent media messages with +// different msgIDs do not collide and each resolve to distinct files. +func TestStoreWSMedia_MultipleMessages(t *testing.T) { + imageA := bytes.Repeat([]byte("a"), 64) + imageB := bytes.Repeat([]byte("b"), 64) + + srvA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(imageA) + })) + defer srvA.Close() + srvB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(imageB) + })) + defer srvB.Close() + + ch := newTestWSChannel(t) + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + refA, err := ch.storeWSMedia(context.Background(), "chat1", "msgA", srvA.URL, "", ".jpg") + if err != nil { + t.Fatalf("storeWSMedia A: %v", err) + } + refB, err := ch.storeWSMedia(context.Background(), "chat1", "msgB", srvB.URL, "", ".jpg") + if err != nil { + t.Fatalf("storeWSMedia B: %v", err) + } + if refA == refB { + t.Fatal("distinct messages must produce distinct refs") + } + + pathA, _ := store.Resolve(refA) + pathB, _ := store.Resolve(refB) + if pathA == pathB { + t.Fatal("distinct messages must be stored at distinct paths") + } + + gotA, _ := os.ReadFile(pathA) + gotB, _ := os.ReadFile(pathB) + if !bytes.Equal(gotA, imageA) { + t.Errorf("content mismatch for message A") + } + if !bytes.Equal(gotB, imageB) { + t.Errorf("content mismatch for message B") + } +} + +// TestStoreWSMedia_ContentTypeExt verifies that the file extension is inferred +// from the HTTP Content-Type header and the defaultExt fallback is used when the +// type is absent or unrecognized. +func TestStoreWSMedia_ContentTypeExt(t *testing.T) { + tests := []struct { + contentType string + wantExt string + }{ + {"image/jpeg", ".jpg"}, + {"image/png", ".png"}, + {"video/mp4", ".mp4"}, + {"application/pdf", ".pdf"}, + {"application/zip", ".zip"}, + // With parameters stripped. + {"video/mp4; codecs=avc1", ".mp4"}, + // Unknown type → falls back to defaultExt. + {"", ""}, + {"application/octet-stream", ""}, + } + for _, tc := range tests { + got := wsMediaExtFromContentType(tc.contentType) + if got != tc.wantExt { + t.Errorf("wsMediaExtFromContentType(%q) = %q, want %q", tc.contentType, got, tc.wantExt) + } + } + + // End-to-end: server returns Content-Type: video/mp4, defaultExt is .bin. + // The stored file should carry the .mp4 extension, not .bin. + payload := bytes.Repeat([]byte("v"), 128) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "video/mp4") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(payload) + })) + defer srv.Close() + + ch := newTestWSChannel(t) + store := media.NewFileMediaStore() + ch.SetMediaStore(store) + + ref, err := ch.storeWSMedia(context.Background(), "chat1", "vid1", srv.URL, "", ".bin") + if err != nil { + t.Fatalf("storeWSMedia: %v", err) + } + path, err := store.Resolve(ref) + if err != nil { + t.Fatalf("resolve: %v", err) + } + if ext := path[len(path)-4:]; ext != ".mp4" { + t.Errorf("expected .mp4 extension from Content-Type, got %q", ext) + } +} + +// TestSplitWSContent verifies byte-aware splitting of stream content. +func TestSplitWSContent(t *testing.T) { + t.Run("short content is not split", func(t *testing.T) { + chunks := splitWSContent("hello", 20480) + if len(chunks) != 1 || chunks[0] != "hello" { + t.Fatalf("unexpected chunks: %v", chunks) + } + }) + + t.Run("ASCII content split at byte boundary", func(t *testing.T) { + // Build a string just over the limit. + content := strings.Repeat("a", 20481) + chunks := splitWSContent(content, 20480) + if len(chunks) < 2 { + t.Fatalf("expected >= 2 chunks, got %d", len(chunks)) + } + for i, c := range chunks { + if len(c) > 20480 { + t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c)) + } + } + // Reassembled content must equal the original (possibly without leading + // whitespace that splitWSContent trims between chunks). + joined := strings.Join(chunks, "") + if len(joined) < len(content)-len(chunks) { + t.Errorf("joined length %d too short (original %d)", len(joined), len(content)) + } + }) + + t.Run("CJK content split within byte limit", func(t *testing.T) { + // Each CJK rune is 3 bytes in UTF-8. + // 7000 CJK chars = 21000 bytes, which exceeds 20480. + content := strings.Repeat("\u4e2d", 7000) + chunks := splitWSContent(content, 20480) + if len(chunks) < 2 { + t.Fatalf("expected >= 2 chunks for 21000-byte CJK content, got %d", len(chunks)) + } + for i, c := range chunks { + if len(c) > 20480 { + t.Errorf("chunk %d has %d bytes, want <= 20480", i, len(c)) + } + // Every chunk must be valid UTF-8. + if !strings.ContainsRune(c, '\u4e2d') && len(c) > 0 { + // quick plausibility check — content was pure CJK + } + } + }) +} + +// TestSplitAtByteBoundary verifies the last-resort byte-boundary splitter. +func TestSplitAtByteBoundary(t *testing.T) { + t.Run("ASCII fits in one chunk", func(t *testing.T) { + parts := splitAtByteBoundary("hello world", 100) + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d", len(parts)) + } + }) + + t.Run("splits at byte boundary, never mid-rune", func(t *testing.T) { + // 10 CJK characters = 30 bytes; split at 20 bytes. + s := strings.Repeat("\u6587", 10) // 10 × 3 bytes = 30 bytes + parts := splitAtByteBoundary(s, 20) + for i, p := range parts { + if len(p) > 20 { + t.Errorf("part %d has %d bytes, want <= 20", i, len(p)) + } + // Must be valid UTF-8 (no torn multi-byte sequences). + for j, r := range p { + if r == '\uFFFD' { + t.Errorf("part %d has replacement rune at position %d: torn UTF-8", i, j) + } + } + } + }) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index ece2a7dbf..70e89ea9f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -228,25 +228,31 @@ type SubTurnConfig struct { ConcurrencyTimeoutSec int `json:"concurrency_timeout_sec" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_CONCURRENCY_TIMEOUT_SEC"` } +type ToolFeedbackConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_ENABLED"` + MaxArgsLength int `json:"max_args_length" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_MAX_ARGS_LENGTH"` +} + type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` - Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead - ModelFallbacks []string `json:"model_fallbacks,omitempty"` - ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` - ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` - SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` - SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` - MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` - Routing *RoutingConfig `json:"routing,omitempty"` - SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all" - SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + ModelName string `json:"model_name" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` + Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Routing *RoutingConfig `json:"routing,omitempty"` + SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all" + SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"` + ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"` } const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB @@ -258,6 +264,19 @@ func (d *AgentDefaults) GetMaxMediaSize() int { return DefaultMaxMediaSize } +// GetToolFeedbackMaxArgsLength returns the max args preview length for tool feedback messages. +func (d *AgentDefaults) GetToolFeedbackMaxArgsLength() int { + if d.ToolFeedback.MaxArgsLength > 0 { + return d.ToolFeedback.MaxArgsLength + } + return 300 +} + +// IsToolFeedbackEnabled returns true when tool feedback messages should be sent to the chat. +func (d *AgentDefaults) IsToolFeedbackEnabled() bool { + return d.ToolFeedback.Enabled +} + // GetModelName returns the effective model name for the agent defaults. // It prefers the new "model_name" field but falls back to "model" for backward compatibility. func (d *AgentDefaults) GetModelName() string { @@ -463,15 +482,17 @@ type WeComAppConfig struct { } type WeComAIBotConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"` - MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps - WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"` + BotID string `json:"bot_id,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_BOT_ID"` + Secret string `json:"secret,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_SECRET"` + Token string `json:"token,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"` + WebhookPath string `json:"webhook_path,omitempty" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"` + MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` + WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"` } type PicoConfig struct { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 7b38d8463..26ade9c1a 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -36,6 +36,10 @@ func DefaultConfig() *Config { SummarizeMessageThreshold: 20, SummarizeTokenPercent: 75, SteeringMode: "one-at-a-time", + ToolFeedback: ToolFeedbackConfig{ + Enabled: true, + MaxArgsLength: 300, + }, }, }, Bindings: []AgentBinding{},