/skills` (builtin)
+
+For advanced/test setups, you can override the builtin skills root with:
+
+```bash
+export PICOCLAW_BUILTIN_SKILLS=/path/to/skills
+```
+
### 🔒 Security Sandbox
PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace.
@@ -846,7 +939,7 @@ This design also enables **multi-agent support** with flexible provider selectio
#### 📋 All Supported Vendors
| Vendor | `model` Prefix | Default API Base | Protocol | API Key |
-| ------------------- | ----------------- | --------------------------------------------------- | --------- | ---------------------------------------------------------------- |
+| ------------------- | ----------------- |-----------------------------------------------------| --------- | ---------------------------------------------------------------- |
| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) |
| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) |
| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
@@ -858,6 +951,7 @@ This design also enables **multi-agent support** with flexible provider selectio
| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) |
| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) |
| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) |
+| **LiteLLM Proxy** | `litellm/` | `http://localhost:4000/v1 | OpenAI | Your LiteLLM proxy key |
| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local |
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) |
| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) |
@@ -959,6 +1053,19 @@ This design also enables **multi-agent support** with flexible provider selectio
}
```
+**LiteLLM Proxy**
+
+```json
+{
+ "model_name": "lite-gpt4",
+ "model": "litellm/lite-gpt4",
+ "api_base": "http://localhost:4000/v1",
+ "api_key": "sk-..."
+}
+```
+
+PicoClaw strips only the outer `litellm/` prefix before sending the request, so proxy aliases like `litellm/lite-gpt4` send `lite-gpt4`, while `litellm/openai/gpt-4o` sends `openai/gpt-4o`.
+
#### Load Balancing
Configure multiple endpoints for the same model name—PicoClaw will automatically round-robin between them:
diff --git a/README.pt-br.md b/README.pt-br.md
index 61663e363..bfe655770 100644
--- a/README.pt-br.md
+++ b/README.pt-br.md
@@ -282,7 +282,7 @@ Converse com seu PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom.
| **QQ** | Fácil (AppID + AppSecret) |
| **DingTalk** | Médio (credenciais do app) |
| **LINE** | Médio (credenciais + webhook URL) |
-| **WeCom** | Médio (CorpID + configuração webhook) |
+| **WeCom AI Bot** | Médio (Token + chave AES) |
Telegram (Recomendado)
@@ -450,8 +450,6 @@ picoclaw gateway
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
@@ -465,11 +463,13 @@ O LINE requer HTTPS para webhooks. Use um reverse proxy ou tunnel:
```bash
# Exemplo com ngrok
-ngrok http 18791
+ngrok http 18790
```
Em seguida, configure a Webhook URL no LINE Developers Console para `https://seu-dominio/webhook/line` e habilite **Use webhook**.
+> **Nota**: O webhook do LINE é servido pelo Gateway compartilhado (padrão 127.0.0.1:18790). Use um proxy reverso/HTTPS ou túnel (como ngrok) para expor o Gateway de forma segura quando necessário.
+
**4. Executar**
```bash
@@ -478,19 +478,20 @@ picoclaw gateway
> Em chats de grupo, o bot responde apenas quando mencionado com @. As respostas citam a mensagem original.
-> **Docker Compose**: Adicione `ports: ["18791:18791"]` ao serviço `picoclaw-gateway` para expor a porta do webhook.
+> **Docker Compose**: Se você usa Docker Compose, exponha o Gateway (padrão 127.0.0.1:18790) se precisar acessar o webhook LINE externamente, por exemplo `ports: ["18790:18790"]`.
WeCom (WeChat Work)
-O PicoClaw suporta dois tipos de integração WeCom:
+O PicoClaw suporta três tipos de integração WeCom:
-**Opção 1: WeCom Bot (Robô Inteligente)** - Configuração mais fácil, suporta chats em grupo
-**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas
+**Opção 1: WeCom Bot (Robô)** - Configuração mais fácil, suporta chats em grupo
+**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas, somente chat privado
+**Opção 3: WeCom AI Bot (Robô Inteligente)** - Bot IA oficial, respostas em streaming, suporta grupo e privado
-Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para instruções detalhadas.
+Veja o [Guia de Configuração WeCom AI Bot](docs/channels/wecom/wecom_aibot/README.zh.md) para instruções detalhadas.
**Configuração Rápida - WeCom Bot:**
@@ -509,8 +510,6 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": []
}
@@ -518,6 +517,8 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
}
```
+> **Nota**: O webhook do WeCom Bot é atendido pelo Gateway compartilhado (padrão 127.0.0.1:18790). Use um proxy reverso/HTTPS ou túnel para expor o Gateway em produção.
+
**Configuração Rápida - WeCom App:**
**1. Criar um aplicativo**
@@ -529,7 +530,7 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
**2. Configurar recebimento de mensagens**
* Nos detalhes do aplicativo, clique em "Receber Mensagens" → "Configurar API"
-* Defina a URL como `http://your-server:18792/webhook/wecom-app`
+* Defina a URL como `http://your-server:18790/webhook/wecom-app`
* Gere o **Token** e o **EncodingAESKey**
**3. Configurar**
@@ -544,8 +545,6 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": []
}
@@ -559,7 +558,40 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
picoclaw gateway
```
-> **Nota**: O WeCom App requer a abertura da porta 18792 para callbacks de webhook. Use um proxy reverso para HTTPS em produção.
+> **Nota**: O WeCom App (callbacks de webhook) é servido pelo Gateway compartilhado (padrão 127.0.0.1:18790). Em produção use um proxy reverso HTTPS para expor a porta do Gateway, ou atualize `PICOCLAW_GATEWAY_HOST` para `0.0.0.0` se necessário.
+
+**Configuração Rápida - WeCom AI Bot:**
+
+**1. Criar um AI Bot**
+
+* Acesse o Console de Administração WeCom → Gerenciamento de Aplicativos → AI Bot
+* Configure a URL de callback: `http://your-server:18791/webhook/wecom-aibot`
+* Copie o **Token** e gere o **EncodingAESKey**
+
+**2. Configurar**
+
+```json
+{
+ "channels": {
+ "wecom_aibot": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "allow_from": [],
+ "welcome_message": "Olá! Como posso ajudá-lo?"
+ }
+ }
+}
+```
+
+**3. Executar**
+
+```bash
+picoclaw gateway
+```
+
+> **Nota**: O WeCom AI Bot usa protocolo de pull em streaming — sem preocupações com timeout de resposta. Tarefas longas (>5,5 min) alternam automaticamente para entrega via `response_url`.
@@ -573,6 +605,31 @@ Conecte o PicoClaw a Rede Social de Agentes simplesmente enviando uma única men
Arquivo de configuração: `~/.picoclaw/config.json`
+### Variáveis de Ambiente
+
+Você pode substituir os caminhos padrão usando variáveis de ambiente. Isso é útil para instalações portáteis, implantações em contêineres ou para executar o picoclaw como um serviço do sistema. Essas variáveis são independentes e controlam caminhos diferentes.
+
+| Variável | Descrição | Caminho Padrão |
+|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
+| `PICOCLAW_CONFIG` | Substitui o caminho para o arquivo de configuração. Isso informa diretamente ao picoclaw qual `config.json` carregar, ignorando todos os outros locais. | `~/.picoclaw/config.json` |
+| `PICOCLAW_HOME` | Substitui o diretório raiz dos dados do picoclaw. Isso altera o local padrão do `workspace` e de outros diretórios de dados. | `~/.picoclaw` |
+
+**Exemplos:**
+
+```bash
+# Executar o picoclaw usando um arquivo de configuração específico
+# O caminho do workspace será lido de dentro desse arquivo de configuração
+PICOCLAW_CONFIG=/etc/picoclaw/production.json picoclaw gateway
+
+# Executar o picoclaw com todos os seus dados armazenados em /opt/picoclaw
+# A configuração será carregada do ~/.picoclaw/config.json padrão
+# O workspace será criado em /opt/picoclaw/workspace
+PICOCLAW_HOME=/opt/picoclaw picoclaw agent
+
+# Use ambos para uma configuração totalmente personalizada
+PICOCLAW_HOME=/srv/picoclaw PICOCLAW_CONFIG=/srv/picoclaw/main.json picoclaw gateway
+```
+
### Estrutura do Workspace
O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/workspace`):
diff --git a/README.vi.md b/README.vi.md
index f8ece7eda..b30659614 100644
--- a/README.vi.md
+++ b/README.vi.md
@@ -3,7 +3,7 @@
PicoClaw: Trợ lý AI Siêu Nhẹ viết bằng Go
-Phần cứng $10 · RAM 10MB · Khởi động 1 giây · 皮皮虾,我们走!
+Phần cứng $10 · RAM 10MB · Khởi động 1 giây · Nào, xuất phát!
@@ -256,7 +256,7 @@ Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk, LINE hoặc WeCom.
| **QQ** | Dễ (AppID + AppSecret) |
| **DingTalk** | Trung bình (app credentials) |
| **LINE** | Trung bình (credentials + webhook URL) |
-| **WeCom** | Trung bình (CorpID + cấu hình webhook) |
+| **WeCom AI Bot** | Trung bình (Token + khóa AES) |
Telegram (Khuyên dùng)
@@ -424,8 +424,6 @@ picoclaw gateway
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
@@ -439,7 +437,7 @@ LINE yêu cầu HTTPS cho webhook. Sử dụng reverse proxy hoặc tunnel:
```bash
# Ví dụ với ngrok
-ngrok http 18791
+ngrok http 18790
```
Sau đó cài đặt Webhook URL trong LINE Developers Console thành `https://your-domain/webhook/line` và bật **Use webhook**.
@@ -452,19 +450,20 @@ picoclaw gateway
> Trong nhóm chat, bot chỉ phản hồi khi được @mention. Các câu trả lời sẽ trích dẫn tin nhắn gốc.
-> **Docker Compose**: Thêm `ports: ["18791:18791"]` vào service `picoclaw-gateway` để mở port webhook.
+> **Docker Compose**: Nếu bạn cần mở port webhook cục bộ, hãy thêm một rule chuyển tiếp từ port Gateway (mặc định 18790) tới host. Lưu ý: LINE webhook được phục vụ bởi Gateway HTTP chung (mặc định 127.0.0.1:18790).
WeCom (WeChat Work)
-PicoClaw hỗ trợ hai loại tích hợp WeCom:
+PicoClaw hỗ trợ ba loại tích hợp WeCom:
-**Tùy chọn 1: WeCom Bot (Robot Thông minh)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm
-**Tùy chọn 2: WeCom App (Ứng dụng Tự xây dựng)** - Nhiều tính năng hơn, nhắn tin chủ động
+**Tùy chọn 1: WeCom Bot (Robot)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm
+**Tùy chọn 2: WeCom App (Ứng dụng Tùy chỉnh)** - Nhiều tính năng hơn, nhắn tin chủ động, chỉ chat riêng tư
+**Tùy chọn 3: WeCom AI Bot (Bot Thông Minh)** - Bot AI chính thức, phản hồi streaming, hỗ trợ nhóm và riêng tư
-Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) để biết hướng dẫn chi tiết.
+Xem [Hướng dẫn Cấu hình WeCom AI Bot](docs/channels/wecom/wecom_aibot/README.zh.md) để biết hướng dẫn chi tiết.
**Thiết lập Nhanh - WeCom Bot:**
@@ -483,8 +482,6 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": []
}
@@ -492,6 +489,8 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
}
```
+> **Lưu ý:** Các endpoint webhook của WeCom Bot được phục vụ bởi máy chủ Gateway HTTP dùng chung (mặc định 127.0.0.1:18790). Nếu bạn cần truy cập từ bên ngoài, hãy cấu hình reverse proxy hoặc mở cổng Gateway tương ứng.
+
**Thiết lập Nhanh - WeCom App:**
**1. Tạo ứng dụng**
@@ -503,7 +502,7 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
**2. Cấu hình nhận tin nhắn**
* Trong chi tiết ứng dụng, nhấp vào "Nhận Tin nhắn" → "Thiết lập API"
-* Đặt URL thành `http://your-server:18792/webhook/wecom-app`
+* Đặt URL thành `http://your-server:18790/webhook/wecom-app`
* Tạo **Token** và **EncodingAESKey**
**3. Cấu hình**
@@ -518,8 +517,6 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": []
}
@@ -533,7 +530,40 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
picoclaw gateway
```
-> **Lưu ý**: WeCom App yêu cầu mở cổng 18792 cho callback webhook. Sử dụng proxy ngược cho HTTPS trong môi trường sản xuất.
+> **Lưu ý**: WeCom App callback webhook được phục vụ bởi Gateway HTTP chung (mặc định 127.0.0.1:18790). Sử dụng proxy ngược để cung cấp HTTPS trong môi trường production nếu cần.
+
+**Thiết lập Nhanh - WeCom AI Bot:**
+
+**1. Tạo AI Bot**
+
+* Truy cập Bảng điều khiển Quản trị WeCom → Quản lý Ứng dụng → AI Bot
+* Cấu hình URL callback: `http://your-server:18791/webhook/wecom-aibot`
+* Sao chép **Token** và tạo **EncodingAESKey**
+
+**2. Cấu hình**
+
+```json
+{
+ "channels": {
+ "wecom_aibot": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "allow_from": [],
+ "welcome_message": "Xin chào! Tôi có thể giúp gì cho bạn?"
+ }
+ }
+}
+```
+
+**3. Chạy**
+
+```bash
+picoclaw gateway
+```
+
+> **Lưu ý**: WeCom AI Bot sử dụng giao thức pull streaming — không lo timeout phản hồi. Tác vụ dài (>5,5 phút) tự động chuyển sang gửi qua `response_url`.
@@ -547,6 +577,31 @@ Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một
File cấu hình: `~/.picoclaw/config.json`
+### Biến môi trường
+
+Bạn có thể ghi đè các đường dẫn mặc định bằng cách sử dụng các biến môi trường. Điều này hữu ích cho việc cài đặt di động, triển khai container hóa hoặc chạy picoclaw như một dịch vụ hệ thống. Các biến này độc lập và kiểm soát các đường dẫn khác nhau.
+
+| Biến | Mô tả | Đường dẫn mặc định |
+|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
+| `PICOCLAW_CONFIG` | Ghi đè đường dẫn đến file cấu hình. Điều này trực tiếp yêu cầu picoclaw tải file `config.json` nào, bỏ qua tất cả các vị trí khác. | `~/.picoclaw/config.json` |
+| `PICOCLAW_HOME` | Ghi đè thư mục gốc cho dữ liệu picoclaw. Điều này thay đổi vị trí mặc định của `workspace` và các thư mục dữ liệu khác. | `~/.picoclaw` |
+
+**Ví dụ:**
+
+```bash
+# Chạy picoclaw bằng một file cấu hình cụ thể
+# Đường dẫn workspace sẽ được đọc từ trong file cấu hình đó
+PICOCLAW_CONFIG=/etc/picoclaw/production.json picoclaw gateway
+
+# Chạy picoclaw với tất cả dữ liệu được lưu trữ trong /opt/picoclaw
+# Cấu hình sẽ được tải từ ~/.picoclaw/config.json mặc định
+# Workspace sẽ được tạo tại /opt/picoclaw/workspace
+PICOCLAW_HOME=/opt/picoclaw picoclaw agent
+
+# Sử dụng cả hai để có thiết lập tùy chỉnh hoàn toàn
+PICOCLAW_HOME=/srv/picoclaw PICOCLAW_CONFIG=/srv/picoclaw/main.json picoclaw gateway
+```
+
### Cấu trúc Workspace
PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định: `~/.picoclaw/workspace`):
diff --git a/README.zh.md b/README.zh.md
index 7c9351cb4..db96ba555 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -290,6 +290,8 @@ picoclaw agent -m "2+2 等于几?"
PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方。
+> **注意**: 所有 Webhook 类渠道(LINE、WeCom 等)均挂载在同一个 Gateway HTTP 服务器上(`gateway.host`:`gateway.port`,默认 `127.0.0.1:18790`),无需为每个渠道单独配置端口。注意:飞书(Feishu)使用 WebSocket/SDK 模式,不通过该共享 HTTP webhook 服务器接收消息。
+
### 核心渠道
| 渠道 | 设置难度 | 特性说明 | 文档链接 |
@@ -299,7 +301,7 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方
| **Slack** | ⭐ 简单 | **Socket Mode** (无需公网 IP),企业级支持 | [查看文档](docs/channels/slack/README.zh.md) |
| **QQ** | ⭐⭐ 中等 | 官方机器人 API,适合国内社群 | [查看文档](docs/channels/qq/README.zh.md) |
| **钉钉 (DingTalk)** | ⭐⭐ 中等 | Stream 模式无需公网,企业办公首选 | [查看文档](docs/channels/dingtalk/README.zh.md) |
-| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)和自建应用(API) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) |
+| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)、自建应用(API)和智能机器人(AI Bot) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) / [AI Bot 文档](docs/channels/wecom/wecom_aibot/README.zh.md) |
| **飞书 (Feishu)** | ⭐⭐⭐ 较难 | 企业级协作,功能丰富 | [查看文档](docs/channels/feishu/README.zh.md) |
| **Line** | ⭐⭐⭐ 较难 | 需要 HTTPS Webhook | [查看文档](docs/channels/line/README.zh.md) |
| **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) |
@@ -315,6 +317,31 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方
配置文件路径: `~/.picoclaw/config.json`
+### 环境变量
+
+你可以使用环境变量覆盖默认路径。这对于便携安装、容器化部署或将 picoclaw 作为系统服务运行非常有用。这些变量是独立的,控制不同的路径。
+
+| 变量 | 描述 | 默认路径 |
+|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
+| `PICOCLAW_CONFIG` | 覆盖配置文件的路径。这直接告诉 picoclaw 加载哪个 `config.json`,忽略所有其他位置。 | `~/.picoclaw/config.json` |
+| `PICOCLAW_HOME` | 覆盖 picoclaw 数据根目录。这会更改 `workspace` 和其他数据目录的默认位置。 | `~/.picoclaw` |
+
+**示例:**
+
+```bash
+# 使用特定的配置文件运行 picoclaw
+# 工作区路径将从该配置文件中读取
+PICOCLAW_CONFIG=/etc/picoclaw/production.json picoclaw gateway
+
+# 在 /opt/picoclaw 中存储所有数据运行 picoclaw
+# 配置将从默认的 ~/.picoclaw/config.json 加载
+# 工作区将在 /opt/picoclaw/workspace 创建
+PICOCLAW_HOME=/opt/picoclaw picoclaw agent
+
+# 同时使用两者进行完全自定义设置
+PICOCLAW_HOME=/srv/picoclaw PICOCLAW_CONFIG=/srv/picoclaw/main.json picoclaw gateway
+```
+
### 工作区布局 (Workspace Layout)
PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`):
@@ -335,6 +362,20 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work
```
+### 技能来源 (Skill Sources)
+
+默认情况下,技能会按以下顺序加载:
+
+1. `~/.picoclaw/workspace/skills`(工作区)
+2. `~/.picoclaw/skills`(全局)
+3. `/skills`(内置)
+
+在高级/测试场景下,可通过以下环境变量覆盖内置技能目录:
+
+```bash
+export PICOCLAW_BUILTIN_SKILLS=/path/to/skills
+```
+
### 心跳 / 周期性任务 (Heartbeat)
PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件:
diff --git a/assets/wechat.png b/assets/wechat.png
index 1900c7556..1c0b88295 100644
Binary files a/assets/wechat.png and b/assets/wechat.png differ
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/channel.go b/cmd/picoclaw-launcher-tui/internal/ui/channel.go
index ad9171424..49a6ccc5d 100644
--- a/cmd/picoclaw-launcher-tui/internal/ui/channel.go
+++ b/cmd/picoclaw-launcher-tui/internal/ui/channel.go
@@ -10,8 +10,8 @@ import (
picoclawconfig "github.com/sipeed/picoclaw/pkg/config"
)
-func (s *appState) channelMenu() tview.Primitive {
- items := []MenuItem{
+func (s *appState) buildChannelMenuItems() []MenuItem {
+ return []MenuItem{
{Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
channelItem(
"Telegram",
@@ -86,8 +86,10 @@ func (s *appState) channelMenu() tview.Primitive {
func() { s.push("channel-wecomapp", s.wecomAppForm()) },
),
}
+}
- menu := NewMenu("Channels", items)
+func (s *appState) channelMenu() tview.Primitive {
+ menu := NewMenu("Channels", s.buildChannelMenuItems())
menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
if event.Key() == tcell.KeyEsc {
s.pop()
@@ -103,199 +105,72 @@ func (s *appState) channelMenu() tview.Primitive {
}
func refreshChannelMenuFromState(menu *Menu, s *appState) {
- items := []MenuItem{
- {Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
- channelItem(
- "Telegram",
- "Telegram bot settings",
- s.config.Channels.Telegram.Enabled,
- func() { s.push("channel-telegram", s.telegramForm()) },
- ),
- channelItem(
- "Discord",
- "Discord bot settings",
- s.config.Channels.Discord.Enabled,
- func() { s.push("channel-discord", s.discordForm()) },
- ),
- channelItem(
- "QQ",
- "QQ bot settings",
- s.config.Channels.QQ.Enabled,
- func() { s.push("channel-qq", s.qqForm()) },
- ),
- channelItem(
- "MaixCam",
- "MaixCam gateway",
- s.config.Channels.MaixCam.Enabled,
- func() { s.push("channel-maixcam", s.maixcamForm()) },
- ),
- channelItem(
- "WhatsApp",
- "WhatsApp bridge",
- s.config.Channels.WhatsApp.Enabled,
- func() { s.push("channel-whatsapp", s.whatsappForm()) },
- ),
- channelItem(
- "Feishu",
- "Feishu bot settings",
- s.config.Channels.Feishu.Enabled,
- func() { s.push("channel-feishu", s.feishuForm()) },
- ),
- channelItem(
- "DingTalk",
- "DingTalk bot settings",
- s.config.Channels.DingTalk.Enabled,
- func() { s.push("channel-dingtalk", s.dingtalkForm()) },
- ),
- channelItem(
- "Slack",
- "Slack bot settings",
- s.config.Channels.Slack.Enabled,
- func() { s.push("channel-slack", s.slackForm()) },
- ),
- channelItem(
- "LINE",
- "LINE bot settings",
- s.config.Channels.LINE.Enabled,
- func() { s.push("channel-line", s.lineForm()) },
- ),
- channelItem(
- "OneBot",
- "OneBot settings",
- s.config.Channels.OneBot.Enabled,
- func() { s.push("channel-onebot", s.onebotForm()) },
- ),
- channelItem(
- "WeCom",
- "WeCom bot settings",
- s.config.Channels.WeCom.Enabled,
- func() { s.push("channel-wecom", s.wecomForm()) },
- ),
- channelItem(
- "WeCom App",
- "WeCom App settings",
- s.config.Channels.WeComApp.Enabled,
- func() { s.push("channel-wecomapp", s.wecomAppForm()) },
- ),
- }
- menu.applyItems(items)
+ menu.applyItems(s.buildChannelMenuItems())
}
func (s *appState) telegramForm() tview.Primitive {
cfg := &s.config.Channels.Telegram
- form := baseChannelForm("Telegram", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("Telegram", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
cfg.Token = strings.TrimSpace(text)
})
form.AddInputField("Proxy", cfg.Proxy, 128, nil, func(text string) {
cfg.Proxy = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) discordForm() tview.Primitive {
cfg := &s.config.Channels.Discord
- form := baseChannelForm("Discord", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("Discord", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
cfg.Token = strings.TrimSpace(text)
})
form.AddCheckbox("Mention Only", cfg.MentionOnly, func(checked bool) {
cfg.MentionOnly = checked
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) qqForm() tview.Primitive {
cfg := &s.config.Channels.QQ
- form := baseChannelForm("QQ", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("QQ", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
cfg.AppID = strings.TrimSpace(text)
})
form.AddInputField("App Secret", cfg.AppSecret, 128, nil, func(text string) {
cfg.AppSecret = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) maixcamForm() tview.Primitive {
cfg := &s.config.Channels.MaixCam
- form := baseChannelForm("MaixCam", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("MaixCam", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Host", cfg.Host, 64, nil, func(text string) {
cfg.Host = strings.TrimSpace(text)
})
addIntField(form, "Port", cfg.Port, func(value int) { cfg.Port = value })
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) whatsappForm() tview.Primitive {
cfg := &s.config.Channels.WhatsApp
- form := baseChannelForm("WhatsApp", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("WhatsApp", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Bridge URL", cfg.BridgeURL, 128, nil, func(text string) {
cfg.BridgeURL = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) feishuForm() tview.Primitive {
cfg := &s.config.Channels.Feishu
- form := baseChannelForm("Feishu", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("Feishu", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
cfg.AppID = strings.TrimSpace(text)
})
@@ -308,66 +183,39 @@ func (s *appState) feishuForm() tview.Primitive {
form.AddInputField("Verification Token", cfg.VerificationToken, 128, nil, func(text string) {
cfg.VerificationToken = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) dingtalkForm() tview.Primitive {
cfg := &s.config.Channels.DingTalk
- form := baseChannelForm("DingTalk", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("DingTalk", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Client ID", cfg.ClientID, 64, nil, func(text string) {
cfg.ClientID = strings.TrimSpace(text)
})
form.AddInputField("Client Secret", cfg.ClientSecret, 128, nil, func(text string) {
cfg.ClientSecret = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) slackForm() tview.Primitive {
cfg := &s.config.Channels.Slack
- form := baseChannelForm("Slack", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("Slack", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Bot Token", cfg.BotToken, 128, nil, func(text string) {
cfg.BotToken = strings.TrimSpace(text)
})
form.AddInputField("App Token", cfg.AppToken, 128, nil, func(text string) {
cfg.AppToken = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) lineForm() tview.Primitive {
cfg := &s.config.Channels.LINE
- form := baseChannelForm("LINE", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("LINE", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Channel Secret", cfg.ChannelSecret, 128, nil, func(text string) {
cfg.ChannelSecret = strings.TrimSpace(text)
})
@@ -381,22 +229,13 @@ func (s *appState) lineForm() tview.Primitive {
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
cfg.WebhookPath = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) onebotForm() tview.Primitive {
cfg := &s.config.Channels.OneBot
- form := baseChannelForm("OneBot", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("OneBot", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("WS URL", cfg.WSUrl, 128, nil, func(text string) {
cfg.WSUrl = strings.TrimSpace(text)
})
@@ -418,22 +257,13 @@ func (s *appState) onebotForm() tview.Primitive {
cfg.GroupTriggerPrefix = splitCSV(text)
},
)
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
return wrapWithBack(form, s)
}
func (s *appState) wecomForm() tview.Primitive {
cfg := &s.config.Channels.WeCom
- form := baseChannelForm("WeCom", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("WeCom", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
cfg.Token = strings.TrimSpace(text)
})
@@ -450,9 +280,7 @@ func (s *appState) wecomForm() tview.Primitive {
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
cfg.WebhookPath = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
addIntField(
form,
"Reply Timeout",
@@ -464,14 +292,7 @@ func (s *appState) wecomForm() tview.Primitive {
func (s *appState) wecomAppForm() tview.Primitive {
cfg := &s.config.Channels.WeComApp
- form := baseChannelForm("WeCom App", cfg.Enabled, func(v bool) {
- cfg.Enabled = v
- s.dirty = true
- refreshMainMenuIfPresent(s)
- if menu, ok := s.menus["channel"]; ok {
- refreshChannelMenuFromState(menu, s)
- }
- })
+ form := baseChannelForm("WeCom App", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
form.AddInputField("Corp ID", cfg.CorpID, 64, nil, func(text string) {
cfg.CorpID = strings.TrimSpace(text)
})
@@ -492,9 +313,7 @@ func (s *appState) wecomAppForm() tview.Primitive {
form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
cfg.WebhookPath = strings.TrimSpace(text)
})
- form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) {
- cfg.AllowFrom = splitCSV(text)
- })
+ addAllowFromField(form, &cfg.AllowFrom)
addIntField(
form,
"Reply Timeout",
@@ -504,6 +323,23 @@ func (s *appState) wecomAppForm() tview.Primitive {
return wrapWithBack(form, s)
}
+func (s *appState) makeChannelOnEnabled(enabledPtr *bool) func(bool) {
+ return func(v bool) {
+ *enabledPtr = v
+ s.dirty = true
+ refreshMainMenuIfPresent(s)
+ if menu, ok := s.menus["channel"]; ok {
+ refreshChannelMenuFromState(menu, s)
+ }
+ }
+}
+
+func addAllowFromField(form *tview.Form, allowFrom *picoclawconfig.FlexibleStringSlice) {
+ form.AddInputField("Allow From", strings.Join(*allowFrom, ","), 128, nil, func(text string) {
+ *allowFrom = splitCSV(text)
+ })
+}
+
func baseChannelForm(title string, enabled bool, onEnabled func(bool)) *tview.Form {
form := tview.NewForm()
form.SetBorder(true).SetTitle(fmt.Sprintf("Channel: %s", title))
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/style.go b/cmd/picoclaw-launcher-tui/internal/ui/style.go
index ff4f8b1a8..68cdd60b9 100644
--- a/cmd/picoclaw-launcher-tui/internal/ui/style.go
+++ b/cmd/picoclaw-launcher-tui/internal/ui/style.go
@@ -5,6 +5,19 @@ import (
"github.com/rivo/tview"
)
+const (
+ colorBlue = "[#3e5db9]"
+ colorRed = "[#d54646]"
+ banner = "\r\n[::b]" +
+ colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" +
+ colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" +
+ colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" +
+ colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" +
+ colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
+ colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " +
+ "[:]"
+)
+
func applyStyles() {
tview.Styles.PrimitiveBackgroundColor = tcell.NewRGBColor(12, 13, 22)
tview.Styles.ContrastBackgroundColor = tcell.NewRGBColor(34, 19, 53)
@@ -24,14 +37,7 @@ func bannerView() *tview.TextView {
text.SetDynamicColors(true)
text.SetTextAlign(tview.AlignCenter)
text.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor)
- text.SetText(
- "[::b][#84aaff]██████╗ ██╗ ██████╗ ██████╗ ██████╗██╗ █████╗ ██╗ ██╗\n" +
- "[#84aaff]██╔══██╗██║██╔════╝██╔═══██╗██╔════╝██║ ██╔══██╗██║ ██║\n" +
- "[#84aaff]██████╔╝██║██║ ██║ ██║██║ ██║ ███████║██║ █╗ ██║\n" +
- "[#84aaff]██╔═══╝ ██║██║ ██║ ██║██║ ██║ ██╔══██║██║███╗██║\n" +
- "[#84aaff]██║ ██║╚██████╗╚██████╔╝╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
- "[#84aaff]╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝",
- )
+ text.SetText(banner)
text.SetBorder(false)
return text
}
diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go
index 1f52df5dd..9655d3c08 100644
--- a/cmd/picoclaw/internal/helpers.go
+++ b/cmd/picoclaw/internal/helpers.go
@@ -19,6 +19,9 @@ var (
)
func GetConfigPath() string {
+ if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" {
+ return configPath
+ }
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "config.json")
}
diff --git a/cmd/picoclaw/internal/helpers_test.go b/cmd/picoclaw/internal/helpers_test.go
index 9342d141d..47e2f8c07 100644
--- a/cmd/picoclaw/internal/helpers_test.go
+++ b/cmd/picoclaw/internal/helpers_test.go
@@ -95,3 +95,13 @@ func TestGetConfigPath_Windows(t *testing.T) {
func TestGetVersion(t *testing.T) {
assert.Equal(t, "dev", GetVersion())
}
+
+func TestGetConfigPath_WithEnv(t *testing.T) {
+ t.Setenv("PICOCLAW_CONFIG", "/tmp/custom/config.json")
+ t.Setenv("HOME", "/tmp/home") // Also set home to ensure env is preferred
+
+ got := GetConfigPath()
+ want := "/tmp/custom/config.json"
+
+ assert.Equal(t, want, got)
+}
diff --git a/cmd/picoclaw/internal/onboard/helpers_test.go b/cmd/picoclaw/internal/onboard/helpers_test.go
new file mode 100644
index 000000000..f3e0c92e0
--- /dev/null
+++ b/cmd/picoclaw/internal/onboard/helpers_test.go
@@ -0,0 +1,25 @@
+package onboard
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestCopyEmbeddedToTargetUsesAgentsMarkdown(t *testing.T) {
+ targetDir := t.TempDir()
+
+ if err := copyEmbeddedToTarget(targetDir); err != nil {
+ t.Fatalf("copyEmbeddedToTarget() error = %v", err)
+ }
+
+ agentsPath := filepath.Join(targetDir, "AGENTS.md")
+ if _, err := os.Stat(agentsPath); err != nil {
+ t.Fatalf("expected %s to exist: %v", agentsPath, err)
+ }
+
+ legacyPath := filepath.Join(targetDir, "AGENT.md")
+ if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
+ t.Fatalf("expected legacy file %s to be absent, got err=%v", legacyPath, err)
+ }
+}
diff --git a/cmd/picoclaw/internal/skills/command.go b/cmd/picoclaw/internal/skills/command.go
index 7f8bd011d..65eb127b9 100644
--- a/cmd/picoclaw/internal/skills/command.go
+++ b/cmd/picoclaw/internal/skills/command.go
@@ -71,7 +71,7 @@ func NewSkillsCommand() *cobra.Command {
newInstallBuiltinCommand(workspaceFn),
newListBuiltinCommand(),
newRemoveCommand(installerFn),
- newSearchCommand(installerFn),
+ newSearchCommand(),
newShowCommand(loaderFn),
)
diff --git a/cmd/picoclaw/internal/skills/helpers.go b/cmd/picoclaw/internal/skills/helpers.go
index 439b81a4f..a59a2013a 100644
--- a/cmd/picoclaw/internal/skills/helpers.go
+++ b/cmd/picoclaw/internal/skills/helpers.go
@@ -15,6 +15,8 @@ import (
"github.com/sipeed/picoclaw/pkg/utils"
)
+const skillsSearchMaxResults = 20
+
func skillsListCmd(loader *skills.SkillsLoader) {
allSkills := loader.ListSkills()
@@ -215,34 +217,43 @@ func skillsListBuiltinCmd() {
}
}
-func skillsSearchCmd(installer *skills.SkillInstaller) {
+func skillsSearchCmd(query string) {
fmt.Println("Searching for available skills...")
+ cfg, err := internal.LoadConfig()
+ if err != nil {
+ fmt.Printf("✗ Failed to load config: %v\n", err)
+ return
+ }
+
+ registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
+ MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
+ ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
+ })
+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
- availableSkills, err := installer.ListAvailableSkills(ctx)
+ results, err := registryMgr.SearchAll(ctx, query, skillsSearchMaxResults)
if err != nil {
fmt.Printf("✗ Failed to fetch skills list: %v\n", err)
return
}
- if len(availableSkills) == 0 {
+ if len(results) == 0 {
fmt.Println("No skills available.")
return
}
- fmt.Printf("\nAvailable Skills (%d):\n", len(availableSkills))
+ fmt.Printf("\nAvailable Skills (%d):\n", len(results))
fmt.Println("--------------------")
- for _, skill := range availableSkills {
- fmt.Printf(" 📦 %s\n", skill.Name)
- fmt.Printf(" %s\n", skill.Description)
- fmt.Printf(" Repo: %s\n", skill.Repository)
- if skill.Author != "" {
- fmt.Printf(" Author: %s\n", skill.Author)
- }
- if len(skill.Tags) > 0 {
- fmt.Printf(" Tags: %v\n", skill.Tags)
+ for _, result := range results {
+ fmt.Printf(" 📦 %s\n", result.DisplayName)
+ fmt.Printf(" %s\n", result.Summary)
+ fmt.Printf(" Slug: %s\n", result.Slug)
+ fmt.Printf(" Registry: %s\n", result.RegistryName)
+ if result.Version != "" {
+ fmt.Printf(" Version: %s\n", result.Version)
}
fmt.Println()
}
diff --git a/cmd/picoclaw/internal/skills/search.go b/cmd/picoclaw/internal/skills/search.go
index 53bc99109..54f72259f 100644
--- a/cmd/picoclaw/internal/skills/search.go
+++ b/cmd/picoclaw/internal/skills/search.go
@@ -2,20 +2,19 @@ package skills
import (
"github.com/spf13/cobra"
-
- "github.com/sipeed/picoclaw/pkg/skills"
)
-func newSearchCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command {
+func newSearchCommand() *cobra.Command {
cmd := &cobra.Command{
- Use: "search",
+ Use: "search [query]",
Short: "Search available skills",
- RunE: func(_ *cobra.Command, _ []string) error {
- installer, err := installerFn()
- if err != nil {
- return err
+ Args: cobra.MaximumNArgs(1),
+ RunE: func(_ *cobra.Command, args []string) error {
+ query := ""
+ if len(args) == 1 {
+ query = args[0]
}
- skillsSearchCmd(installer)
+ skillsSearchCmd(query)
return nil
},
}
diff --git a/cmd/picoclaw/internal/skills/search_test.go b/cmd/picoclaw/internal/skills/search_test.go
index 19f63a9ff..ed92e25cc 100644
--- a/cmd/picoclaw/internal/skills/search_test.go
+++ b/cmd/picoclaw/internal/skills/search_test.go
@@ -8,11 +8,11 @@ import (
)
func TestNewSearchSubcommand(t *testing.T) {
- cmd := newSearchCommand(nil)
+ cmd := newSearchCommand()
require.NotNil(t, cmd)
- assert.Equal(t, "search", cmd.Use)
+ assert.Equal(t, "search [query]", cmd.Use)
assert.Equal(t, "Search available skills", cmd.Short)
assert.Nil(t, cmd.Run)
diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go
index 6db69c990..d9263462e 100644
--- a/cmd/picoclaw/main.go
+++ b/cmd/picoclaw/main.go
@@ -48,7 +48,21 @@ func NewPicoclawCommand() *cobra.Command {
return cmd
}
+const (
+ colorBlue = "\033[1;38;2;62;93;185m"
+ colorRed = "\033[1;38;2;213;70;70m"
+ banner = "\r\n" +
+ colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" +
+ colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" +
+ colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" +
+ colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" +
+ colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
+ colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " +
+ "\033[0m\r\n"
+)
+
func main() {
+ fmt.Printf("%s", banner)
cmd := NewPicoclawCommand()
if err := cmd.Execute(); err != nil {
os.Exit(1)
diff --git a/config/config.example.json b/config/config.example.json
index 55a823009..3c84cfa9f 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -49,6 +49,7 @@
"telegram": {
"enabled": false,
"token": "YOUR_TELEGRAM_BOT_TOKEN",
+ "base_url": "",
"proxy": "",
"allow_from": [
"YOUR_USER_ID"
@@ -59,7 +60,9 @@
"enabled": false,
"token": "YOUR_DISCORD_BOT_TOKEN",
"allow_from": [],
- "mention_only": false,
+ "group_trigger": {
+ "mention_only": false
+ },
"reasoning_channel_id": ""
},
"qq": {
@@ -111,8 +114,6 @@
"enabled": false,
"channel_secret": "YOUR_LINE_CHANNEL_SECRET",
"channel_access_token": "YOUR_LINE_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": [],
"reasoning_channel_id": ""
@@ -127,32 +128,38 @@
"reasoning_channel_id": ""
},
"wecom": {
- "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats",
+ "_comment": "WeCom Bot - Easier setup, supports group chats",
"enabled": false,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": [],
"reply_timeout": 5,
"reasoning_channel_id": ""
},
"wecom_app": {
- "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md",
+ "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only.",
"enabled": false,
"corp_id": "YOUR_CORP_ID",
"corp_secret": "YOUR_CORP_SECRET",
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": [],
"reply_timeout": 5,
"reasoning_channel_id": ""
+ },
+ "wecom_aibot": {
+ "_comment": "WeCom AI Bot (智能机器人) - Official WeCom AI Bot integration, supports proactive messaging and private chats.",
+ "enabled": false,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "max_steps": 10,
+ "welcome_message": "Hello! I'm your AI assistant. How can I help you today?",
+ "reasoning_channel_id": ""
}
},
"providers": {
@@ -237,6 +244,71 @@
"cron": {
"exec_timeout_minutes": 5
},
+ "mcp": {
+ "enabled": false,
+ "servers": {
+ "context7": {
+ "enabled": false,
+ "type": "http",
+ "url": "https://mcp.context7.com/mcp",
+ "headers": {
+ "CONTEXT7_API_KEY": "ctx7sk-xx"
+ }
+ },
+ "filesystem": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-filesystem",
+ "/tmp"
+ ]
+ },
+ "github": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-github"
+ ],
+ "env": {
+ "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
+ }
+ },
+ "brave-search": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-brave-search"
+ ],
+ "env": {
+ "BRAVE_API_KEY": "YOUR_BRAVE_API_KEY"
+ }
+ },
+ "postgres": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-postgres",
+ "postgresql://user:password@localhost/dbname"
+ ]
+ },
+ "slack": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-slack"
+ ],
+ "env": {
+ "SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN",
+ "SLACK_TEAM_ID": "YOUR_SLACK_TEAM_ID"
+ }
+ }
+ }
+ },
"exec": {
"enable_deny_patterns": false,
"custom_deny_patterns": []
@@ -265,4 +337,4 @@
"host": "127.0.0.1",
"port": 18790
}
-}
+}
\ No newline at end of file
diff --git a/docker/Dockerfile.full b/docker/Dockerfile.full
new file mode 100644
index 000000000..30e1680d5
--- /dev/null
+++ b/docker/Dockerfile.full
@@ -0,0 +1,44 @@
+# ============================================================
+# Stage 1: Build the picoclaw binary
+# ============================================================
+FROM golang:1.26.0-alpine AS builder
+
+RUN apk add --no-cache git make
+
+WORKDIR /src
+
+# Cache dependencies
+COPY go.mod go.sum ./
+RUN go mod download
+
+# Copy source and build
+COPY . .
+RUN make build
+
+# ============================================================
+# Stage 2: Node.js-based runtime with full MCP support
+# ============================================================
+FROM node:24-alpine3.23
+
+# Install runtime dependencies
+RUN apk add --no-cache \
+ ca-certificates \
+ curl \
+ git \
+ python3 \
+ py3-pip
+
+# Install uv and symlink to system path
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
+ ln -s /root/.local/bin/uv /usr/local/bin/uv && \
+ ln -s /root/.local/bin/uvx /usr/local/bin/uvx && \
+ uv --version
+
+# Copy binary
+COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
+
+# Create picoclaw home directory
+RUN /usr/local/bin/picoclaw onboard
+
+ENTRYPOINT ["picoclaw"]
+CMD ["gateway"]
diff --git a/docker/docker-compose.full.yml b/docker/docker-compose.full.yml
new file mode 100644
index 000000000..6f34448c4
--- /dev/null
+++ b/docker/docker-compose.full.yml
@@ -0,0 +1,44 @@
+services:
+ # ─────────────────────────────────────────────
+ # PicoClaw Agent (one-shot query) - Full MCP Support
+ # docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent -m "Hello"
+ # ─────────────────────────────────────────────
+ picoclaw-agent:
+ build:
+ context: ..
+ dockerfile: docker/Dockerfile.full
+ container_name: picoclaw-agent-full
+ profiles:
+ - agent
+ volumes:
+ - ../config/config.json:/root/.picoclaw/config.json:ro
+ - picoclaw-workspace:/root/.picoclaw/workspace
+ - picoclaw-npm-cache:/root/.npm # npm cache for faster MCP server installs
+ entrypoint: ["picoclaw", "agent"]
+ stdin_open: true
+ tty: true
+
+ # ─────────────────────────────────────────────
+ # PicoClaw Gateway (Long-running Bot) - Full MCP Support
+ # docker compose -f docker/docker-compose.full.yml --profile gateway up
+ # ─────────────────────────────────────────────
+ picoclaw-gateway:
+ build:
+ context: ..
+ dockerfile: docker/Dockerfile.full
+ container_name: picoclaw-gateway-full
+ restart: unless-stopped
+ profiles:
+ - gateway
+ volumes:
+ # Configuration file
+ - ../config/config.json:/root/.picoclaw/config.json:ro
+ # Persistent workspace (sessions, memory, logs)
+ - picoclaw-workspace:/root/.picoclaw/workspace
+ # NPM cache for faster MCP server installs
+ - picoclaw-npm-cache:/root/.npm
+ command: ["gateway"]
+
+volumes:
+ picoclaw-workspace:
+ picoclaw-npm-cache: # Cache npm packages to speed up MCP server installations
diff --git a/docs/channels/discord/README.zh.md b/docs/channels/discord/README.zh.md
index 5b597eced..6d3c502cf 100644
--- a/docs/channels/discord/README.zh.md
+++ b/docs/channels/discord/README.zh.md
@@ -11,7 +11,9 @@ Discord 是一个专为社区设计的免费语音、视频和文本聊天应用
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allow_from": ["YOUR_USER_ID"],
- "mention_only": false
+ "group_trigger": {
+ "mention_only": false
+ }
}
}
}
@@ -22,7 +24,7 @@ Discord 是一个专为社区设计的免费语音、视频和文本聊天应用
| enabled | bool | 是 | 是否启用 Discord 频道 |
| token | string | 是 | Discord 机器人 Token |
| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
-| mention_only | bool | 否 | 是否仅响应提及机器人的消息 |
+| group_trigger | object | 否 | 群组触发设置(示例: { "mention_only": false }) |
## 设置流程
diff --git a/docs/channels/line/README.zh.md b/docs/channels/line/README.zh.md
index fd3aa80da..a36f622c2 100644
--- a/docs/channels/line/README.zh.md
+++ b/docs/channels/line/README.zh.md
@@ -11,8 +11,6 @@ PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
@@ -25,9 +23,7 @@ PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的
| enabled | bool | 是 | 是否启用 LINE Channel |
| channel_secret | string | 是 | LINE Messaging API 的 Channel Secret |
| channel_access_token | string | 是 | LINE Messaging API 的 Channel Access Token |
-| webhook_host | string | 是 | Webhook 监听的主机地址 (通常为 0.0.0.0) |
-| webhook_port | int | 是 | Webhook 监听的端口 (默认为 18791) |
-| webhook_path | string | 是 | Webhook 的路径 (默认为 /webhook/line) |
+| webhook_path | string | 否 | Webhook 的路径 (默认为 /webhook/line) |
| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
## 设置流程
@@ -35,7 +31,8 @@ PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的
1. 前往 [LINE Developers Console](https://developers.line.biz/console/) 创建一个服务提供商和一个 Messaging API Channel
2. 获取 Channel Secret 和 Channel Access Token
3. 配置Webhook:
- - Line要求Webhook必须使用HTTPS协议,因此需要部署一个支持HTTPS的服务器,或者使用反向代理工具如ngrok将本地服务器暴露到公网
- - 将 Webhook URL 设置为 `https://your-domain.com/webhook/line`
+ - LINE 要求 Webhook 必须使用 HTTPS 协议,因此需要部署一个支持 HTTPS 的服务器,或者使用反向代理工具如 ngrok 将本地服务器暴露到公网
+ - PicoClaw 现在使用共享的 Gateway HTTP 服务器来接收所有渠道的 webhook 回调,默认监听地址为 127.0.0.1:18790
+ - 将 Webhook URL 设置为 `https://your-domain.com/webhook/line`,然后将外部域名反向代理到本机的 Gateway(默认端口 18790)
- 启用 Webhook 并验证 URL
4. 将 Channel Secret 和 Channel Access Token 填入配置文件中
diff --git a/docs/channels/wecom/wecom_aibot/README.zh.md b/docs/channels/wecom/wecom_aibot/README.zh.md
new file mode 100644
index 000000000..d210528af
--- /dev/null
+++ b/docs/channels/wecom/wecom_aibot/README.zh.md
@@ -0,0 +1,116 @@
+# 企业微信智能机器人 (AI Bot)
+
+企业微信智能机器人(AI Bot)是企业微信官方提供的 AI 对话接入方式,支持私聊与群聊,内置流式响应协议,并支持超时后通过 `response_url` 主动推送最终回复。
+
+## 与其他 WeCom 通道的对比
+
+| 特性 | WeCom Bot | WeCom App | **WeCom AI Bot** |
+|------|-----------|-----------|-----------------|
+| 私聊 | ✅ | ✅ | ✅ |
+| 群聊 | ✅ | ❌ | ✅ |
+| 流式输出 | ❌ | ❌ | ✅ |
+| 超时主动推送 | ❌ | ✅ | ✅ |
+| 配置复杂度 | 低 | 高 | 中 |
+
+## 配置
+
+```json
+{
+ "channels": {
+ "wecom_aibot": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "allow_from": [],
+ "welcome_message": "你好!有什么可以帮助你的吗?",
+ "max_steps": 10
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ---------------- | ------ | ---- | -------------------------------------------------- |
+| token | string | 是 | 回调验证令牌,在 AI Bot 管理页面配置 |
+| encoding_aes_key | string | 是 | 43 字符 AES 密钥,在 AI Bot 管理页面随机生成 |
+| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-aibot) |
+| allow_from | array | 否 | 用户 ID 白名单,空数组表示允许所有用户 |
+| welcome_message | string | 否 | 用户进入聊天时发送的欢迎语,留空则不发送 |
+| reply_timeout | int | 否 | 回复超时时间(秒,默认:5) |
+| max_steps | int | 否 | Agent 最大执行步骤数(默认:10) |
+
+## 设置流程
+
+1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin)
+2. 进入"应用管理" → "智能机器人",创建或选择一个 AI Bot
+3. 在 AI Bot 配置页面,填写"消息接收"信息:
+ - **URL**:`http://:18791/webhook/wecom-aibot`
+ - **Token**:随机生成或自定义
+ - **EncodingAESKey**:点击"随机生成",得到 43 字符密钥
+4. 将 Token 和 EncodingAESKey 填入 PicoClaw 配置文件,启动服务后回到管理后台保存(企业微信会发送验证请求)
+
+> [!TIP]
+> 服务器需要能被企业微信服务器访问。如在内网/本地开发,可使用 [ngrok](https://ngrok.com) 或 frp 做内网穿透。
+
+## 流式响应协议
+
+WeCom AI Bot 使用"流式拉取"协议,区别于普通 Webhook 的一次性回复:
+
+```
+用户发消息
+ │
+ ▼
+PicoClaw 立即返回 {finish: false}(Agent 开始处理)
+ │
+ ▼
+企业微信每隔约 1 秒拉取一次 {msgtype: "stream", stream: {id: "..."}}
+ │
+ ├─ Agent 未完成 → 返回 {finish: false}(继续等待)
+ │
+ └─ Agent 完成 → 返回 {finish: true, content: "回答内容"}
+```
+
+**超时处理**(任务超过 30 秒):
+
+若 Agent 处理时间超过约 30 秒(企业微信最大轮询窗口为 6 分钟),PicoClaw 会:
+
+1. 立即关闭流,向用户显示「⏳ 正在处理中,请稍候,结果将稍后发送。」
+2. Agent 继续在后台运行
+3. Agent 完成后,通过消息中携带的 `response_url` 将最终回复主动推送给用户
+
+> `response_url` 由企业微信颁发,有效期 1 小时,只可使用一次,无需加密,直接 POST markdown 消息体即可。
+
+## 欢迎语
+
+配置 `welcome_message` 后,当用户打开与 AI Bot 的聊天窗口时(`enter_chat` 事件),PicoClaw 会自动回复该欢迎语。留空则静默忽略。
+
+```json
+"welcome_message": "你好!我是 PicoClaw AI 助手,有什么可以帮你?"
+```
+
+## 常见问题
+
+### 回调 URL 验证失败
+
+- 确认服务器防火墙已开放对应端口(默认 18791)
+- 确认 `token` 与 `encoding_aes_key` 填写正确
+- 检查 PicoClaw 日志是否收到了来自企业微信的 GET 请求
+
+### 消息没有回复
+
+- 检查 `allow_from` 是否意外限制了发送者
+- 查看日志中是否出现 `context canceled` 或 Agent 错误
+- 确认 Agent 配置(`model_name` 等)正确
+
+### 超长任务没有收到最终推送
+
+- 确认消息回调中携带了 `response_url`(仅企业微信新版 AI Bot 支持)
+- 确认服务器能主动访问外网(需向 `response_url` POST 请求)
+- 查看日志关键词 `response_url mode` 和 `Sending reply via response_url`
+
+## 参考文档
+
+- [企业微信 AI Bot 接入文档](https://developer.work.weixin.qq.com/document/path/100719)
+- [流式响应协议说明](https://developer.work.weixin.qq.com/document/path/100719)
+- [response_url 主动回复](https://developer.work.weixin.qq.com/document/path/101138)
diff --git a/docs/channels/wecom/wecom_app/README.zh.md b/docs/channels/wecom/wecom_app/README.zh.md
index 1e6a0e2b3..0a9858107 100644
--- a/docs/channels/wecom/wecom_app/README.zh.md
+++ b/docs/channels/wecom/wecom_app/README.zh.md
@@ -14,8 +14,6 @@
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": [],
"reply_timeout": 5
@@ -31,8 +29,6 @@
| agent_id | int | 是 | 应用程序代理 ID |
| token | string | 是 | 回调验证令牌 |
| encoding_aes_key | string | 是 | 43 字符 AES 密钥 |
-| webhook_host | string | 否 | HTTP 服务器绑定地址 |
-| webhook_port | int | 否 | HTTP 服务器端口(默认:18792) |
| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-app) |
| allow_from | array | 否 | 用户 ID 白名单 |
| reply_timeout | int | 否 | 回复超时时间(秒) |
@@ -45,3 +41,5 @@
4. 在应用设置中配置“接收消息”,获取 Token 和 EncodingAESKey
5. 设置回调 URL 为 `http://:/webhook/wecom-app`
6. 将 CorpID, Secret, AgentID 等信息填入配置文件
+
+ 注意: PicoClaw 现在使用共享的 Gateway HTTP 服务器来接收所有渠道的 webhook 回调,默认监听地址为 127.0.0.1:18790。如需从公网接收回调,请把外部域名反向代理到 Gateway(默认端口 18790)。
diff --git a/docs/channels/wecom/wecom_bot/README.zh.md b/docs/channels/wecom/wecom_bot/README.zh.md
index c4bb1c87e..63d9b84d6 100644
--- a/docs/channels/wecom/wecom_bot/README.zh.md
+++ b/docs/channels/wecom/wecom_bot/README.zh.md
@@ -12,8 +12,6 @@
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": [],
"reply_timeout": 5
@@ -27,8 +25,6 @@
| token | string | 是 | 签名验证代币 |
| encoding_aes_key | string | 是 | 用于解密的 43 字符 AES 密钥 |
| webhook_url | string | 是 | 用于发送回复的企业微信群聊机器人 Webhook URL |
-| webhook_host | string | 否 | HTTP 服务器绑定地址(默认:0.0.0.0) |
-| webhook_port | int | 否 | HTTP 服务器端口(默认:18793) |
| webhook_path | string | 否 | Webhook 端点路径(默认:/webhook/wecom) |
| allow_from | array | 否 | 用户 ID 白名单(空值 = 允许所有用户) |
| reply_timeout | int | 否 | 回复超时时间(单位:秒,默认值:5) |
@@ -39,3 +35,5 @@
2. 获取 Webhook URL
3. (如需接收消息) 在机器人配置页面设置接收消息的 API 地址(回调地址)以及 Token 和 EncodingAESKey
4. 将相关信息填入配置文件
+
+ 注意: PicoClaw 现在使用共享的 Gateway HTTP 服务器来接收所有渠道的 webhook 回调,默认监听地址为 127.0.0.1:18790。如需从公网接收回调,请把外部域名反向代理到 Gateway(默认端口 18790)。
diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md
index 8aba1aa91..6204fb0c8 100644
--- a/docs/tools_configuration.md
+++ b/docs/tools_configuration.md
@@ -8,6 +8,7 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`.
{
"tools": {
"web": { ... },
+ "mcp": { ... },
"exec": { ... },
"cron": { ... },
"skills": { ... }
@@ -21,35 +22,35 @@ Web tools are used for web search and fetching.
### Brave
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enabled` | bool | false | Enable Brave search |
-| `api_key` | string | - | Brave Search API key |
-| `max_results` | int | 5 | Maximum number of results |
+| Config | Type | Default | Description |
+| ------------- | ------ | ------- | ------------------------- |
+| `enabled` | bool | false | Enable Brave search |
+| `api_key` | string | - | Brave Search API key |
+| `max_results` | int | 5 | Maximum number of results |
### DuckDuckGo
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enabled` | bool | true | Enable DuckDuckGo search |
-| `max_results` | int | 5 | Maximum number of results |
+| Config | Type | Default | Description |
+| ------------- | ---- | ------- | ------------------------- |
+| `enabled` | bool | true | Enable DuckDuckGo search |
+| `max_results` | int | 5 | Maximum number of results |
### Perplexity
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enabled` | bool | false | Enable Perplexity search |
-| `api_key` | string | - | Perplexity API key |
-| `max_results` | int | 5 | Maximum number of results |
+| Config | Type | Default | Description |
+| ------------- | ------ | ------- | ------------------------- |
+| `enabled` | bool | false | Enable Perplexity search |
+| `api_key` | string | - | Perplexity API key |
+| `max_results` | int | 5 | Maximum number of results |
## Exec Tool
The exec tool is used to execute shell commands.
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
-| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
+| Config | Type | Default | Description |
+| ---------------------- | ----- | ------- | ------------------------------------------ |
+| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
+| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
### Functionality
@@ -80,10 +81,7 @@ By default, PicoClaw blocks the following dangerous commands:
"tools": {
"exec": {
"enable_deny_patterns": true,
- "custom_deny_patterns": [
- "\\brm\\s+-r\\b",
- "\\bkillall\\s+python"
- ]
+ "custom_deny_patterns": ["\\brm\\s+-r\\b", "\\bkillall\\s+python"]
}
}
}
@@ -93,9 +91,84 @@ By default, PicoClaw blocks the following dangerous commands:
The cron tool is used for scheduling periodic tasks.
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
+| Config | Type | Default | Description |
+| ---------------------- | ---- | ------- | ---------------------------------------------- |
+| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
+
+## MCP Tool
+
+The MCP tool enables integration with external Model Context Protocol servers.
+
+### Global Config
+
+| Config | Type | Default | Description |
+| --------- | ------ | ------- | ----------------------------------- |
+| `enabled` | bool | false | Enable MCP integration globally |
+| `servers` | object | `{}` | Map of server name to server config |
+
+### Per-Server Config
+
+| Config | Type | Required | Description |
+| ---------- | ------ | -------- | ------------------------------------------ |
+| `enabled` | bool | yes | Enable this MCP server |
+| `type` | string | no | Transport type: `stdio`, `sse`, `http` |
+| `command` | string | stdio | Executable command for stdio transport |
+| `args` | array | no | Command arguments for stdio transport |
+| `env` | object | no | Environment variables for stdio process |
+| `env_file` | string | no | Path to environment file for stdio process |
+| `url` | string | sse/http | Endpoint URL for `sse`/`http` transport |
+| `headers` | object | no | HTTP headers for `sse`/`http` transport |
+
+### Transport Behavior
+
+- If `type` is omitted, transport is auto-detected:
+ - `url` is set → `sse`
+ - `command` is set → `stdio`
+- `http` and `sse` both use `url` + optional `headers`.
+- `env` and `env_file` are only applied to `stdio` servers.
+
+### Configuration Examples
+
+#### 1) Stdio MCP server
+
+```json
+{
+ "tools": {
+ "mcp": {
+ "enabled": true,
+ "servers": {
+ "filesystem": {
+ "enabled": true,
+ "command": "npx",
+ "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
+ }
+ }
+ }
+ }
+}
+```
+
+#### 2) Remote SSE/HTTP MCP server
+
+```json
+{
+ "tools": {
+ "mcp": {
+ "enabled": true,
+ "servers": {
+ "remote-mcp": {
+ "enabled": true,
+ "type": "sse",
+ "url": "https://example.com/mcp",
+ "headers": {
+ "Authorization": "Bearer YOUR_TOKEN"
+ }
+ }
+ }
+ }
+ }
+}
+```
## Skills Tool
@@ -103,13 +176,13 @@ The skills tool configures skill discovery and installation via registries like
### Registries
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
-| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
-| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
-| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
-| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
+| Config | Type | Default | Description |
+| ---------------------------------- | ------ | -------------------- | ----------------------- |
+| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
+| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
+| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
+| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
+| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
### Configuration Example
@@ -136,8 +209,10 @@ The skills tool configures skill discovery and installation via registries like
All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS__`:
For example:
+
- `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true`
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
+- `PICOCLAW_TOOLS_MCP_ENABLED=true`
-Note: Array-type environment variables are not currently supported and must be set via the config file.
+Note: Nested map-style config (for example `tools.mcp.servers..*`) is configured in `config.json` rather than environment variables.
diff --git a/docs/wecom-app-configuration.md b/docs/wecom-app-configuration.md
deleted file mode 100644
index 3b17d37a7..000000000
--- a/docs/wecom-app-configuration.md
+++ /dev/null
@@ -1,117 +0,0 @@
-# 企业微信自建应用 (WeCom App) 配置指南
-
-本文档介绍如何在 PicoClaw 中配置企业微信自建应用 (wecom-app) 通道。
-
-## 功能特性
-
-| 功能 | 支持状态 |
-|------|---------|
-| 被动接收消息 | ✅ |
-| 主动发送消息 | ✅ |
-| 私聊 | ✅ |
-| 群聊 | ❌ |
-
-## 配置步骤
-
-### 1. 企业微信后台配置
-
-1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin)
-2. 进入"应用管理" → 选择自建应用
-3. 记录以下信息:
- - **AgentId**: 应用详情页显示
- - **Secret**: 点击"查看"获取
-4. 进入"我的企业"页面,记录 **企业ID** (CorpID)
-
-### 2. 接收消息配置
-
-1. 在应用详情页,点击"接收消息"的"设置API接收"
-2. 填写以下信息:
- - **URL**: `http://your-server:18792/webhook/wecom-app`
- - **Token**: 随机生成或自定义(用于签名验证)
- - **EncodingAESKey**: 点击"随机生成"生成43字符的密钥
-3. 点击"保存"时,企业微信会发送验证请求
-
-### 3. PicoClaw 配置
-
-在 `config.json` 中添加以下配置:
-
-```json
-{
- "channels": {
- "wecom_app": {
- "enabled": true,
- "corp_id": "wwxxxxxxxxxxxxxxxx", // 企业ID
- "corp_secret": "xxxxxxxxxxxxxxxxxxxxxxxx", // 应用Secret
- "agent_id": 1000002, // 应用AgentId
- "token": "your_token", // 接收消息配置的Token
- "encoding_aes_key": "your_encoding_aes_key", // 接收消息配置的EncodingAESKey
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
- "webhook_path": "/webhook/wecom-app",
- "allow_from": [],
- "reply_timeout": 5
- }
- }
-}
-```
-
-## 常见问题
-
-### 1. 回调URL验证失败
-
-**症状**: 企业微信保存API接收消息时提示验证失败
-
-**检查项**:
-- 确认服务器防火墙已开放 18792 端口
-- 确认 `corp_id`、`token`、`encoding_aes_key` 配置正确
-- 查看 PicoClaw 日志是否有请求到达
-
-### 2. 中文消息解密失败
-
-**症状**: 发送中文消息时出现 `invalid padding size` 错误
-
-**原因**: 企业微信使用非标准的 PKCS7 填充(32字节块大小)
-
-**解决**: 确保使用最新版本的 PicoClaw,已修复此问题。
-
-### 3. 端口冲突
-
-**症状**: 启动时提示端口已被占用
-
-**解决**: 修改 `webhook_port` 为其他端口,如 18794
-
-## 技术细节
-
-### 加密算法
-
-- **算法**: AES-256-CBC
-- **密钥**: EncodingAESKey Base64解码后的32字节
-- **IV**: AESKey的前16字节
-- **填充**: PKCS7(块大小为32字节,非标准16字节)
-- **消息格式**: XML
-
-### 消息结构
-
-解密后的消息格式:
-```
-random(16B) + msg_len(4B) + msg + receiveid
-```
-
-其中 `receiveid` 对于自建应用是 `corp_id`。
-
-## 调试
-
-启用调试模式查看详细日志:
-
-```bash
-picoclaw gateway --debug
-```
-
-关键日志标识:
-- `wecom_app`: WeCom App 通道相关日志
-- `wecom_common`: 加密解密相关日志
-
-## 参考文档
-
-- [企业微信官方文档 - 接收消息](https://developer.work.weixin.qq.com/document/path/96211)
-- [企业微信官方加解密库](https://github.com/sbzhu/weworkapi_golang)
diff --git a/go.mod b/go.mod
index 7892cade6..c1172937c 100644
--- a/go.mod
+++ b/go.mod
@@ -8,13 +8,16 @@ require (
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1
+ github.com/gdamore/tcell/v2 v2.13.8
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mdp/qrterminal/v3 v3.2.1
+ github.com/modelcontextprotocol/go-sdk v1.3.0
github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
+ github.com/rivo/tview v0.42.0
github.com/slack-go/slack v0.17.3
github.com/spf13/cobra v1.10.2
github.com/stretchr/testify v1.11.1
@@ -35,6 +38,7 @@ require (
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
github.com/gdamore/encoding v1.0.1 // indirect
github.com/gdamore/tcell/v2 v2.13.8 // indirect
+ github.com/h2non/filetype v1.1.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
@@ -43,7 +47,6 @@ require (
github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
- github.com/rivo/tview v0.42.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/zerolog v1.34.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
@@ -81,6 +84,7 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
+ github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/arch v0.24.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
diff --git a/go.sum b/go.sum
index d1ee1d629..060594d06 100644
--- a/go.sum
+++ b/go.sum
@@ -66,6 +66,8 @@ github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncV
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
+github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
+github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
@@ -96,6 +98,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc=
github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek=
+github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg=
+github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
@@ -130,6 +134,8 @@ github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
+github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs=
+github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE=
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
@@ -212,6 +218,8 @@ github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTd
github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
+github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
+github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
diff --git a/pkg/agent/context.go b/pkg/agent/context.go
index b7c6e1108..3aa903b3f 100644
--- a/pkg/agent/context.go
+++ b/pkg/agent/context.go
@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "slices"
"strings"
"sync"
"time"
@@ -33,6 +34,11 @@ type ContextBuilder struct {
// created (didn't exist at cache time, now exist) or deleted (existed at
// cache time, now gone) — both of which should trigger a cache rebuild.
existedAtCache map[string]bool
+
+ // skillFilesAtCache snapshots the skill tree file set and mtimes at cache
+ // build time. This catches nested file creations/deletions/mtime changes
+ // that may not update the top-level skill root directory mtime.
+ skillFilesAtCache map[string]time.Time
}
func getGlobalConfigDir() string {
@@ -46,8 +52,11 @@ func getGlobalConfigDir() string {
func NewContextBuilder(workspace string) *ContextBuilder {
// builtin skills: skills directory in current project
// Use the skills/ directory under the current working directory
- wd, _ := os.Getwd()
- builtinSkillsDir := filepath.Join(wd, "skills")
+ builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
+ if builtinSkillsDir == "" {
+ wd, _ := os.Getwd()
+ builtinSkillsDir = filepath.Join(wd, "skills")
+ }
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
return &ContextBuilder{
@@ -147,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
cb.cachedSystemPrompt = prompt
cb.cachedAt = baseline.maxMtime
cb.existedAtCache = baseline.existed
+ cb.skillFilesAtCache = baseline.skillFiles
logger.DebugCF("agent", "System prompt cached",
map[string]any{
@@ -166,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() {
cb.cachedSystemPrompt = ""
cb.cachedAt = time.Time{}
cb.existedAtCache = nil
+ cb.skillFilesAtCache = nil
logger.DebugCF("agent", "System prompt cache invalidated", nil)
}
-// sourcePaths returns the workspace source file paths tracked for cache
-// invalidation (bootstrap files + memory). The skills directory is handled
-// separately in sourceFilesChangedLocked because it requires both directory-
-// level and recursive file-level mtime checks.
+// sourcePaths returns non-skill workspace source files tracked for cache
+// invalidation (bootstrap files + memory). Skill roots are handled separately
+// because they require both directory-level and recursive file-level checks.
func (cb *ContextBuilder) sourcePaths() []string {
return []string{
filepath.Join(cb.workspace, "AGENTS.md"),
@@ -184,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string {
}
}
+// skillRoots returns all skill root directories that can affect
+// BuildSkillsSummary output (workspace/global/builtin).
+func (cb *ContextBuilder) skillRoots() []string {
+ if cb.skillsLoader == nil {
+ return []string{filepath.Join(cb.workspace, "skills")}
+ }
+
+ roots := cb.skillsLoader.SkillRoots()
+ if len(roots) == 0 {
+ return []string{filepath.Join(cb.workspace, "skills")}
+ }
+ return roots
+}
+
// cacheBaseline holds the file existence snapshot and the latest observed
// mtime across all tracked paths. Used as the cache reference point.
type cacheBaseline struct {
- existed map[string]bool
- maxMtime time.Time
+ existed map[string]bool
+ skillFiles map[string]time.Time
+ maxMtime time.Time
}
// buildCacheBaseline records which tracked paths currently exist and computes
// the latest mtime across all tracked files + skills directory contents.
// Called under write lock when the cache is built.
func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
- skillsDir := filepath.Join(cb.workspace, "skills")
+ skillRoots := cb.skillRoots()
- // All paths whose existence we track: source files + skills dir.
- allPaths := append(cb.sourcePaths(), skillsDir)
+ // All paths whose existence we track: source files + all skill roots.
+ allPaths := append(cb.sourcePaths(), skillRoots...)
existed := make(map[string]bool, len(allPaths))
+ skillFiles := make(map[string]time.Time)
var maxMtime time.Time
for _, p := range allPaths {
@@ -211,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
}
}
- // Walk skills files to capture their mtimes too.
- // Use os.Stat (not d.Info) to match the stat method used in
- // fileChangedSince / skillFilesModifiedSince for consistency.
- _ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
- if walkErr == nil && !d.IsDir() {
- if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
- maxMtime = info.ModTime()
+ // Walk all skill roots recursively to snapshot skill files and mtimes.
+ // Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks.
+ for _, root := range skillRoots {
+ _ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr == nil && !d.IsDir() {
+ if info, err := os.Stat(path); err == nil {
+ skillFiles[path] = info.ModTime()
+ if info.ModTime().After(maxMtime) {
+ maxMtime = info.ModTime()
+ }
+ }
}
- }
- return nil
- })
+ return nil
+ })
+ }
// If no tracked files exist yet (empty workspace), maxMtime is zero.
// Use a very old non-zero time so that:
@@ -233,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
maxMtime = time.Unix(1, 0)
}
- return cacheBaseline{existed: existed, maxMtime: maxMtime}
+ return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime}
}
// sourceFilesChangedLocked checks whether any workspace source file has been
@@ -249,27 +279,21 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
}
// Check tracked source files (bootstrap + memory).
- for _, p := range cb.sourcePaths() {
- if cb.fileChangedSince(p) {
- return true
- }
- }
-
- // --- Skills directory (handled separately from sourcePaths) ---
- //
- // 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
- skillsDir := filepath.Join(cb.workspace, "skills")
- if cb.fileChangedSince(skillsDir) {
+ if slices.ContainsFunc(cb.sourcePaths(), cb.fileChangedSince) {
return true
}
- // 2. Structural changes (add/remove entries inside the dir) are reflected
- // in the directory's own mtime, which fileChangedSince already checks.
+ // --- Skill roots (workspace/global/builtin) ---
//
- // 3. Content-only edits to files inside skills/ do NOT update the parent
- // directory mtime on most filesystems, so we recursively walk to check
- // individual file mtimes at any nesting depth.
- if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
+ // For each root:
+ // 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince.
+ // 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot.
+ for _, root := range cb.skillRoots() {
+ if cb.fileChangedSince(root) {
+ return true
+ }
+ }
+ if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) {
return true
}
@@ -310,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool {
// if the callback returned nil when its err parameter is non-nil.
var errWalkStop = errors.New("walk stop")
-// skillFilesModifiedSince recursively walks the skills directory and checks
-// whether any file was modified after t. This catches content-only edits at
-// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
-// parent directory mtimes.
-func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
- changed := false
- err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
- if walkErr == nil && !d.IsDir() {
- if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
- changed = true
- return errWalkStop // stop walking
- }
- }
- return nil
- })
- // errWalkStop is expected (early exit on first changed file).
- // os.IsNotExist means the skills dir doesn't exist yet — not an error.
- // Any other error is unexpected and worth logging.
- if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
- logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
+// skillFilesChangedSince compares the current recursive skill file tree
+// against the cache-time snapshot. Any create/delete/mtime drift invalidates
+// the cache.
+func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool {
+ // Defensive: if the snapshot was never initialized, force rebuild.
+ if filesAtCache == nil {
+ return true
}
- return changed
+
+ // Check cached files still exist and keep the same mtime.
+ for path, cachedMtime := range filesAtCache {
+ info, err := os.Stat(path)
+ if err != nil {
+ // A previously tracked file disappeared (or became inaccessible):
+ // either way, cached skill summary may now be stale.
+ return true
+ }
+ if !info.ModTime().Equal(cachedMtime) {
+ return true
+ }
+ }
+
+ // Check no new files appeared under any skill root.
+ changed := false
+ for _, root := range skillRoots {
+ if strings.TrimSpace(root) == "" {
+ continue
+ }
+
+ err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr != nil {
+ // Treat unexpected walk errors as changed to avoid stale cache.
+ if !os.IsNotExist(walkErr) {
+ changed = true
+ return errWalkStop
+ }
+ return nil
+ }
+ if d.IsDir() {
+ return nil
+ }
+ if _, ok := filesAtCache[path]; !ok {
+ changed = true
+ return errWalkStop
+ }
+ return nil
+ })
+
+ if changed {
+ return true
+ }
+ if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
+ logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
+ return true
+ }
+ }
+
+ return false
}
func (cb *ContextBuilder) LoadBootstrapFiles() string {
@@ -467,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages(
// Add current user message
if strings.TrimSpace(currentMessage) != "" {
- messages = append(messages, providers.Message{
+ msg := providers.Message{
Role: "user",
Content: currentMessage,
- })
+ }
+ if len(media) > 0 {
+ msg.Media = media
+ }
+ messages = append(messages, msg)
}
return messages
diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go
index ba70d4c0d..707510820 100644
--- a/pkg/agent/context_cache_test.go
+++ b/pkg/agent/context_cache_test.go
@@ -383,6 +383,162 @@ Updated content.`
}
}
+// TestGlobalSkillFileContentChange verifies that modifying a global skill
+// (~/.picoclaw/skills) invalidates the cached system prompt.
+func TestGlobalSkillFileContentChange(t *testing.T) {
+ tmpHome := t.TempDir()
+ t.Setenv("HOME", tmpHome)
+
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md")
+ if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ v1 := `---
+name: global-skill
+description: global-v1
+---
+# Global Skill v1`
+ if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ cb := NewContextBuilder(tmpDir)
+ sp1 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp1, "global-v1") {
+ t.Fatal("expected initial prompt to contain global skill description")
+ }
+
+ v2 := `---
+name: global-skill
+description: global-v2
+---
+# Global Skill v2`
+ if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ future := time.Now().Add(2 * time.Second)
+ if err := os.Chtimes(globalSkillPath, future, future); err != nil {
+ t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("sourceFilesChangedLocked() should detect global skill file content change")
+ }
+
+ sp2 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp2, "global-v2") {
+ t.Error("rebuilt prompt should contain updated global skill description")
+ }
+ if sp1 == sp2 {
+ t.Error("cache should be invalidated when global skill file content changes")
+ }
+}
+
+// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill
+// invalidates the cached system prompt.
+func TestBuiltinSkillFileContentChange(t *testing.T) {
+ tmpHome := t.TempDir()
+ t.Setenv("HOME", tmpHome)
+
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ builtinRoot := t.TempDir()
+ t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot)
+
+ builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md")
+ if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ v1 := `---
+name: builtin-skill
+description: builtin-v1
+---
+# Builtin Skill v1`
+ if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ cb := NewContextBuilder(tmpDir)
+ sp1 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp1, "builtin-v1") {
+ t.Fatal("expected initial prompt to contain builtin skill description")
+ }
+
+ v2 := `---
+name: builtin-skill
+description: builtin-v2
+---
+# Builtin Skill v2`
+ if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ future := time.Now().Add(2 * time.Second)
+ if err := os.Chtimes(builtinSkillPath, future, future); err != nil {
+ t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change")
+ }
+
+ sp2 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp2, "builtin-v2") {
+ t.Error("rebuilt prompt should contain updated builtin skill description")
+ }
+ if sp1 == sp2 {
+ t.Error("cache should be invalidated when builtin skill file content changes")
+ }
+}
+
+// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill
+// file invalidates the cached system prompt.
+func TestSkillFileDeletionInvalidatesCache(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "skills/delete-me/SKILL.md": `---
+name: delete-me
+description: delete-me-v1
+---
+# Delete Me`,
+ })
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ sp1 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp1, "delete-me-v1") {
+ t.Fatal("expected initial prompt to contain skill description")
+ }
+
+ skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md")
+ if err := os.Remove(skillPath); err != nil {
+ t.Fatal(err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("sourceFilesChangedLocked() should detect deleted skill file")
+ }
+
+ sp2 := cb.BuildSystemPromptWithCache()
+ if strings.Contains(sp2, "delete-me-v1") {
+ t.Error("rebuilt prompt should not contain deleted skill description")
+ }
+ if sp1 == sp2 {
+ t.Error("cache should be invalidated when skill file is deleted")
+ }
+}
+
// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
// can safely call BuildSystemPromptWithCache concurrently without producing
// empty results, panics, or data races.
@@ -404,11 +560,11 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
var wg sync.WaitGroup
errs := make(chan string, goroutines*iterations)
- for g := 0; g < goroutines; g++ {
+ for g := range goroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
- for i := 0; i < iterations; i++ {
+ for i := range iterations {
result := cb.BuildSystemPromptWithCache()
if result == "" {
errs <- "empty prompt returned"
diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go
index e597542e4..ed438059f 100644
--- a/pkg/agent/instance.go
+++ b/pkg/agent/instance.go
@@ -1,9 +1,11 @@
package agent
import (
+ "fmt"
"log"
"os"
"path/filepath"
+ "regexp"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
@@ -48,18 +50,24 @@ func NewAgentInstance(
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
restrict := defaults.RestrictToWorkspace
+ readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
+
+ // Compile path whitelist patterns from config.
+ allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
+ allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
+
toolsRegistry := tools.NewToolRegistry()
- toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
- toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
- toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
+ toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
+ toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
+ toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
toolsRegistry.Register(execTool)
- toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
- toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
+ toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
+ toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
@@ -189,6 +197,19 @@ func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentD
return defaults.ModelFallbacks
}
+func compilePatterns(patterns []string) []*regexp.Regexp {
+ compiled := make([]*regexp.Regexp, 0, len(patterns))
+ for _, p := range patterns {
+ re, err := regexp.Compile(p)
+ if err != nil {
+ fmt.Printf("Warning: invalid path pattern %q: %v\n", p, err)
+ continue
+ }
+ compiled = append(compiled, re)
+ }
+ return compiled
+}
+
func expandHome(path string) string {
if path == "" {
return path
diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go
index af1bf2ead..4f41ecd1c 100644
--- a/pkg/agent/instance_test.go
+++ b/pkg/agent/instance_test.go
@@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
}
func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Agents: config.AgentsConfig{
- Defaults: config.AgentDefaults{
- Workspace: tmpDir,
- Model: "step-3.5-flash",
- },
+ tests := []struct {
+ name string
+ aliasName string
+ modelName string
+ apiBase string
+ wantProvider string
+ wantModel string
+ }{
+ {
+ name: "alias with provider prefix",
+ aliasName: "step-3.5-flash",
+ modelName: "openrouter/stepfun/step-3.5-flash:free",
+ apiBase: "https://openrouter.ai/api/v1",
+ wantProvider: "openrouter",
+ wantModel: "stepfun/step-3.5-flash:free",
},
- ModelList: []config.ModelConfig{
- {
- ModelName: "step-3.5-flash",
- Model: "openrouter/stepfun/step-3.5-flash:free",
- APIBase: "https://openrouter.ai/api/v1",
- },
+ {
+ name: "alias without provider prefix",
+ aliasName: "glm-5",
+ modelName: "glm-5",
+ apiBase: "https://api.z.ai/api/coding/paas/v4",
+ wantProvider: "openai",
+ wantModel: "glm-5",
},
}
- provider := &mockProvider{}
- agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
- if len(agent.Candidates) != 1 {
- t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
- }
- if agent.Candidates[0].Provider != "openrouter" {
- t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter")
- }
- if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" {
- t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free")
- }
-}
-
-func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Agents: config.AgentsConfig{
- Defaults: config.AgentDefaults{
- Workspace: tmpDir,
- Model: "glm-5",
- },
- },
- ModelList: []config.ModelConfig{
- {
- ModelName: "glm-5",
- Model: "glm-5",
- APIBase: "https://api.z.ai/api/coding/paas/v4",
- },
- },
- }
-
- provider := &mockProvider{}
- agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
-
- if len(agent.Candidates) != 1 {
- t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
- }
- if agent.Candidates[0].Provider != "openai" {
- t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai")
- }
- if agent.Candidates[0].Model != "glm-5" {
- t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5")
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: tt.aliasName,
+ },
+ },
+ ModelList: []config.ModelConfig{
+ {
+ ModelName: tt.aliasName,
+ Model: tt.modelName,
+ APIBase: tt.apiBase,
+ },
+ },
+ }
+
+ provider := &mockProvider{}
+ agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
+
+ if len(agent.Candidates) != 1 {
+ t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
+ }
+ if agent.Candidates[0].Provider != tt.wantProvider {
+ t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
+ }
+ if agent.Candidates[0].Model != tt.wantModel {
+ t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
+ }
+ })
}
}
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 1150b5ab3..3e29e8884 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -23,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/mcp"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -46,19 +47,24 @@ type AgentLoop struct {
// processOptions configures how a message is processed
type processOptions struct {
- SessionKey string // Session identifier for history/context
- Channel string // Target channel for tool execution
- ChatID string // Target chat ID for tool execution
- UserMessage string // User message content (may include prefix)
- DefaultResponse string // Response when LLM returns empty
- EnableSummary bool // Whether to trigger summarization
- SendResponse bool // Whether to send response via bus
- NoHistory bool // If true, don't load session history (for heartbeat)
+ SessionKey string // Session identifier for history/context
+ Channel string // Target channel for tool execution
+ ChatID string // Target chat ID for tool execution
+ UserMessage string // User message content (may include prefix)
+ Media []string // media:// refs from inbound message
+ DefaultResponse string // Response when LLM returns empty
+ EnableSummary bool // Whether to trigger summarization
+ SendResponse bool // Whether to send response via bus
+ NoHistory bool // If true, don't load session history (for heartbeat)
}
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
-func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
+func NewAgentLoop(
+ cfg *config.Config,
+ msgBus *bus.MessageBus,
+ provider providers.LLMProvider,
+) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
// Register shared tools to all agents
@@ -99,7 +105,7 @@ func registerSharedTools(
}
// Web tools
- if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
+ searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
@@ -113,10 +119,18 @@ func registerSharedTools(
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
Proxy: cfg.Tools.Web.Proxy,
- }); searchTool != nil {
+ })
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
+ } else if searchTool != nil {
agent.Tools.Register(searchTool)
}
- agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
+ fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ } else {
+ agent.Tools.Register(fetchTool)
+ }
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
@@ -162,6 +176,72 @@ func registerSharedTools(
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
+ // Initialize MCP servers for all agents
+ if al.cfg.Tools.MCP.Enabled {
+ mcpManager := mcp.NewManager()
+ // Ensure MCP connections are cleaned up on exit, regardless of initialization success
+ // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
+ defer func() {
+ if err := mcpManager.Close(); err != nil {
+ logger.ErrorCF("agent", "Failed to close MCP manager",
+ map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ var workspacePath string
+ if defaultAgent != nil && defaultAgent.Workspace != "" {
+ workspacePath = defaultAgent.Workspace
+ } else {
+ workspacePath = al.cfg.WorkspacePath()
+ }
+
+ if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
+ logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
+ map[string]any{
+ "error": err.Error(),
+ })
+ } else {
+ // Register MCP tools for all agents
+ servers := mcpManager.GetServers()
+ uniqueTools := 0
+ totalRegistrations := 0
+ agentIDs := al.registry.ListAgentIDs()
+ agentCount := len(agentIDs)
+
+ for serverName, conn := range servers {
+ uniqueTools += len(conn.Tools)
+ for _, tool := range conn.Tools {
+ for _, agentID := range agentIDs {
+ agent, ok := al.registry.GetAgent(agentID)
+ if !ok {
+ continue
+ }
+ mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
+ agent.Tools.Register(mcpTool)
+ totalRegistrations++
+ logger.DebugCF("agent", "Registered MCP tool",
+ map[string]any{
+ "agent_id": agentID,
+ "server": serverName,
+ "tool": tool.Name,
+ "name": mcpTool.Name(),
+ })
+ }
+ }
+ }
+ logger.InfoCF("agent", "MCP tools registered successfully",
+ map[string]any{
+ "server_count": len(servers),
+ "unique_tools": uniqueTools,
+ "total_registrations": totalRegistrations,
+ "agent_count": agentCount,
+ })
+ }
+ }
+
for al.running.Load() {
select {
case <-ctx.Done():
@@ -302,7 +382,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
-func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
+func (al *AgentLoop) ProcessDirect(
+ ctx context.Context,
+ content, sessionKey string,
+) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
@@ -323,7 +406,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
-func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
+func (al *AgentLoop) ProcessHeartbeat(
+ ctx context.Context,
+ content, channel, chatID string,
+) (string, error) {
agent := al.registry.GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for heartbeat")
@@ -348,13 +434,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
} else {
logContent = utils.Truncate(msg.Content, 80)
}
- logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
+ logger.InfoCF(
+ "agent",
+ fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"sender_id": msg.SenderID,
"session_key": msg.SessionKey,
- })
+ },
+ )
// Route system messages to processSystemMessage
if msg.Channel == "system" {
@@ -409,15 +498,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
+ Media: msg.Media,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
})
}
-func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
+func (al *AgentLoop) processSystemMessage(
+ ctx context.Context,
+ msg bus.InboundMessage,
+) (string, error) {
if msg.Channel != "system" {
- return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
+ return "", fmt.Errorf(
+ "processSystemMessage called with non-system message channel: %s",
+ msg.Channel,
+ )
}
logger.InfoCF("agent", "Processing system message",
@@ -475,14 +571,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
}
// runAgentLoop is the core message processing logic.
-func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
+func (al *AgentLoop) runAgentLoop(
+ ctx context.Context,
+ agent *AgentInstance,
+ opts processOptions,
+) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
- logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
+ logger.WarnCF(
+ "agent",
+ "Failed to record last channel",
+ map[string]any{"error": err.Error()},
+ )
}
}
}
@@ -501,11 +605,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
history,
summary,
opts.UserMessage,
- nil,
+ opts.Media,
opts.Channel,
opts.ChatID,
)
+ // Resolve media:// refs to base64 data URLs (streaming)
+ maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
+ messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
+
// 3. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
@@ -564,7 +672,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
return ""
}
-func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
+func (al *AgentLoop) handleReasoning(
+ ctx context.Context,
+ reasoningContent, channelName, channelID string,
+) {
if reasoningContent == "" || channelName == "" || channelID == "" {
return
}
@@ -657,22 +768,33 @@ func (al *AgentLoop) runLLMIteration(
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
- fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
+ fbResult, fbErr := al.fallback.Execute(
+ ctx,
+ agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
- return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
- "max_tokens": agent.MaxTokens,
- "temperature": agent.Temperature,
- "prompt_cache_key": agent.ID,
- })
+ return agent.Provider.Chat(
+ ctx,
+ messages,
+ providerToolDefs,
+ model,
+ map[string]any{
+ "max_tokens": agent.MaxTokens,
+ "temperature": agent.Temperature,
+ "prompt_cache_key": agent.ID,
+ },
+ )
},
)
if fbErr != nil {
return nil, fbErr
}
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
- logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
- fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
- map[string]any{"agent_id": agent.ID, "iteration": iteration})
+ logger.InfoCF(
+ "agent",
+ fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
+ fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
+ map[string]any{"agent_id": agent.ID, "iteration": iteration},
+ )
}
return fbResult.Response, nil
}
@@ -723,10 +845,14 @@ func (al *AgentLoop) runLLMIteration(
}
if isContextError && retry < maxRetries {
- logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
- "error": err.Error(),
- "retry": retry,
- })
+ logger.WarnCF(
+ "agent",
+ "Context window error detected, attempting compression",
+ map[string]any{
+ "error": err.Error(),
+ "retry": retry,
+ },
+ )
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
@@ -758,7 +884,12 @@ func (al *AgentLoop) runLLMIteration(
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
- go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
+ go al.handleReasoning(
+ ctx,
+ response.Reasoning,
+ opts.Channel,
+ al.targetReasoningChannelID(opts.Channel),
+ )
logger.DebugCF("agent", "LLM response",
map[string]any{
@@ -1067,7 +1198,11 @@ func formatMessagesForLog(messages []providers.Message) string {
for _, tc := range msg.ToolCalls {
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
if tc.Function != nil {
- fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
+ fmt.Fprintf(
+ &sb,
+ " Arguments: %s\n",
+ utils.Truncate(tc.Function.Arguments, 200),
+ )
}
}
}
@@ -1096,7 +1231,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
if len(tool.Function.Parameters) > 0 {
- fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
+ fmt.Fprintf(
+ &sb,
+ " Parameters: %s\n",
+ utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
+ )
}
}
sb.WriteString("]")
@@ -1193,7 +1332,9 @@ func (al *AgentLoop) summarizeBatch(
existingSummary string,
) (string, error) {
var sb strings.Builder
- sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
+ sb.WriteString(
+ "Provide a concise summary of this conversation segment, preserving core context and key points.\n",
+ )
if existingSummary != "" {
sb.WriteString("Existing context: ")
sb.WriteString(existingSummary)
diff --git a/pkg/agent/loop_media.go b/pkg/agent/loop_media.go
new file mode 100644
index 000000000..82547a008
--- /dev/null
+++ b/pkg/agent/loop_media.go
@@ -0,0 +1,122 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
+// License: MIT
+//
+// Copyright (c) 2026 PicoClaw contributors
+
+package agent
+
+import (
+ "bytes"
+ "encoding/base64"
+ "io"
+ "os"
+ "strings"
+
+ "github.com/h2non/filetype"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
+// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
+// both raw bytes and encoded string in memory simultaneously.
+// Returns a new slice; original messages are not mutated.
+func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
+ if store == nil {
+ return messages
+ }
+
+ result := make([]providers.Message, len(messages))
+ copy(result, messages)
+
+ for i, m := range result {
+ if len(m.Media) == 0 {
+ continue
+ }
+
+ resolved := make([]string, 0, len(m.Media))
+ for _, ref := range m.Media {
+ if !strings.HasPrefix(ref, "media://") {
+ resolved = append(resolved, ref)
+ continue
+ }
+
+ localPath, meta, err := store.ResolveWithMeta(ref)
+ if err != nil {
+ logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
+ "ref": ref,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ info, err := os.Stat(localPath)
+ if err != nil {
+ logger.WarnCF("agent", "Failed to stat media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+ if info.Size() > int64(maxSize) {
+ logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
+ "path": localPath,
+ "size": info.Size(),
+ "max_size": maxSize,
+ })
+ continue
+ }
+
+ // Determine MIME type: prefer metadata, fallback to magic-bytes detection
+ mime := meta.ContentType
+ if mime == "" {
+ kind, ftErr := filetype.MatchFile(localPath)
+ if ftErr != nil || kind == filetype.Unknown {
+ logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
+ "path": localPath,
+ })
+ continue
+ }
+ mime = kind.MIME.Value
+ }
+
+ // Streaming base64: open file → base64 encoder → buffer
+ // Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
+ f, err := os.Open(localPath)
+ if err != nil {
+ logger.WarnCF("agent", "Failed to open media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ prefix := "data:" + mime + ";base64,"
+ encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
+ var buf bytes.Buffer
+ buf.Grow(len(prefix) + encodedLen)
+ buf.WriteString(prefix)
+
+ encoder := base64.NewEncoder(base64.StdEncoding, &buf)
+ if _, err := io.Copy(encoder, f); err != nil {
+ f.Close()
+ logger.WarnCF("agent", "Failed to encode media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+ encoder.Close()
+ f.Close()
+
+ resolved = append(resolved, buf.String())
+ }
+
+ result[i].Media = resolved
+ }
+
+ return result
+}
diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go
index 6915f07bd..55098fa61 100644
--- a/pkg/agent/loop_test.go
+++ b/pkg/agent/loop_test.go
@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
+ "slices"
"strings"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -27,16 +29,15 @@ func (f *fakeChannel) IsAllowed(string) bool {
func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
func (f *fakeChannel) ReasoningChannelID() string { return f.id }
-func TestRecordLastChannel(t *testing.T) {
- // Create temp workspace
+func newTestAgentLoop(
+ t *testing.T,
+) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
+ t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
- defer os.RemoveAll(tmpDir)
-
- // Create test config
- cfg := &config.Config{
+ cfg = &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
@@ -46,74 +47,43 @@ func TestRecordLastChannel(t *testing.T) {
},
},
}
+ msgBus = bus.NewMessageBus()
+ provider = &mockProvider{}
+ al = NewAgentLoop(cfg, msgBus, provider)
+ return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
+}
- // Create agent loop
- msgBus := bus.NewMessageBus()
- provider := &mockProvider{}
- al := NewAgentLoop(cfg, msgBus, provider)
+func TestRecordLastChannel(t *testing.T) {
+ al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
+ defer cleanup()
- // Test RecordLastChannel
testChannel := "test-channel"
- err = al.RecordLastChannel(testChannel)
- if err != nil {
+ if err := al.RecordLastChannel(testChannel); err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
-
- // Verify channel was saved
- lastChannel := al.state.GetLastChannel()
- if lastChannel != testChannel {
- t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
+ if got := al.state.GetLastChannel(); got != testChannel {
+ t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
}
-
- // Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
- if al2.state.GetLastChannel() != testChannel {
- t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
+ if got := al2.state.GetLastChannel(); got != testChannel {
+ t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
}
}
func TestRecordLastChatID(t *testing.T) {
- // Create temp workspace
- tmpDir, err := os.MkdirTemp("", "agent-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
+ al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
+ defer cleanup()
- // Create test config
- cfg := &config.Config{
- Agents: config.AgentsConfig{
- Defaults: config.AgentDefaults{
- Workspace: tmpDir,
- Model: "test-model",
- MaxTokens: 4096,
- MaxToolIterations: 10,
- },
- },
- }
-
- // Create agent loop
- msgBus := bus.NewMessageBus()
- provider := &mockProvider{}
- al := NewAgentLoop(cfg, msgBus, provider)
-
- // Test RecordLastChatID
testChatID := "test-chat-id-123"
- err = al.RecordLastChatID(testChatID)
- if err != nil {
+ if err := al.RecordLastChatID(testChatID); err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
-
- // Verify chat ID was saved
- lastChatID := al.state.GetLastChatID()
- if lastChatID != testChatID {
- t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
+ if got := al.state.GetLastChatID(); got != testChatID {
+ t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
}
-
- // Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
- if al2.state.GetLastChatID() != testChatID {
- t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
+ if got := al2.state.GetLastChatID(); got != testChatID {
+ t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
}
}
@@ -188,13 +158,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
- found := false
- for _, name := range toolsList {
- if name == "mock_custom" {
- found = true
- break
- }
- }
+ found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -263,13 +227,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
- found := false
- for _, name := range toolsList {
- if name == "mock_custom" {
- found = true
- break
- }
- }
+ found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -931,3 +889,142 @@ func TestHandleReasoning(t *testing.T) {
}
})
}
+
+func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ // Create a minimal valid PNG (8-byte header is enough for filetype detection)
+ pngPath := filepath.Join(dir, "test.png")
+ // PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
+ pngHeader := []byte{
+ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
+ 0x00, 0x00, 0x00, 0x0D, // IHDR length
+ 0x49, 0x48, 0x44, 0x52, // "IHDR"
+ 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
+ 0x00, 0x00, 0x00, // no interlace
+ 0x90, 0x77, 0x53, 0xDE, // CRC
+ }
+ if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
+ t.Fatal(err)
+ }
+ ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ messages := []providers.Message{
+ {Role: "user", Content: "describe this", Media: []string{ref}},
+ }
+ result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 1 {
+ t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
+ }
+ if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
+ t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
+ }
+}
+
+func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ bigPath := filepath.Join(dir, "big.png")
+ // Write PNG header + padding to exceed limit
+ data := make([]byte, 1024+1) // 1KB + 1 byte
+ copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
+ if err := os.WriteFile(bigPath, data, 0o644); err != nil {
+ t.Fatal(err)
+ }
+ ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
+
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ // Use a tiny limit (1KB) so the file is oversized
+ result := resolveMediaRefs(messages, store, 1024)
+
+ if len(result[0].Media) != 0 {
+ t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
+ }
+}
+
+func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ txtPath := filepath.Join(dir, "readme.txt")
+ if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
+
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 0 {
+ t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
+ }
+}
+
+func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
+ }
+ result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
+ t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
+ }
+}
+
+func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+ pngPath := filepath.Join(dir, "test.png")
+ pngHeader := []byte{
+ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
+ 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
+ 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
+ 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
+ }
+ os.WriteFile(pngPath, pngHeader, 0o644)
+ ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
+
+ original := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ originalRef := original[0].Media[0]
+
+ resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
+
+ if original[0].Media[0] != originalRef {
+ t.Fatal("resolveMediaRefs mutated original message slice")
+ }
+}
+
+func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ // File with JPEG content but stored with explicit content type
+ jpegPath := filepath.Join(dir, "photo")
+ jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
+ os.WriteFile(jpegPath, jpegHeader, 0o644)
+ ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
+
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 1 {
+ t.Fatalf("expected 1 media, got %d", len(result[0].Media))
+ }
+ if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
+ t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
+ }
+}
diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go
index 87a687479..01e682f3b 100644
--- a/pkg/agent/memory.go
+++ b/pkg/agent/memory.go
@@ -111,7 +111,7 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
var sb strings.Builder
first := true
- for i := 0; i < days; i++ {
+ for i := range days {
date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go
index a50586df1..e07b8c7fe 100644
--- a/pkg/bus/bus_test.go
+++ b/pkg/bus/bus_test.go
@@ -67,7 +67,7 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
// Fill the buffer
ctx := context.Background()
- for i := 0; i < defaultBusBufferSize; i++ {
+ for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
@@ -154,7 +154,7 @@ func TestConcurrentPublishClose(t *testing.T) {
wg.Add(numGoroutines + 1)
// Spawn many goroutines trying to publish
- for i := 0; i < numGoroutines; i++ {
+ for range numGoroutines {
go func() {
defer wg.Done()
// Use a short timeout context so we don't block forever after close
@@ -194,7 +194,7 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
ctx := context.Background()
// Fill the buffer
- for i := 0; i < defaultBusBufferSize; i++ {
+ for i := range defaultBusBufferSize {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
diff --git a/pkg/channels/README.md b/pkg/channels/README.md
index 52b9f98f4..b7c56660b 100644
--- a/pkg/channels/README.md
+++ b/pkg/channels/README.md
@@ -1,7 +1,5 @@
-# PicoClaw Channel System Refactor: Complete Development Guide
+# PicoClaw Channel System: Complete Development Guide
-> **Branch**: `refactor/channel-system`
-> **Status**: Active development (~40 commits)
> **Scope**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
---
@@ -46,6 +44,8 @@ pkg/channels/
pkg/channels/
├── base.go # BaseChannel shared abstraction layer
├── interfaces.go # Optional capability interfaces (TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
+├── README.md # English documentation
+├── README.zh.md # Chinese documentation
├── media.go # MediaSender optional interface
├── webhook.go # WebhookHandler, HealthChecker optional interfaces
├── errors.go # Sentinel errors (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
@@ -60,7 +60,7 @@ pkg/channels/
├── discord/
│ ├── init.go
│ └── discord.go
-├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
+├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
│ └── ...
pkg/bus/
@@ -111,7 +111,7 @@ pkg/identity/
|-----------|-------------|
| **Sub-package Isolation** | Each channel is a standalone Go sub-package, depending on `BaseChannel` and interfaces from the `channels` parent package |
| **Factory Registration** | Sub-packages self-register via `init()`, Manager looks up factories by name, eliminating import coupling |
-| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`), discovered by Manager via runtime type assertions |
+| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`), discovered by Manager via runtime type assertions |
| **Structured Messages** | Peer, MessageID, and SenderInfo promoted from Metadata to first-class fields on InboundMessage |
| **Error Classification** | Channels return sentinel errors (`ErrRateLimit`, `ErrTemporary`, etc.), Manager uses these to determine retry strategy |
| **Centralized Orchestration** | Rate limiting, message splitting, retries, and Typing/Reaction/Placeholder management are all handled by Manager and BaseChannel; channels only need to implement Send |
@@ -145,6 +145,7 @@ After refactoring, these files have been removed and code moved to corresponding
| _(did not exist)_ | `pkg/channels/interfaces.go` | New optional capability interfaces |
| _(did not exist)_ | `pkg/channels/media.go` | New MediaSender interface |
| _(did not exist)_ | `pkg/channels/webhook.go` | New WebhookHandler/HealthChecker |
+| _(did not exist)_ | `pkg/channels/whatsapp_native/` | New WhatsApp native mode (whatsmeow) |
| _(did not exist)_ | `pkg/channels/split.go` | New message splitting (migrated from utils) |
| _(did not exist)_ | `pkg/bus/types.go` | New structured message types |
| _(did not exist)_ | `pkg/media/store.go` | New media file lifecycle management |
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
cfg.Channels.Telegram.AllowFrom, // Allow list
channels.WithMaxMessageLength(4096), // Platform message length limit
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // Group trigger config
+ channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // Reasoning chain routing
)
return &TelegramChannel{
BaseChannel: base,
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
matrixCfg.AllowFrom, // Allow list
channels.WithMaxMessageLength(65536), // Matrix message length limit
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
+ channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // Reasoning chain routing (optional)
)
return &MatrixChannel{
@@ -666,6 +669,32 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
}
```
+#### PlaceholderCapable — Placeholder Messages
+
+```go
+// If the platform supports sending placeholder messages (e.g. "Thinking... 💭"),
+// and the channel also implements MessageEditor, then Manager's preSend will
+// automatically edit the placeholder into the final response on outbound.
+// SendPlaceholder checks PlaceholderConfig.Enabled internally;
+// returning ("", nil) means skip.
+func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ cfg := c.config.Channels.Matrix.Placeholder
+ if !cfg.Enabled {
+ return "", nil
+ }
+ text := cfg.Text
+ if text == "" {
+ text = "Thinking... 💭"
+ }
+ // Call Matrix API to send placeholder message
+ msg, err := c.sendText(ctx, chatID, text)
+ if err != nil {
+ return "", err
+ }
+ return msg.ID, nil
+}
+```
+
#### WebhookHandler — HTTP Webhook Reception
```go
@@ -746,15 +775,17 @@ When the Agent finishes processing a message, Manager's `preSend` automatically:
```go
type ChannelsConfig struct {
// ... existing channels
- Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
+ Matrix MatrixChannelConfig `json:"matrix"`
}
type MatrixChannelConfig struct {
- Enabled bool `yaml:"enabled" json:"enabled"`
- HomeServer string `yaml:"home_server" json:"home_server"`
- Token string `yaml:"token" json:"token"`
- AllowFrom []string `yaml:"allow_from" json:"allow_from"`
- GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
+ Enabled bool `json:"enabled"`
+ HomeServer string `json:"home_server"`
+ Token string `json:"token"`
+ AllowFrom []string `json:"allow_from"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger"`
+ Placeholder PlaceholderConfig `json:"placeholder"`
+ ReasoningChannelID string `json:"reasoning_channel_id"`
}
```
@@ -767,6 +798,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
}
```
+> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config:
+> ```go
+> if cfg.UseNative {
+> m.initChannel("whatsapp_native", "WhatsApp Native")
+> } else {
+> m.initChannel("whatsapp", "WhatsApp")
+> }
+> ```
+
#### Add blank import in Gateway
```go
@@ -882,19 +922,21 @@ BaseChannel is the shared abstraction layer for all channels, providing the foll
| `IsRunning() bool` | Atomically read running state |
| `SetRunning(bool)` | Atomically set running state |
| `MaxMessageLength() int` | Message length limit (rune count), 0 = unlimited |
+| `ReasoningChannelID() string` | Reasoning chain routing target channel ID (empty = no routing) |
| `IsAllowed(senderID string) bool` | Legacy allow-list check (supports `"id\|username"` and `"@username"` formats) |
| `IsAllowedSender(sender SenderInfo) bool` | New allow-list check (delegates to `identity.MatchAllowed`) |
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | Unified group chat trigger filtering logic |
-| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction → publish to Bus |
+| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction/Placeholder → publish to Bus |
| `SetMediaStore(s) / GetMediaStore()` | MediaStore injected by Manager |
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | PlaceholderRecorder injected by Manager |
-| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction type assertions in HandleMessage) |
+| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction/Placeholder type assertions in HandleMessage) |
**Functional Options**:
```go
channels.WithMaxMessageLength(4096) // Set platform message length limit
channels.WithGroupTrigger(groupTriggerCfg) // Set group trigger configuration
+channels.WithReasoningChannelID(id) // Set reasoning chain routing target channel
```
### 4.4 Factory Registry
@@ -998,7 +1040,7 @@ StartAll:
- runMediaWorker (per-channel outbound media)
- dispatchOutbound (route from bus to worker queues)
- dispatchOutboundMedia (route from bus to media worker queues)
- - runTTLJanitor (every 10s clean up expired typing/placeholder)
+ - runTTLJanitor (every 10s clean up expired typing/reaction/placeholder)
4. Start shared HTTP server (if configured)
StopAll:
@@ -1206,18 +1248,20 @@ make test # Full test suite
| Sub-package | Registered Name | Optional Interfaces |
|-------------|----------------|-------------------|
-| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
-| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
-| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
-| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
-| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
-| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
-| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (architecture-specific build tags) |
-| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
+| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
+| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
+| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
+| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
+| `pkg/channels/feishu/` | `"feishu"` | — (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) |
+| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
+| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
| `pkg/channels/qq/` | `"qq"` | — |
-| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
+| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) |
+| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) |
| `pkg/channels/maixcam/` | `"maixcam"` | — |
-| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
+| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
### A.3 Interface Quick Reference
@@ -1231,6 +1275,7 @@ type Channel interface {
IsRunning() bool
IsAllowed(senderID string) bool
IsAllowedSender(sender bus.SenderInfo) bool
+ ReasoningChannelID() string
}
// ===== Optional =====
@@ -1324,8 +1369,16 @@ agentLoop.Stop() // Stop Agent
1. **Media cleanup temporarily disabled**: The `ReleaseAll` call in the Agent loop is commented out (`refactor(loop): disable media cleanup to prevent premature file deletion`) because session boundaries are not yet clearly defined. TTL cleanup remains active.
-2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`).
+2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
-3. **WeCom has two factories**: `"wecom"` (Bot mode) and `"wecom_app"` (App mode) are registered separately.
+3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`.
-4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via webhook.
\ No newline at end of file
+4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`).
+
+5. **WhatsApp has two modes**: `"whatsapp"` (Bridge mode, communicates via external bridge URL) and `"whatsapp_native"` (native whatsmeow mode, connects directly to WhatsApp). Manager selects which to initialize based on `WhatsAppConfig.UseNative`.
+
+6. **DingTalk uses Stream mode**: DingTalk uses the SDK's Stream/WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
+
+7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields.
+
+8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method.
\ No newline at end of file
diff --git a/pkg/channels/README.zh.md b/pkg/channels/README.zh.md
index 0a9487cd0..2c5e7356e 100644
--- a/pkg/channels/README.zh.md
+++ b/pkg/channels/README.zh.md
@@ -1,7 +1,5 @@
-# PicoClaw Channel System 重构:完整开发指南
+# PicoClaw Channel System:完整开发指南
-> **分支**: `refactor/channel-system`
-> **状态**: 活跃开发中(约 40 commits)
> **影响范围**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
---
@@ -46,6 +44,8 @@ pkg/channels/
pkg/channels/
├── base.go # BaseChannel 共享抽象层
├── interfaces.go # 可选能力接口(TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
+├── README.md # 英文文档
+├── README.zh.md # 中文文档
├── media.go # MediaSender 可选接口
├── webhook.go # WebhookHandler, HealthChecker 可选接口
├── errors.go # 错误哨兵值(ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
@@ -60,7 +60,7 @@ pkg/channels/
├── discord/
│ ├── init.go
│ └── discord.go
-├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ maixcam/ pico/
+├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
│ └── ...
pkg/bus/
@@ -111,7 +111,7 @@ pkg/identity/
|------|------|
| **子包隔离** | 每个 channel 一个独立 Go 子包,依赖 `channels` 父包提供的 `BaseChannel` 和接口 |
| **工厂注册** | 各子包通过 `init()` 自注册,Manager 通过名字查找工厂,消除 import 耦合 |
-| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`)声明,Manager 运行时类型断言发现 |
+| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`)声明,Manager 运行时类型断言发现 |
| **结构化消息** | Peer、MessageID、SenderInfo 从 Metadata 提升为 InboundMessage 的一等字段 |
| **错误分类** | Channel 返回哨兵错误(`ErrRateLimit`, `ErrTemporary` 等),Manager 据此决定重试策略 |
| **集中编排** | 速率限制、消息分割、重试、Typing/Reaction/Placeholder 全部由 Manager 和 BaseChannel 统一处理,Channel 只负责 Send |
@@ -145,6 +145,7 @@ pkg/identity/
| _(不存在)_ | `pkg/channels/interfaces.go` | 新增可选能力接口 |
| _(不存在)_ | `pkg/channels/media.go` | 新增 MediaSender 接口 |
| _(不存在)_ | `pkg/channels/webhook.go` | 新增 WebhookHandler/HealthChecker |
+| _(不存在)_ | `pkg/channels/whatsapp_native/` | 新增 WhatsApp 原生模式(whatsmeow) |
| _(不存在)_ | `pkg/channels/split.go` | 新增消息分割(从 utils 迁入) |
| _(不存在)_ | `pkg/bus/types.go` | 新增结构化消息类型 |
| _(不存在)_ | `pkg/media/store.go` | 新增媒体文件生命周期管理 |
@@ -220,6 +221,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
cfg.Channels.Telegram.AllowFrom, // 允许列表
channels.WithMaxMessageLength(4096), // 平台消息长度上限
channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // 群聊触发配置
+ channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // 思维链路由
)
return &TelegramChannel{
BaseChannel: base,
@@ -466,6 +468,7 @@ func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChanne
matrixCfg.AllowFrom, // 允许列表
channels.WithMaxMessageLength(65536), // Matrix 消息长度限制
channels.WithGroupTrigger(matrixCfg.GroupTrigger),
+ channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // 思维链路由(可选)
)
return &MatrixChannel{
@@ -666,6 +669,31 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, cont
}
```
+#### PlaceholderCapable — 占位消息
+
+```go
+// 如果平台支持发送占位消息(如 "Thinking... 💭"),并且实现了 MessageEditor,
+// 则 Manager 的 preSend 会在出站时自动将占位消息编辑为最终回复。
+// SendPlaceholder 内部根据 PlaceholderConfig.Enabled 决定是否发送;
+// 返回 ("", nil) 表示跳过。
+func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ cfg := c.config.Channels.Matrix.Placeholder
+ if !cfg.Enabled {
+ return "", nil
+ }
+ text := cfg.Text
+ if text == "" {
+ text = "Thinking... 💭"
+ }
+ // 调用 Matrix API 发送占位消息
+ msg, err := c.sendText(ctx, chatID, text)
+ if err != nil {
+ return "", err
+ }
+ return msg.ID, nil
+}
+```
+
#### WebhookHandler — HTTP Webhook 接收
```go
@@ -746,15 +774,17 @@ if c.owner != nil && c.placeholderRecorder != nil {
```go
type ChannelsConfig struct {
// ... 现有 channels
- Matrix MatrixChannelConfig `yaml:"matrix" json:"matrix"`
+ Matrix MatrixChannelConfig `json:"matrix"`
}
type MatrixChannelConfig struct {
- Enabled bool `yaml:"enabled" json:"enabled"`
- HomeServer string `yaml:"home_server" json:"home_server"`
- Token string `yaml:"token" json:"token"`
- AllowFrom []string `yaml:"allow_from" json:"allow_from"`
- GroupTrigger GroupTriggerConfig `yaml:"group_trigger" json:"group_trigger"`
+ Enabled bool `json:"enabled"`
+ HomeServer string `json:"home_server"`
+ Token string `json:"token"`
+ AllowFrom []string `json:"allow_from"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger"`
+ Placeholder PlaceholderConfig `json:"placeholder"`
+ ReasoningChannelID string `json:"reasoning_channel_id"`
}
```
@@ -767,6 +797,15 @@ if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
}
```
+> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),需要在 initChannels 中根据配置分支:
+> ```go
+> if cfg.UseNative {
+> m.initChannel("whatsapp_native", "WhatsApp Native")
+> } else {
+> m.initChannel("whatsapp", "WhatsApp")
+> }
+> ```
+
#### 在 Gateway 中添加 blank import
```go
@@ -882,19 +921,21 @@ BaseChannel 是所有 channel 的共享抽象层,提供以下能力:
| `IsRunning() bool` | 原子读取运行状态 |
| `SetRunning(bool)` | 原子设置运行状态 |
| `MaxMessageLength() int` | 消息长度限制(rune 计数),0 = 无限制 |
+| `ReasoningChannelID() string` | 思维链路由目标 channel ID(空 = 不路由) |
| `IsAllowed(senderID string) bool` | 旧格式允许列表检查(支持 `"id\|username"` 和 `"@username"` 格式) |
| `IsAllowedSender(sender SenderInfo) bool` | 新格式允许列表检查(委托给 `identity.MatchAllowed`) |
| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | 统一群聊触发过滤逻辑 |
-| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction → 发布到 Bus |
+| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction/Placeholder → 发布到 Bus |
| `SetMediaStore(s) / GetMediaStore()` | Manager 注入的媒体存储 |
| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | Manager 注入的占位符记录器 |
-| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction 类型断言) |
+| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction/Placeholder 类型断言) |
**功能选项**:
```go
channels.WithMaxMessageLength(4096) // 设置平台消息长度限制
channels.WithGroupTrigger(groupTriggerCfg) // 设置群聊触发配置
+channels.WithReasoningChannelID(id) // 设置思维链路由目标 channel
```
### 4.4 工厂注册表
@@ -998,7 +1039,7 @@ StartAll:
- runMediaWorker (per-channel 出站媒体)
- dispatchOutbound (从 bus 路由到 worker 队列)
- dispatchOutboundMedia (从 bus 路由到 media worker 队列)
- - runTTLJanitor (每 10s 清理过期 typing/placeholder)
+ - runTTLJanitor (每 10s 清理过期 typing/reaction/placeholder)
4. 启动共享 HTTP 服务器(如已配置)
StopAll:
@@ -1206,18 +1247,20 @@ make test # 全量测试
| 子包 | 注册名 | 可选接口 |
|------|--------|----------|
-| `pkg/channels/telegram/` | `"telegram"` | MessageEditor, MediaSender, TypingCapable, PlaceholderCapable |
-| `pkg/channels/discord/` | `"discord"` | MessageEditor, TypingCapable, PlaceholderCapable |
-| `pkg/channels/slack/` | `"slack"` | ReactionCapable |
-| `pkg/channels/line/` | `"line"` | WebhookHandler, HealthChecker, TypingCapable |
-| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable |
-| `pkg/channels/dingtalk/` | `"dingtalk"` | WebhookHandler |
-| `pkg/channels/feishu/` | `"feishu"` | WebhookHandler (架构特定 build tags) |
-| `pkg/channels/wecom/` | `"wecom"` + `"wecom_app"` | WebhookHandler |
+| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
+| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
+| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
+| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
+| `pkg/channels/feishu/` | `"feishu"` | — (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) |
+| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
+| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
| `pkg/channels/qq/` | `"qq"` | — |
-| `pkg/channels/whatsapp/` | `"whatsapp"` | — |
+| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) |
+| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) |
| `pkg/channels/maixcam/` | `"maixcam"` | — |
-| `pkg/channels/pico/` | `"pico"` | WebhookHandler (Pico Protocol), TypingCapable, PlaceholderCapable |
+| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
### A.3 接口速查表
@@ -1231,6 +1274,7 @@ type Channel interface {
IsRunning() bool
IsAllowed(senderID string) bool
IsAllowedSender(sender bus.SenderInfo) bool
+ ReasoningChannelID() string
}
// ===== 可选实现 =====
@@ -1324,8 +1368,16 @@ agentLoop.Stop() // 停止 Agent
1. **媒体清理暂时禁用**:Agent loop 中的 `ReleaseAll` 调用被注释掉了(`refactor(loop): disable media cleanup to prevent premature file deletion`),因为会话边界尚未明确定义。TTL 清理仍然有效。
-2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。
+2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
-3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式)和 `"wecom_app"`(应用模式)分别注册。
+3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler` 和 `HealthChecker`。
-4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 webhook 接收消息。
\ No newline at end of file
+4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。
+
+5. **WhatsApp 有两种模式**:`"whatsapp"`(Bridge 模式,通过外部 bridge URL 通信)和 `"whatsapp_native"`(原生 whatsmeow 模式,直接连接 WhatsApp)。Manager 根据 `WhatsAppConfig.UseNative` 决定初始化哪个。
+
+6. **DingTalk 使用 Stream 模式**:DingTalk 使用 SDK 的 Stream/WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
+
+7. **PlaceholderConfig 的配置与实现**:`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channel(Telegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。
+
+8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。
\ No newline at end of file
diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go
index e8a057741..fbe085b73 100644
--- a/pkg/channels/feishu/common.go
+++ b/pkg/channels/feishu/common.go
@@ -1,5 +1,16 @@
package feishu
+import (
+ "encoding/json"
+ "regexp"
+ "strings"
+
+ larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+)
+
+// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
+var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`)
+
// stringValue safely dereferences a *string pointer.
func stringValue(v *string) string {
if v == nil {
@@ -7,3 +18,69 @@ func stringValue(v *string) string {
}
return *v
}
+
+// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content.
+// JSON 2.0 cards support full CommonMark standard markdown syntax.
+func buildMarkdownCard(content string) (string, error) {
+ card := map[string]any{
+ "schema": "2.0",
+ "body": map[string]any{
+ "elements": []map[string]any{
+ {
+ "tag": "markdown",
+ "content": content,
+ },
+ },
+ },
+ }
+ data, err := json.Marshal(card)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+// extractJSONStringField unmarshals content as JSON and returns the value of the given string field.
+// Returns "" if the content is invalid JSON or the field is missing/empty.
+func extractJSONStringField(content, field string) string {
+ var m map[string]json.RawMessage
+ if err := json.Unmarshal([]byte(content), &m); err != nil {
+ return ""
+ }
+ raw, ok := m[field]
+ if !ok {
+ return ""
+ }
+ var s string
+ if err := json.Unmarshal(raw, &s); err != nil {
+ return ""
+ }
+ return s
+}
+
+// extractImageKey extracts the image_key from a Feishu image message content JSON.
+// Format: {"image_key": "img_xxx"}
+func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") }
+
+// extractFileKey extracts the file_key from a Feishu file/audio message content JSON.
+// Format: {"file_key": "file_xxx", "file_name": "...", ...}
+func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") }
+
+// extractFileName extracts the file_name from a Feishu file message content JSON.
+func extractFileName(content string) string { return extractJSONStringField(content, "file_name") }
+
+// stripMentionPlaceholders removes @_user_N placeholders from the text content.
+// These are inserted by Feishu when users @mention someone in a message.
+func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string {
+ if len(mentions) == 0 {
+ return content
+ }
+ for _, m := range mentions {
+ if m.Key != nil && *m.Key != "" {
+ content = strings.ReplaceAll(content, *m.Key, "")
+ }
+ }
+ // Also clean up any remaining @_user_N patterns
+ content = mentionPlaceholderRegex.ReplaceAllString(content, "")
+ return strings.TrimSpace(content)
+}
diff --git a/pkg/channels/feishu/common_test.go b/pkg/channels/feishu/common_test.go
new file mode 100644
index 000000000..fefc9f7c1
--- /dev/null
+++ b/pkg/channels/feishu/common_test.go
@@ -0,0 +1,292 @@
+package feishu
+
+import (
+ "encoding/json"
+ "testing"
+
+ larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+)
+
+func TestExtractJSONStringField(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ field string
+ want string
+ }{
+ {
+ name: "valid field",
+ content: `{"image_key": "img_v2_xxx"}`,
+ field: "image_key",
+ want: "img_v2_xxx",
+ },
+ {
+ name: "missing field",
+ content: `{"image_key": "img_v2_xxx"}`,
+ field: "file_key",
+ want: "",
+ },
+ {
+ name: "invalid JSON",
+ content: `not json at all`,
+ field: "image_key",
+ want: "",
+ },
+ {
+ name: "empty content",
+ content: "",
+ field: "image_key",
+ want: "",
+ },
+ {
+ name: "non-string field value",
+ content: `{"count": 42}`,
+ field: "count",
+ want: "",
+ },
+ {
+ name: "empty string value",
+ content: `{"image_key": ""}`,
+ field: "image_key",
+ want: "",
+ },
+ {
+ name: "multiple fields",
+ content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`,
+ field: "file_name",
+ want: "test.pdf",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractJSONStringField(tt.content, tt.field)
+ if got != tt.want {
+ t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractImageKey(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ want string
+ }{
+ {
+ name: "normal",
+ content: `{"image_key": "img_v2_abc123"}`,
+ want: "img_v2_abc123",
+ },
+ {
+ name: "missing key",
+ content: `{"file_key": "file_xxx"}`,
+ want: "",
+ },
+ {
+ name: "malformed JSON",
+ content: `{broken`,
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractImageKey(tt.content)
+ if got != tt.want {
+ t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractFileKey(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ want string
+ }{
+ {
+ name: "normal",
+ content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`,
+ want: "file_v2_abc123",
+ },
+ {
+ name: "missing key",
+ content: `{"image_key": "img_xxx"}`,
+ want: "",
+ },
+ {
+ name: "malformed JSON",
+ content: `not json`,
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractFileKey(tt.content)
+ if got != tt.want {
+ t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractFileName(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ want string
+ }{
+ {
+ name: "normal",
+ content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
+ want: "report.pdf",
+ },
+ {
+ name: "missing name",
+ content: `{"file_key": "file_xxx"}`,
+ want: "",
+ },
+ {
+ name: "malformed JSON",
+ content: `{bad`,
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractFileName(tt.content)
+ if got != tt.want {
+ t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestBuildMarkdownCard(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ }{
+ {
+ name: "normal content",
+ content: "Hello **world**",
+ },
+ {
+ name: "empty content",
+ content: "",
+ },
+ {
+ name: "special characters",
+ content: `Code: "foo" & 'baz'`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := buildMarkdownCard(tt.content)
+ if err != nil {
+ t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err)
+ }
+
+ // Verify valid JSON
+ var parsed map[string]any
+ if err := json.Unmarshal([]byte(result), &parsed); err != nil {
+ t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err)
+ }
+
+ // Verify schema
+ if parsed["schema"] != "2.0" {
+ t.Errorf("schema = %v, want %q", parsed["schema"], "2.0")
+ }
+
+ // Verify body.elements[0].content == input
+ body, ok := parsed["body"].(map[string]any)
+ if !ok {
+ t.Fatal("missing body in card JSON")
+ }
+ elements, ok := body["elements"].([]any)
+ if !ok || len(elements) == 0 {
+ t.Fatal("missing or empty elements in card JSON")
+ }
+ elem, ok := elements[0].(map[string]any)
+ if !ok {
+ t.Fatal("first element is not an object")
+ }
+ if elem["tag"] != "markdown" {
+ t.Errorf("tag = %v, want %q", elem["tag"], "markdown")
+ }
+ if elem["content"] != tt.content {
+ t.Errorf("content = %v, want %q", elem["content"], tt.content)
+ }
+ })
+ }
+}
+
+func TestStripMentionPlaceholders(t *testing.T) {
+ strPtr := func(s string) *string { return &s }
+
+ tests := []struct {
+ name string
+ content string
+ mentions []*larkim.MentionEvent
+ want string
+ }{
+ {
+ name: "no mentions",
+ content: "Hello world",
+ mentions: nil,
+ want: "Hello world",
+ },
+ {
+ name: "single mention",
+ content: "@_user_1 hello",
+ mentions: []*larkim.MentionEvent{
+ {Key: strPtr("@_user_1")},
+ },
+ want: "hello",
+ },
+ {
+ name: "multiple mentions",
+ content: "@_user_1 @_user_2 hey",
+ mentions: []*larkim.MentionEvent{
+ {Key: strPtr("@_user_1")},
+ {Key: strPtr("@_user_2")},
+ },
+ want: "hey",
+ },
+ {
+ name: "empty content",
+ content: "",
+ mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}},
+ want: "",
+ },
+ {
+ name: "empty mentions slice",
+ content: "@_user_1 test",
+ mentions: []*larkim.MentionEvent{},
+ want: "@_user_1 test",
+ },
+ {
+ name: "mention with nil key",
+ content: "@_user_1 test",
+ mentions: []*larkim.MentionEvent{
+ {Key: nil},
+ },
+ want: "test",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := stripMentionPlaceholders(tt.content, tt.mentions)
+ if got != tt.want {
+ t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/channels/feishu/feishu_32.go b/pkg/channels/feishu/feishu_32.go
index d0ec758c6..f5e3aa224 100644
--- a/pkg/channels/feishu/feishu_32.go
+++ b/pkg/channels/feishu/feishu_32.go
@@ -16,6 +16,8 @@ type FeishuChannel struct {
*channels.BaseChannel
}
+var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures")
+
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
return nil, errors.New(
@@ -25,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
// Start is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Start(ctx context.Context) error {
- return nil
+ return errUnsupported
}
// Stop is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Stop(ctx context.Context) error {
- return nil
+ return errUnsupported
}
// Send is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
- return errors.New("feishu channel is not supported on 32-bit architectures")
+ return errUnsupported
+}
+
+// EditMessage is a stub method to satisfy MessageEditor
+func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
+ return errUnsupported
+}
+
+// SendPlaceholder is a stub method to satisfy PlaceholderCapable
+func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ return "", errUnsupported
+}
+
+// ReactToMessage is a stub method to satisfy ReactionCapable
+func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
+ return func() {}, errUnsupported
+}
+
+// SendMedia is a stub method to satisfy MediaSender
+func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ return errUnsupported
}
diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go
index 1db1bf669..00f73064d 100644
--- a/pkg/channels/feishu/feishu_64.go
+++ b/pkg/channels/feishu/feishu_64.go
@@ -6,10 +6,15 @@ import (
"context"
"encoding/json"
"fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
"sync"
- "time"
+ "sync/atomic"
lark "github.com/larksuite/oapi-sdk-go/v3"
+ larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
@@ -19,6 +24,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -28,6 +34,8 @@ type FeishuChannel struct {
client *lark.Client
wsClient *larkws.Client
+ botOpenID atomic.Value // stores string; populated lazily for @mention detection
+
mu sync.Mutex
cancel context.CancelFunc
}
@@ -38,11 +46,13 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
- return &FeishuChannel{
+ ch := &FeishuChannel{
BaseChannel: base,
config: cfg,
client: lark.NewClient(cfg.AppID, cfg.AppSecret),
- }, nil
+ }
+ ch.SetOwner(ch)
+ return ch, nil
}
func (c *FeishuChannel) Start(ctx context.Context) error {
@@ -50,6 +60,13 @@ func (c *FeishuChannel) Start(ctx context.Context) error {
return fmt.Errorf("feishu app_id or app_secret is empty")
}
+ // Fetch bot open_id via API for reliable @mention detection.
+ if err := c.fetchBotOpenID(ctx); err != nil {
+ logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{
+ "error": err.Error(),
+ })
+ }
+
dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey).
OnP2MessageReceiveV1(c.handleMessageReceive)
@@ -93,46 +110,213 @@ func (c *FeishuChannel) Stop(ctx context.Context) error {
return nil
}
+// Send sends a message using Interactive Card format for markdown rendering.
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
}
if msg.ChatID == "" {
- return fmt.Errorf("chat ID is empty")
+ return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
}
- payload, err := json.Marshal(map[string]string{"text": msg.Content})
+ // Build interactive card with markdown content
+ cardContent, err := buildMarkdownCard(msg.Content)
if err != nil {
- return fmt.Errorf("failed to marshal feishu content: %w", err)
+ return fmt.Errorf("feishu send: card build failed: %w", err)
+ }
+ return c.sendCard(ctx, msg.ChatID, cardContent)
+}
+
+// EditMessage implements channels.MessageEditor.
+// Uses Message.Patch to update an interactive card message.
+func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
+ cardContent, err := buildMarkdownCard(content)
+ if err != nil {
+ return fmt.Errorf("feishu edit: card build failed: %w", err)
+ }
+
+ req := larkim.NewPatchMessageReqBuilder().
+ MessageId(messageID).
+ Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Patch(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu edit: %w", err)
+ }
+ if !resp.Success() {
+ return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+ return nil
+}
+
+// SendPlaceholder implements channels.PlaceholderCapable.
+// Sends an interactive card with placeholder text and returns its message ID.
+func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ if !c.config.Placeholder.Enabled {
+ logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{
+ "chat_id": chatID,
+ })
+ return "", nil
+ }
+
+ text := c.config.Placeholder.Text
+ if text == "" {
+ text = "Thinking..."
+ }
+
+ cardContent, err := buildMarkdownCard(text)
+ if err != nil {
+ return "", fmt.Errorf("feishu placeholder: card build failed: %w", err)
}
req := larkim.NewCreateMessageReqBuilder().
ReceiveIdType(larkim.ReceiveIdTypeChatId).
Body(larkim.NewCreateMessageReqBodyBuilder().
- ReceiveId(msg.ChatID).
- MsgType(larkim.MsgTypeText).
- Content(string(payload)).
- Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())).
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeInteractive).
+ Content(cardContent).
Build()).
Build()
resp, err := c.client.Im.V1.Message.Create(ctx, req)
if err != nil {
- return fmt.Errorf("feishu send: %w", channels.ErrTemporary)
+ return "", fmt.Errorf("feishu placeholder send: %w", err)
}
-
if !resp.Success() {
- return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
+ return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
- logger.DebugCF("feishu", "Feishu message sent", map[string]any{
- "chat_id": msg.ChatID,
- })
+ if resp.Data != nil && resp.Data.MessageId != nil {
+ return *resp.Data.MessageId, nil
+ }
+ return "", nil
+}
+
+// ReactToMessage implements channels.ReactionCapable.
+// Adds an "Pin" reaction and returns an undo function to remove it.
+func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
+ req := larkim.NewCreateMessageReactionReqBuilder().
+ MessageId(messageID).
+ Body(larkim.NewCreateMessageReactionReqBodyBuilder().
+ ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{
+ "message_id": messageID,
+ "error": err.Error(),
+ })
+ return func() {}, fmt.Errorf("feishu react: %w", err)
+ }
+ if !resp.Success() {
+ logger.ErrorCF("feishu", "Reaction API error", map[string]any{
+ "message_id": messageID,
+ "code": resp.Code,
+ "msg": resp.Msg,
+ })
+ return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+
+ var reactionID string
+ if resp.Data != nil && resp.Data.ReactionId != nil {
+ reactionID = *resp.Data.ReactionId
+ }
+ if reactionID == "" {
+ return func() {}, nil
+ }
+
+ var undone atomic.Bool
+ undo := func() {
+ if !undone.CompareAndSwap(false, true) {
+ return
+ }
+ delReq := larkim.NewDeleteMessageReactionReqBuilder().
+ MessageId(messageID).
+ ReactionId(reactionID).
+ Build()
+ _, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq)
+ }
+ return undo, nil
+}
+
+// SendMedia implements channels.MediaSender.
+// Uploads images/files via Feishu API then sends as messages.
+func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ if msg.ChatID == "" {
+ return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ for _, part := range msg.Parts {
+ if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil {
+ return err
+ }
+ }
return nil
}
+// sendMediaPart resolves and sends a single media part.
+func (c *FeishuChannel) sendMediaPart(
+ ctx context.Context,
+ chatID string,
+ part bus.MediaPart,
+ store media.MediaStore,
+) error {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{
+ "ref": part.Ref,
+ "error": err.Error(),
+ })
+ return nil // skip this part
+ }
+
+ file, err := os.Open(localPath)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to open media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ return nil // skip this part
+ }
+ defer file.Close()
+
+ switch part.Type {
+ case "image":
+ err = c.sendImage(ctx, chatID, file)
+ default:
+ filename := part.Filename
+ if filename == "" {
+ filename = "file"
+ }
+ err = c.sendFile(ctx, chatID, file, filename, part.Type)
+ }
+
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to send media", map[string]any{
+ "type": part.Type,
+ "error": err.Error(),
+ })
+ return fmt.Errorf("feishu send media: %w", channels.ErrTemporary)
+ }
+ return nil
+}
+
+// --- Inbound message handling ---
+
func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
if event == nil || event.Event == nil || event.Event.Message == nil {
return nil
@@ -151,34 +335,68 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
senderID = "unknown"
}
- content := extractFeishuMessageContent(message)
+ messageType := stringValue(message.MessageType)
+ messageID := stringValue(message.MessageId)
+ rawContent := stringValue(message.Content)
+
+ // Check allowlist early to avoid downloading media for rejected senders.
+ // BaseChannel.HandleMessage will check again, but this avoids wasted network I/O.
+ senderInfo := bus.SenderInfo{
+ Platform: "feishu",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("feishu", senderID),
+ }
+ if !c.IsAllowedSender(senderInfo) {
+ return nil
+ }
+
+ // Extract content based on message type
+ content := extractContent(messageType, rawContent)
+
+ // Handle media messages (download and store)
+ var mediaRefs []string
+ if store := c.GetMediaStore(); store != nil && messageID != "" {
+ mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store)
+ }
+
+ // Append media tags to content (like Telegram does)
+ content = appendMediaTags(content, messageType, mediaRefs)
+
if content == "" {
content = "[empty message]"
}
metadata := map[string]string{}
- messageID := ""
- if mid := stringValue(message.MessageId); mid != "" {
- messageID = mid
+ if messageID != "" {
+ metadata["message_id"] = messageID
}
- if messageType := stringValue(message.MessageType); messageType != "" {
+ if messageType != "" {
metadata["message_type"] = messageType
}
- if chatType := stringValue(message.ChatType); chatType != "" {
+ chatType := stringValue(message.ChatType)
+ if chatType != "" {
metadata["chat_type"] = chatType
}
if sender != nil && sender.TenantKey != nil {
metadata["tenant_key"] = *sender.TenantKey
}
- chatType := stringValue(message.ChatType)
var peer bus.Peer
if chatType == "p2p" {
peer = bus.Peer{Kind: "direct", ID: senderID}
} else {
peer = bus.Peer{Kind: "group", ID: chatID}
+
+ // Check if bot was mentioned
+ isMentioned := c.isBotMentioned(message)
+
+ // Strip mention placeholders from content before group trigger check
+ if len(message.Mentions) > 0 {
+ content = stripMentionPlaceholders(content, message.Mentions)
+ }
+
// In group chats, apply unified group trigger filtering
- respond, cleaned := c.ShouldRespondInGroup(false, content)
+ respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
if !respond {
return nil
}
@@ -186,22 +404,398 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
}
logger.InfoCF("feishu", "Feishu message received", map[string]any{
- "sender_id": senderID,
- "chat_id": chatID,
- "preview": utils.Truncate(content, 80),
+ "sender_id": senderID,
+ "chat_id": chatID,
+ "message_id": messageID,
+ "preview": utils.Truncate(content, 80),
})
- senderInfo := bus.SenderInfo{
- Platform: "feishu",
- PlatformID: senderID,
- CanonicalID: identity.BuildCanonicalID("feishu", senderID),
+ c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
+ return nil
+}
+
+// --- Internal helpers ---
+
+// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id.
+func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error {
+ resp, err := c.client.Do(ctx, &larkcore.ApiReq{
+ HttpMethod: http.MethodGet,
+ ApiPath: "/open-apis/bot/v3/info",
+ SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant},
+ })
+ if err != nil {
+ return fmt.Errorf("bot info request: %w", err)
}
- if !c.IsAllowedSender(senderInfo) {
- return nil
+ var result struct {
+ Code int `json:"code"`
+ Bot struct {
+ OpenID string `json:"open_id"`
+ } `json:"bot"`
+ }
+ if err := json.Unmarshal(resp.RawBody, &result); err != nil {
+ return fmt.Errorf("bot info parse: %w", err)
+ }
+ if result.Code != 0 {
+ return fmt.Errorf("bot info api error (code=%d)", result.Code)
+ }
+ if result.Bot.OpenID == "" {
+ return fmt.Errorf("bot info: empty open_id")
}
- c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo)
+ c.botOpenID.Store(result.Bot.OpenID)
+ logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{
+ "open_id": result.Bot.OpenID,
+ })
+ return nil
+}
+
+// isBotMentioned checks if the bot was @mentioned in the message.
+func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool {
+ if message.Mentions == nil {
+ return false
+ }
+
+ knownID, _ := c.botOpenID.Load().(string)
+ if knownID == "" {
+ logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil)
+ return false
+ }
+
+ for _, m := range message.Mentions {
+ if m.Id == nil {
+ continue
+ }
+ if m.Id.OpenId != nil && *m.Id.OpenId == knownID {
+ return true
+ }
+ }
+ return false
+}
+
+// extractContent extracts text content from different message types.
+func extractContent(messageType, rawContent string) string {
+ if rawContent == "" {
+ return ""
+ }
+
+ switch messageType {
+ case larkim.MsgTypeText:
+ var textPayload struct {
+ Text string `json:"text"`
+ }
+ if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil {
+ return textPayload.Text
+ }
+ return rawContent
+
+ case larkim.MsgTypePost:
+ // Pass raw JSON to LLM — structured rich text is more informative than flattened plain text
+ return rawContent
+
+ case larkim.MsgTypeImage:
+ // Image messages don't have text content
+ return ""
+
+ case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
+ // File/audio/video messages may have a filename
+ name := extractFileName(rawContent)
+ if name != "" {
+ return name
+ }
+ return ""
+
+ default:
+ return rawContent
+ }
+}
+
+// downloadInboundMedia downloads media from inbound messages and stores in MediaStore.
+func (c *FeishuChannel) downloadInboundMedia(
+ ctx context.Context,
+ chatID, messageID, messageType, rawContent string,
+ store media.MediaStore,
+) []string {
+ var refs []string
+ scope := channels.BuildMediaScope("feishu", chatID, messageID)
+
+ switch messageType {
+ case larkim.MsgTypeImage:
+ imageKey := extractImageKey(rawContent)
+ if imageKey == "" {
+ return nil
+ }
+ ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
+ if ref != "" {
+ refs = append(refs, ref)
+ }
+
+ case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
+ fileKey := extractFileKey(rawContent)
+ if fileKey == "" {
+ return nil
+ }
+ // Derive a fallback extension from the message type.
+ var ext string
+ switch messageType {
+ case larkim.MsgTypeAudio:
+ ext = ".ogg"
+ case larkim.MsgTypeMedia:
+ ext = ".mp4"
+ default:
+ ext = "" // generic file — rely on resp.FileName
+ }
+ ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope)
+ if ref != "" {
+ refs = append(refs, ref)
+ }
+ }
+
+ return refs
+}
+
+// downloadResource downloads a message resource (image/file) from Feishu,
+// writes it to the project media directory, and stores the reference in MediaStore.
+// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension.
+func (c *FeishuChannel) downloadResource(
+ ctx context.Context,
+ messageID, fileKey, resourceType, fallbackExt string,
+ store media.MediaStore,
+ scope string,
+) string {
+ req := larkim.NewGetMessageResourceReqBuilder().
+ MessageId(messageID).
+ FileKey(fileKey).
+ Type(resourceType).
+ Build()
+
+ resp, err := c.client.Im.V1.MessageResource.Get(ctx, req)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to download resource", map[string]any{
+ "message_id": messageID,
+ "file_key": fileKey,
+ "error": err.Error(),
+ })
+ return ""
+ }
+ if !resp.Success() {
+ logger.ErrorCF("feishu", "Resource download api error", map[string]any{
+ "code": resp.Code,
+ "msg": resp.Msg,
+ })
+ return ""
+ }
+
+ if resp.File == nil {
+ return ""
+ }
+ // Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body).
+ if closer, ok := resp.File.(io.Closer); ok {
+ defer closer.Close()
+ }
+
+ filename := resp.FileName
+ if filename == "" {
+ filename = fileKey
+ }
+ // If filename still has no extension, append the fallback (like Telegram's ext parameter).
+ if filepath.Ext(filename) == "" && fallbackExt != "" {
+ filename += fallbackExt
+ }
+
+ // Write to the shared picoclaw_media directory using a unique name to avoid collisions.
+ mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
+ if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
+ logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
+ "error": mkdirErr.Error(),
+ })
+ return ""
+ }
+ ext := filepath.Ext(filename)
+ localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext))
+
+ out, err := os.Create(localPath)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{
+ "error": err.Error(),
+ })
+ return ""
+ }
+
+ if _, copyErr := io.Copy(out, resp.File); copyErr != nil {
+ out.Close()
+ os.Remove(localPath)
+ logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{
+ "error": copyErr.Error(),
+ })
+ return ""
+ }
+ out.Close()
+
+ ref, err := store.Store(localPath, media.MediaMeta{
+ Filename: filename,
+ Source: "feishu",
+ }, scope)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{
+ "file_key": fileKey,
+ "error": err.Error(),
+ })
+ os.Remove(localPath)
+ return ""
+ }
+
+ return ref
+}
+
+// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]").
+func appendMediaTags(content, messageType string, mediaRefs []string) string {
+ if len(mediaRefs) == 0 {
+ return content
+ }
+
+ var tag string
+ switch messageType {
+ case larkim.MsgTypeImage:
+ tag = "[image: photo]"
+ case larkim.MsgTypeAudio:
+ tag = "[audio]"
+ case larkim.MsgTypeMedia:
+ tag = "[video]"
+ case larkim.MsgTypeFile:
+ tag = "[file]"
+ default:
+ tag = "[attachment]"
+ }
+
+ if content == "" {
+ return tag
+ }
+ return content + " " + tag
+}
+
+// sendCard sends an interactive card message to a chat.
+func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error {
+ req := larkim.NewCreateMessageReqBuilder().
+ ReceiveIdType(larkim.ReceiveIdTypeChatId).
+ Body(larkim.NewCreateMessageReqBodyBuilder().
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeInteractive).
+ Content(cardContent).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Create(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu send card: %w", channels.ErrTemporary)
+ }
+
+ if !resp.Success() {
+ return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
+ }
+
+ logger.DebugCF("feishu", "Feishu card message sent", map[string]any{
+ "chat_id": chatID,
+ })
+
+ return nil
+}
+
+// sendImage uploads an image and sends it as a message.
+func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error {
+ // Upload image to get image_key
+ uploadReq := larkim.NewCreateImageReqBuilder().
+ Body(larkim.NewCreateImageReqBodyBuilder().
+ ImageType("message").
+ Image(file).
+ Build()).
+ Build()
+
+ uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq)
+ if err != nil {
+ return fmt.Errorf("feishu image upload: %w", err)
+ }
+ if !uploadResp.Success() {
+ return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
+ }
+ if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil {
+ return fmt.Errorf("feishu image upload: no image_key returned")
+ }
+
+ imageKey := *uploadResp.Data.ImageKey
+
+ // Send image message
+ content, _ := json.Marshal(map[string]string{"image_key": imageKey})
+ req := larkim.NewCreateMessageReqBuilder().
+ ReceiveIdType(larkim.ReceiveIdTypeChatId).
+ Body(larkim.NewCreateMessageReqBodyBuilder().
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeImage).
+ Content(string(content)).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Create(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu image send: %w", err)
+ }
+ if !resp.Success() {
+ return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+ return nil
+}
+
+// sendFile uploads a file and sends it as a message.
+func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error {
+ // Map part type to Feishu file type
+ feishuFileType := "stream"
+ switch fileType {
+ case "audio":
+ feishuFileType = "opus"
+ case "video":
+ feishuFileType = "mp4"
+ }
+
+ // Upload file to get file_key
+ uploadReq := larkim.NewCreateFileReqBuilder().
+ Body(larkim.NewCreateFileReqBodyBuilder().
+ FileType(feishuFileType).
+ FileName(filename).
+ File(file).
+ Build()).
+ Build()
+
+ uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq)
+ if err != nil {
+ return fmt.Errorf("feishu file upload: %w", err)
+ }
+ if !uploadResp.Success() {
+ return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
+ }
+ if uploadResp.Data == nil || uploadResp.Data.FileKey == nil {
+ return fmt.Errorf("feishu file upload: no file_key returned")
+ }
+
+ fileKey := *uploadResp.Data.FileKey
+
+ // Send file message
+ content, _ := json.Marshal(map[string]string{"file_key": fileKey})
+ req := larkim.NewCreateMessageReqBuilder().
+ ReceiveIdType(larkim.ReceiveIdTypeChatId).
+ Body(larkim.NewCreateMessageReqBodyBuilder().
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeFile).
+ Content(string(content)).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Create(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu file send: %w", err)
+ }
+ if !resp.Success() {
+ return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
return nil
}
@@ -222,20 +816,3 @@ func extractFeishuSenderID(sender *larkim.EventSender) string {
return ""
}
-
-func extractFeishuMessageContent(message *larkim.EventMessage) string {
- if message == nil || message.Content == nil || *message.Content == "" {
- return ""
- }
-
- if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText {
- var textPayload struct {
- Text string `json:"text"`
- }
- if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil {
- return textPayload.Text
- }
- }
-
- return *message.Content
-}
diff --git a/pkg/channels/feishu/feishu_64_test.go b/pkg/channels/feishu/feishu_64_test.go
new file mode 100644
index 000000000..dc3eab2e7
--- /dev/null
+++ b/pkg/channels/feishu/feishu_64_test.go
@@ -0,0 +1,256 @@
+//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
+
+package feishu
+
+import (
+ "testing"
+
+ larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+)
+
+func TestExtractContent(t *testing.T) {
+ tests := []struct {
+ name string
+ messageType string
+ rawContent string
+ want string
+ }{
+ {
+ name: "text message",
+ messageType: "text",
+ rawContent: `{"text": "hello world"}`,
+ want: "hello world",
+ },
+ {
+ name: "text message invalid JSON",
+ messageType: "text",
+ rawContent: `not json`,
+ want: "not json",
+ },
+ {
+ name: "post message returns raw JSON",
+ messageType: "post",
+ rawContent: `{"title": "test post"}`,
+ want: `{"title": "test post"}`,
+ },
+ {
+ name: "image message returns empty",
+ messageType: "image",
+ rawContent: `{"image_key": "img_xxx"}`,
+ want: "",
+ },
+ {
+ name: "file message with filename",
+ messageType: "file",
+ rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
+ want: "report.pdf",
+ },
+ {
+ name: "file message without filename",
+ messageType: "file",
+ rawContent: `{"file_key": "file_xxx"}`,
+ want: "",
+ },
+ {
+ name: "audio message with filename",
+ messageType: "audio",
+ rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`,
+ want: "recording.ogg",
+ },
+ {
+ name: "media message with filename",
+ messageType: "media",
+ rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`,
+ want: "video.mp4",
+ },
+ {
+ name: "unknown message type returns raw",
+ messageType: "sticker",
+ rawContent: `{"sticker_id": "sticker_xxx"}`,
+ want: `{"sticker_id": "sticker_xxx"}`,
+ },
+ {
+ name: "empty raw content",
+ messageType: "text",
+ rawContent: "",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractContent(tt.messageType, tt.rawContent)
+ if got != tt.want {
+ t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAppendMediaTags(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ messageType string
+ mediaRefs []string
+ want string
+ }{
+ {
+ name: "no refs returns content unchanged",
+ content: "hello",
+ messageType: "image",
+ mediaRefs: nil,
+ want: "hello",
+ },
+ {
+ name: "empty refs returns content unchanged",
+ content: "hello",
+ messageType: "image",
+ mediaRefs: []string{},
+ want: "hello",
+ },
+ {
+ name: "image with content",
+ content: "check this",
+ messageType: "image",
+ mediaRefs: []string{"ref1"},
+ want: "check this [image: photo]",
+ },
+ {
+ name: "image empty content",
+ content: "",
+ messageType: "image",
+ mediaRefs: []string{"ref1"},
+ want: "[image: photo]",
+ },
+ {
+ name: "audio",
+ content: "listen",
+ messageType: "audio",
+ mediaRefs: []string{"ref1"},
+ want: "listen [audio]",
+ },
+ {
+ name: "media/video",
+ content: "watch",
+ messageType: "media",
+ mediaRefs: []string{"ref1"},
+ want: "watch [video]",
+ },
+ {
+ name: "file",
+ content: "report.pdf",
+ messageType: "file",
+ mediaRefs: []string{"ref1"},
+ want: "report.pdf [file]",
+ },
+ {
+ name: "unknown type",
+ content: "something",
+ messageType: "sticker",
+ mediaRefs: []string{"ref1"},
+ want: "something [attachment]",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs)
+ if got != tt.want {
+ t.Errorf(
+ "appendMediaTags(%q, %q, %v) = %q, want %q",
+ tt.content,
+ tt.messageType,
+ tt.mediaRefs,
+ got,
+ tt.want,
+ )
+ }
+ })
+ }
+}
+
+func TestExtractFeishuSenderID(t *testing.T) {
+ strPtr := func(s string) *string { return &s }
+
+ tests := []struct {
+ name string
+ sender *larkim.EventSender
+ want string
+ }{
+ {
+ name: "nil sender",
+ sender: nil,
+ want: "",
+ },
+ {
+ name: "nil sender ID",
+ sender: &larkim.EventSender{SenderId: nil},
+ want: "",
+ },
+ {
+ name: "userId preferred",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr("u_abc123"),
+ OpenId: strPtr("ou_def456"),
+ UnionId: strPtr("on_ghi789"),
+ },
+ },
+ want: "u_abc123",
+ },
+ {
+ name: "openId fallback",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr(""),
+ OpenId: strPtr("ou_def456"),
+ UnionId: strPtr("on_ghi789"),
+ },
+ },
+ want: "ou_def456",
+ },
+ {
+ name: "unionId fallback",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr(""),
+ OpenId: strPtr(""),
+ UnionId: strPtr("on_ghi789"),
+ },
+ },
+ want: "on_ghi789",
+ },
+ {
+ name: "all empty strings",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr(""),
+ OpenId: strPtr(""),
+ UnionId: strPtr(""),
+ },
+ },
+ want: "",
+ },
+ {
+ name: "nil userId pointer falls through",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: nil,
+ OpenId: strPtr("ou_def456"),
+ UnionId: nil,
+ },
+ },
+ want: "ou_def456",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractFeishuSenderID(tt.sender)
+ if got != tt.want {
+ t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go
index 9fac2831c..398f12e6b 100644
--- a/pkg/channels/line/line.go
+++ b/pkg/channels/line/line.go
@@ -45,11 +45,13 @@ type replyTokenEntry struct {
type LINEChannel struct {
*channels.BaseChannel
config config.LINEConfig
- botUserID string // Bot's user ID
- botBasicID string // Bot's basic ID (e.g. @216ru...)
- botDisplayName string // Bot's display name for text-based mention detection
- replyTokens sync.Map // chatID -> replyTokenEntry
- quoteTokens sync.Map // chatID -> quoteToken (string)
+ infoClient *http.Client // for bot info lookups (short timeout)
+ apiClient *http.Client // for messaging API calls
+ botUserID string // Bot's user ID
+ botBasicID string // Bot's basic ID (e.g. @216ru...)
+ botDisplayName string // Bot's display name for text-based mention detection
+ replyTokens sync.Map // chatID -> replyTokenEntry
+ quoteTokens sync.Map // chatID -> quoteToken (string)
ctx context.Context
cancel context.CancelFunc
}
@@ -69,6 +71,8 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
return &LINEChannel{
BaseChannel: base,
config: cfg,
+ infoClient: &http.Client{Timeout: 10 * time.Second},
+ apiClient: &http.Client{Timeout: 30 * time.Second},
}, nil
}
@@ -104,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error {
}
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
- client := &http.Client{Timeout: 10 * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.infoClient.Do(req)
if err != nil {
return err
}
@@ -644,8 +647,7 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
- client := &http.Client{Timeout: 30 * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.apiClient.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go
index 155e50b39..fdd6d0c1f 100644
--- a/pkg/channels/manager.go
+++ b/pkg/channels/manager.go
@@ -255,6 +255,10 @@ func (m *Manager) initChannels() error {
m.initChannel("wecom", "WeCom")
}
+ if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" {
+ m.initChannel("wecom_aibot", "WeCom AI Bot")
+ }
+
if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
m.initChannel("wecom_app", "WeCom App")
}
@@ -539,86 +543,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
})
}
-func (m *Manager) dispatchOutbound(ctx context.Context) {
- logger.InfoC("channels", "Outbound dispatcher started")
+func dispatchLoop[M any](
+ ctx context.Context,
+ m *Manager,
+ subscribe func(context.Context) (M, bool),
+ getChannel func(M) string,
+ enqueue func(context.Context, *channelWorker, M) bool,
+ startMsg, stopMsg, unknownMsg, noWorkerMsg string,
+) {
+ logger.InfoC("channels", startMsg)
for {
- msg, ok := m.bus.SubscribeOutbound(ctx)
+ msg, ok := subscribe(ctx)
if !ok {
- logger.InfoC("channels", "Outbound dispatcher stopped")
+ logger.InfoC("channels", stopMsg)
return
}
+ channel := getChannel(msg)
+
// Silently skip internal channels
- if constants.IsInternalChannel(msg.Channel) {
+ if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
- _, exists := m.channels[msg.Channel]
- w, wExists := m.workers[msg.Channel]
+ _, exists := m.channels[channel]
+ w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
- logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
- "channel": msg.Channel,
- })
+ logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
- select {
- case w.queue <- msg:
- case <-ctx.Done():
+ if !enqueue(ctx, w, msg) {
return
}
} else if exists {
- logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{
- "channel": msg.Channel,
- })
+ logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
}
}
}
+func (m *Manager) dispatchOutbound(ctx context.Context) {
+ dispatchLoop(
+ ctx, m,
+ m.bus.SubscribeOutbound,
+ func(msg bus.OutboundMessage) string { return msg.Channel },
+ func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
+ select {
+ case w.queue <- msg:
+ return true
+ case <-ctx.Done():
+ return false
+ }
+ },
+ "Outbound dispatcher started",
+ "Outbound dispatcher stopped",
+ "Unknown channel for outbound message",
+ "Channel has no active worker, skipping message",
+ )
+}
+
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
- logger.InfoC("channels", "Outbound media dispatcher started")
-
- for {
- msg, ok := m.bus.SubscribeOutboundMedia(ctx)
- if !ok {
- logger.InfoC("channels", "Outbound media dispatcher stopped")
- return
- }
-
- // Silently skip internal channels
- if constants.IsInternalChannel(msg.Channel) {
- continue
- }
-
- m.mu.RLock()
- _, exists := m.channels[msg.Channel]
- w, wExists := m.workers[msg.Channel]
- m.mu.RUnlock()
-
- if !exists {
- logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{
- "channel": msg.Channel,
- })
- continue
- }
-
- if wExists && w != nil {
+ dispatchLoop(
+ ctx, m,
+ m.bus.SubscribeOutboundMedia,
+ func(msg bus.OutboundMediaMessage) string { return msg.Channel },
+ func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
case w.mediaQueue <- msg:
+ return true
case <-ctx.Done():
- return
+ return false
}
- } else if exists {
- logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{
- "channel": msg.Channel,
- })
- }
- }
+ },
+ "Outbound media dispatcher started",
+ "Outbound media dispatcher stopped",
+ "Unknown channel for outbound media message",
+ "Channel has no active worker, skipping media message",
+ )
}
// runMediaWorker processes outbound media messages for a single channel.
diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go
index 6b9f151c3..f09ecfe2f 100644
--- a/pkg/channels/manager_test.go
+++ b/pkg/channels/manager_test.go
@@ -274,13 +274,12 @@ func TestWorkerRateLimiter(t *testing.T) {
limiter: rate.NewLimiter(2, 1),
}
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ ctx := t.Context()
go m.runWorker(ctx, "test", w)
// Enqueue 4 messages
- for i := 0; i < 4; i++ {
+ for i := range 4 {
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
}
@@ -352,8 +351,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
limiter: rate.NewLimiter(rate.Inf, 1),
}
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ ctx := t.Context()
go m.runWorker(ctx, "test", w)
@@ -576,7 +574,7 @@ func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) {
m := newTestManager()
var wg sync.WaitGroup
- for i := 0; i < 100; i++ {
+ for i := range 100 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -591,7 +589,7 @@ func TestRecordTypingStop_ConcurrentSafe(t *testing.T) {
m := newTestManager()
var wg sync.WaitGroup
- for i := 0; i < 100; i++ {
+ for i := range 100 {
wg.Add(1)
go func(i int) {
defer wg.Done()
@@ -834,7 +832,7 @@ func TestLazyWorkerCreation(t *testing.T) {
func TestBuildMediaScope_FastIDUniqueness(t *testing.T) {
seen := make(map[string]bool)
- for i := 0; i < 1000; i++ {
+ for range 1000 {
scope := BuildMediaScope("test", "chat1", "")
if seen[scope] {
t.Fatalf("duplicate scope generated: %s", scope)
diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go
index 89cba4ae0..62a9eb34a 100644
--- a/pkg/channels/onebot/onebot.go
+++ b/pkg/channels/onebot/onebot.go
@@ -337,10 +337,7 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
}
func (c *OneBotChannel) reconnectLoop() {
- interval := time.Duration(c.config.ReconnectInterval) * time.Second
- if interval < 5*time.Second {
- interval = 5 * time.Second
- }
+ interval := max(time.Duration(c.config.ReconnectInterval)*time.Second, 5*time.Second)
for {
select {
diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go
index 2ae82d8da..8d8b62a67 100644
--- a/pkg/channels/pico/pico.go
+++ b/pkg/channels/pico/pico.go
@@ -292,8 +292,8 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
// Check Authorization header
auth := r.Header.Get("Authorization")
- if strings.HasPrefix(auth, "Bearer ") {
- if strings.TrimPrefix(auth, "Bearer ") == token {
+ if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
+ if after == token {
return true
}
}
diff --git a/pkg/channels/split.go b/pkg/channels/split.go
index 1c951a31f..bb26c6d8f 100644
--- a/pkg/channels/split.go
+++ b/pkg/channels/split.go
@@ -23,10 +23,7 @@ func SplitMessage(content string, maxLen int) []string {
var messages []string
// Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
- codeBlockBuffer := maxLen / 10
- if codeBlockBuffer < 50 {
- codeBlockBuffer = 50
- }
+ codeBlockBuffer := max(maxLen/10, 50)
if codeBlockBuffer > maxLen/2 {
codeBlockBuffer = maxLen / 2
}
@@ -40,10 +37,7 @@ func SplitMessage(content string, maxLen int) []string {
}
// Effective split point: maxLen minus buffer, to leave room for code blocks
- effectiveLimit := maxLen - codeBlockBuffer
- if effectiveLimit < maxLen/2 {
- effectiveLimit = maxLen / 2
- }
+ effectiveLimit := max(maxLen-codeBlockBuffer, maxLen/2)
end := start + effectiveLimit
@@ -85,10 +79,9 @@ func SplitMessage(content string, maxLen int) []string {
// If we have a reasonable amount of content after the header, split inside
if msgEnd > headerEndIdx+20 {
// Find a better split point closer to maxLen
- innerLimit := start + maxLen - 5 // Leave room for "\n```"
- if innerLimit > totalLen {
- innerLimit = totalLen
- }
+ innerLimit := min(
+ // Leave room for "\n```"
+ start+maxLen-5, totalLen)
betterEnd := findLastNewlineInRange(runes, start, innerLimit, 200)
if betterEnd > headerEndIdx {
msgEnd = betterEnd
@@ -117,10 +110,7 @@ func SplitMessage(content string, maxLen int) []string {
if unclosedIdx-start > 20 {
msgEnd = unclosedIdx
} else {
- splitAt := start + maxLen - 5
- if splitAt > totalLen {
- splitAt = totalLen
- }
+ splitAt := min(start+maxLen-5, totalLen)
chunk := strings.TrimRight(string(runes[start:splitAt]), " \t\n\r") + "\n```"
messages = append(messages, chunk)
remaining := strings.TrimSpace(header + "\n" + string(runes[splitAt:totalLen]))
@@ -196,10 +186,7 @@ func findNewlineFrom(runes []rune, from int) int {
// findLastNewlineInRange finds the last newline within the last searchWindow runes
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
- searchStart := end - searchWindow
- if searchStart < start {
- searchStart = start
- }
+ searchStart := max(end-searchWindow, start)
for i := end - 1; i >= searchStart; i-- {
if runes[i] == '\n' {
return i
@@ -211,10 +198,7 @@ func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
// findLastSpaceInRange finds the last space/tab within the last searchWindow runes
// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
func findLastSpaceInRange(runes []rune, start, end, searchWindow int) int {
- searchStart := end - searchWindow
- if searchStart < start {
- searchStart = start
- }
+ searchStart := max(end-searchWindow, start)
for i := end - 1; i >= searchStart; i-- {
if runes[i] == ' ' || runes[i] == '\t' {
return i
diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go
index a11cf53b8..f328f32b8 100644
--- a/pkg/channels/telegram/telegram.go
+++ b/pkg/channels/telegram/telegram.go
@@ -7,12 +7,12 @@ import (
"net/url"
"os"
"regexp"
+ "slices"
"strconv"
"strings"
"time"
"github.com/mymmrac/telego"
- "github.com/mymmrac/telego/telegohandler"
th "github.com/mymmrac/telego/telegohandler"
tu "github.com/mymmrac/telego/telegoutil"
@@ -41,7 +41,7 @@ var (
type TelegramChannel struct {
*channels.BaseChannel
bot *telego.Bot
- bh *telegohandler.BotHandler
+ bh *th.BotHandler
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
@@ -72,6 +72,10 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
}))
}
+ if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" {
+ opts = append(opts, telego.WithAPIServer(baseURL))
+ }
+
bot, err := telego.NewBot(telegramCfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
@@ -101,6 +105,12 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
c.ctx, c.cancel = context.WithCancel(ctx)
+ if err := c.initBotCommands(c.ctx); err != nil {
+ logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
+ "error": err.Error(),
+ })
+ }
+
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
@@ -109,20 +119,19 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return fmt.Errorf("failed to start long polling: %w", err)
}
- bh, err := telegohandler.NewBotHandler(c.bot, updates)
+ bh, err := th.NewBotHandler(c.bot, updates)
if err != nil {
c.cancel()
return fmt.Errorf("failed to create bot handler: %w", err)
}
c.bh = bh
- bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
- c.commands.Help(ctx, message)
- return nil
- }, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Start(ctx, message)
}, th.CommandEqual("start"))
+ bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
+ return c.commands.Help(ctx, message)
+ }, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Show(ctx, message)
@@ -141,7 +150,13 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
"username": c.bot.Username(),
})
- go bh.Start()
+ go func() {
+ if err = bh.Start(); err != nil {
+ logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
return nil
}
@@ -152,7 +167,7 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
// Stop the bot handler
if c.bh != nil {
- c.bh.Stop()
+ _ = c.bh.StopWithContext(ctx)
}
// Cancel our context (stops long polling)
@@ -163,6 +178,51 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
return nil
}
+func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
+ currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
+ Scope: tu.ScopeDefault(),
+ })
+ if err != nil {
+ return fmt.Errorf("get commands: %w", err)
+ }
+
+ commands := []telego.BotCommand{
+ {
+ Command: "start",
+ Description: "Start the bot",
+ },
+ {
+ Command: "help",
+ Description: "Show a help message",
+ },
+ {
+ Command: "show",
+ Description: "Show current configuration",
+ },
+ {
+ Command: "list",
+ Description: "List available options",
+ },
+ }
+
+ // Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
+ if !slices.Equal(currentCommands, commands) {
+ logger.InfoC("telegram", "Updating bot commands")
+
+ err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
+ Commands: commands,
+ Scope: tu.ScopeDefault(),
+ })
+ if err != nil {
+ return fmt.Errorf("set commands: %w", err)
+ }
+ } else {
+ logger.DebugC("telegram", "Bot commands are up to date")
+ }
+
+ return nil
+}
+
func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return channels.ErrNotRunning
diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go
new file mode 100644
index 000000000..6c5aca40b
--- /dev/null
+++ b/pkg/channels/wecom/aibot.go
@@ -0,0 +1,1014 @@
+package wecom
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "math/big"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+// WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人)
+type WeComAIBotChannel struct {
+ *channels.BaseChannel
+ config config.WeComAIBotConfig
+ ctx context.Context
+ cancel context.CancelFunc
+ streamTasks map[string]*streamTask // streamID -> task (for poll lookups)
+ chatTasks map[string][]*streamTask // chatID -> in-flight tasks queue (FIFO)
+ taskMu sync.RWMutex
+}
+
+// streamTask represents a streaming task for AI Bot.
+//
+// Mutable fields (Finished, StreamClosed, StreamClosedAt) must be read/written
+// while holding WeComAIBotChannel.taskMu. Immutable fields (StreamID, ChatID,
+// ResponseURL, Question, CreatedTime, Deadline, answerCh, ctx, cancel) are set
+// once at creation and never modified, so they are safe to read without a lock.
+type streamTask struct {
+ // immutable after creation
+ StreamID string
+ ChatID string // used by Send() to find this task
+ ResponseURL string // temporary URL for proactive reply (valid 1 hour, use once)
+ Question string
+ CreatedTime time.Time
+ Deadline time.Time // ~30s, we close the stream here and switch to response_url
+ answerCh chan string // receives agent reply from Send()
+ ctx context.Context // canceled when task is removed; used to interrupt the agent goroutine
+ cancel context.CancelFunc // call on task removal to cancel ctx
+
+ // mutable — guarded by WeComAIBotChannel.taskMu
+ StreamClosed bool // stream returned finish:true; waiting for agent to reply via response_url
+ StreamClosedAt time.Time // set when StreamClosed becomes true; used for accelerated cleanup
+ Finished bool // fully done
+}
+
+// WeComAIBotMessage represents the decrypted JSON message from WeCom AI Bot
+// Ref: https://developer.work.weixin.qq.com/document/path/100719
+type WeComAIBotMessage struct {
+ MsgID string `json:"msgid"`
+ AIBotID string `json:"aibotid"`
+ ChatID string `json:"chatid"` // only for group chat
+ ChatType string `json:"chattype"` // "single" or "group"
+ From struct {
+ UserID string `json:"userid"`
+ } `json:"from"`
+ ResponseURL string `json:"response_url"` // temporary URL for proactive reply
+ MsgType string `json:"msgtype"`
+ // text message
+ Text *struct {
+ Content string `json:"content"`
+ } `json:"text,omitempty"`
+ // stream polling refresh
+ Stream *struct {
+ ID string `json:"id"`
+ } `json:"stream,omitempty"`
+ // image message
+ Image *struct {
+ URL string `json:"url"`
+ } `json:"image,omitempty"`
+ // mixed message (text + image)
+ Mixed *struct {
+ MsgItem []struct {
+ MsgType string `json:"msgtype"`
+ Text *struct {
+ Content string `json:"content"`
+ } `json:"text,omitempty"`
+ Image *struct {
+ URL string `json:"url"`
+ } `json:"image,omitempty"`
+ } `json:"msg_item"`
+ } `json:"mixed,omitempty"`
+ // event field
+ Event *struct {
+ EventType string `json:"eventtype"`
+ } `json:"event,omitempty"`
+}
+
+// WeComAIBotMsgItemImage holds the image payload inside a stream message item.
+type WeComAIBotMsgItemImage struct {
+ Base64 string `json:"base64"`
+ MD5 string `json:"md5"`
+}
+
+// WeComAIBotMsgItem is a single item inside a stream's msg_item list.
+type WeComAIBotMsgItem struct {
+ MsgType string `json:"msgtype"`
+ Image *WeComAIBotMsgItemImage `json:"image,omitempty"`
+}
+
+// WeComAIBotStreamInfo represents the detailed stream content in streaming responses.
+type WeComAIBotStreamInfo struct {
+ ID string `json:"id"`
+ Finish bool `json:"finish"`
+ Content string `json:"content,omitempty"`
+ MsgItem []WeComAIBotMsgItem `json:"msg_item,omitempty"`
+}
+
+// WeComAIBotStreamResponse represents the streaming response format
+type WeComAIBotStreamResponse struct {
+ MsgType string `json:"msgtype"`
+ Stream WeComAIBotStreamInfo `json:"stream"`
+}
+
+// WeComAIBotEncryptedResponse represents the encrypted response wrapper
+// Fields match WXBizJsonMsgCrypt.generate() in Python SDK
+type WeComAIBotEncryptedResponse struct {
+ Encrypt string `json:"encrypt"`
+ MsgSignature string `json:"msgsignature"`
+ Timestamp string `json:"timestamp"`
+ Nonce string `json:"nonce"`
+}
+
+// NewWeComAIBotChannel creates a new WeCom AI Bot channel instance
+func NewWeComAIBotChannel(
+ cfg config.WeComAIBotConfig,
+ messageBus *bus.MessageBus,
+) (*WeComAIBotChannel, error) {
+ if cfg.Token == "" || cfg.EncodingAESKey == "" {
+ return nil, fmt.Errorf("token and encoding_aes_key are required for WeCom AI Bot")
+ }
+
+ base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(2048),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
+
+ return &WeComAIBotChannel{
+ BaseChannel: base,
+ config: cfg,
+ streamTasks: make(map[string]*streamTask),
+ chatTasks: make(map[string][]*streamTask),
+ }, nil
+}
+
+// Name returns the channel name
+func (c *WeComAIBotChannel) Name() string {
+ return "wecom_aibot"
+}
+
+// Start initializes the WeCom AI Bot channel
+func (c *WeComAIBotChannel) Start(ctx context.Context) error {
+ logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel...")
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ // Start cleanup goroutine for old tasks
+ go c.cleanupLoop()
+
+ c.SetRunning(true)
+ logger.InfoC("wecom_aibot", "WeCom AI Bot channel started")
+
+ return nil
+}
+
+// Stop gracefully stops the WeCom AI Bot channel
+func (c *WeComAIBotChannel) Stop(ctx context.Context) error {
+ logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel...")
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ c.SetRunning(false)
+ logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped")
+ return nil
+}
+
+// Send delivers the agent reply into the active streamTask for msg.ChatID.
+// It writes into the earliest unfinished task in the queue (FIFO per chatID).
+// If the stream has already closed (deadline passed), it posts directly to response_url.
+func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+ c.taskMu.Lock()
+ queue := c.chatTasks[msg.ChatID]
+ // Only compact Finished tasks at the head of the queue.
+ // Tasks that are Finished in the middle are NOT removed here: doing a full
+ // scan on every Send() call would be O(n) and is unnecessary given that
+ // removeTask() always splices the task out of the queue immediately.
+ // Any Finished task left stranded in the middle (e.g. due to an unexpected
+ // code path) will be collected by cleanupOldTasks.
+ for len(queue) > 0 && queue[0].Finished {
+ queue = queue[1:]
+ }
+ c.chatTasks[msg.ChatID] = queue
+ var task *streamTask
+ var streamClosed bool
+ var responseURL string
+ if len(queue) > 0 {
+ task = queue[0]
+ // Read mutable fields while holding c.taskMu to avoid data races.
+ streamClosed = task.StreamClosed
+ responseURL = task.ResponseURL
+ }
+ c.taskMu.Unlock()
+
+ if task == nil {
+ logger.DebugCF(
+ "wecom_aibot",
+ "Send: no active task for chat (may have timed out)",
+ map[string]any{
+ "chat_id": msg.ChatID,
+ },
+ )
+ return nil
+ }
+
+ if streamClosed {
+ // Stream already ended with a "please wait" notice; send the real reply via response_url.
+ // Note: task.StreamID and task.ChatID are immutable, safe to read without a lock.
+ logger.InfoCF("wecom_aibot", "Sending reply via response_url", map[string]any{
+ "stream_id": task.StreamID,
+ "chat_id": msg.ChatID,
+ })
+ if responseURL != "" {
+ if err := c.sendViaResponseURL(responseURL, msg.Content); err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to send via response_url", map[string]any{
+ "error": err,
+ "stream_id": task.StreamID,
+ })
+ c.removeTask(task)
+ return fmt.Errorf("response_url delivery failed: %w", channels.ErrSendFailed)
+ }
+ } else {
+ logger.WarnCF("wecom_aibot", "Stream closed but no response_url available", map[string]any{
+ "stream_id": task.StreamID,
+ })
+ }
+ c.removeTask(task)
+ return nil
+ }
+
+ // Stream still open: deliver via answerCh for the next poll response.
+ select {
+ case task.answerCh <- msg.Content:
+ case <-task.ctx.Done():
+ // Task was canceled (cleanup removed it); silently drop the reply.
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ return nil
+}
+
+// WebhookPath returns the path for registering on the shared HTTP server
+func (c *WeComAIBotChannel) WebhookPath() string {
+ if c.config.WebhookPath == "" {
+ return "/webhook/wecom-aibot"
+ }
+ return c.config.WebhookPath
+}
+
+// ServeHTTP implements http.Handler for the shared HTTP server
+func (c *WeComAIBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ c.handleWebhook(w, r)
+}
+
+// HealthPath returns the health check endpoint path
+func (c *WeComAIBotChannel) HealthPath() string {
+ return c.WebhookPath() + "/health"
+}
+
+// HealthHandler handles health check requests
+func (c *WeComAIBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
+ c.handleHealth(w, r)
+}
+
+// handleWebhook handles incoming webhook requests from WeCom AI Bot
+func (c *WeComAIBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+
+ // Log all incoming requests for debugging
+ logger.DebugCF("wecom_aibot", "Received webhook request", map[string]any{
+ "method": r.Method,
+ "path": r.URL.Path,
+ "query": r.URL.RawQuery,
+ })
+
+ switch r.Method {
+ case http.MethodGet:
+ // URL verification
+ c.handleVerification(ctx, w, r)
+ case http.MethodPost:
+ // Message callback
+ c.handleMessageCallback(ctx, w, r)
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+// handleVerification handles the URL verification request from WeCom
+func (c *WeComAIBotChannel) handleVerification(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+) {
+ msgSignature := r.URL.Query().Get("msg_signature")
+ timestamp := r.URL.Query().Get("timestamp")
+ nonce := r.URL.Query().Get("nonce")
+ echostr := r.URL.Query().Get("echostr")
+
+ logger.DebugCF("wecom_aibot", "URL verification request", map[string]any{
+ "msg_signature": msgSignature,
+ "timestamp": timestamp,
+ "nonce": nonce,
+ })
+
+ // Verify signature
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
+ logger.ErrorC("wecom_aibot", "Signature verification failed")
+ http.Error(w, "Signature verification failed", http.StatusUnauthorized)
+ return
+ }
+
+ // Decrypt echostr
+ // For WeCom AI Bot (智能机器人), receiveid should be empty string
+ decrypted, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "")
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to decrypt echostr", map[string]any{
+ "error": err,
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Remove BOM and whitespace as per WeCom documentation
+ decrypted = strings.TrimPrefix(decrypted, "\ufeff")
+ decrypted = strings.TrimSpace(decrypted)
+
+ logger.InfoC("wecom_aibot", "URL verification successful")
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(decrypted))
+}
+
+// handleMessageCallback handles incoming messages from WeCom AI Bot
+func (c *WeComAIBotChannel) handleMessageCallback(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+) {
+ msgSignature := r.URL.Query().Get("msg_signature")
+ timestamp := r.URL.Query().Get("timestamp")
+ nonce := r.URL.Query().Get("nonce")
+
+ // Read request body (limit to 4 MB to prevent memory exhaustion)
+ const maxBodySize = 4 << 20 // 4 MB
+ body, err := io.ReadAll(io.LimitReader(r.Body, maxBodySize+1))
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to read request body", map[string]any{
+ "error": err,
+ })
+ http.Error(w, "Failed to read body", http.StatusBadRequest)
+ return
+ }
+ if len(body) > maxBodySize {
+ http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
+ return
+ }
+
+ // Parse JSON body to get encrypted message
+ // Format: {"encrypt": "base64_encrypted_string"}
+ var encryptedMsg struct {
+ Encrypt string `json:"encrypt"`
+ }
+ if unmarshalErr := json.Unmarshal(body, &encryptedMsg); unmarshalErr != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to parse JSON body", map[string]any{
+ "error": unmarshalErr,
+ "body": string(body),
+ })
+ http.Error(w, "Failed to parse JSON", http.StatusBadRequest)
+ return
+ }
+
+ // Verify signature
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
+ logger.ErrorC("wecom_aibot", "Signature verification failed")
+ http.Error(w, "Signature verification failed", http.StatusUnauthorized)
+ return
+ }
+
+ // Decrypt message
+ // For WeCom AI Bot (智能机器人), receiveid is empty string
+ decrypted, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "")
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to decrypt message", map[string]any{
+ "error": err,
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Parse decrypted JSON message
+ var msg WeComAIBotMessage
+ if unmarshalErr := json.Unmarshal([]byte(decrypted), &msg); unmarshalErr != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to parse decrypted JSON", map[string]any{
+ "error": unmarshalErr,
+ "decrypted": decrypted,
+ })
+ http.Error(w, "Failed to parse message", http.StatusInternalServerError)
+ return
+ }
+
+ logger.DebugCF("wecom_aibot", "Decrypted message", map[string]any{
+ "msgtype": msg.MsgType,
+ })
+
+ // Process the message and get streaming response
+ response := c.processMessage(ctx, msg, timestamp, nonce)
+
+ // Check if response is empty (e.g. due to unsupported message type)
+ if response == "" {
+ response = c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ // Return encrypted JSON response
+ w.Header().Set("Content-Type", "application/json; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(response))
+}
+
+// processMessage processes the received message and returns encrypted response
+func (c *WeComAIBotChannel) processMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ logger.DebugCF("wecom_aibot", "Processing message", map[string]any{
+ "msgtype": msg.MsgType,
+ })
+
+ switch msg.MsgType {
+ case "text":
+ return c.handleTextMessage(ctx, msg, timestamp, nonce)
+ case "stream":
+ return c.handleStreamMessage(ctx, msg, timestamp, nonce)
+ case "image":
+ return c.handleImageMessage(ctx, msg, timestamp, nonce)
+ case "mixed":
+ return c.handleMixedMessage(ctx, msg, timestamp, nonce)
+ case "event":
+ return c.handleEventMessage(ctx, msg, timestamp, nonce)
+ default:
+ logger.WarnCF("wecom_aibot", "Unsupported message type", map[string]any{
+ "msgtype": msg.MsgType,
+ })
+ return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: c.generateStreamID(),
+ Finish: true,
+ Content: "Unsupported message type: " + msg.MsgType,
+ },
+ })
+ }
+}
+
+// handleTextMessage handles text messages by starting a new streaming task
+func (c *WeComAIBotChannel) handleTextMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ if msg.Text == nil {
+ logger.ErrorC("wecom_aibot", "text message missing text field")
+ return c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ content := msg.Text.Content
+ userID := msg.From.UserID
+ if userID == "" {
+ userID = "unknown"
+ }
+
+ // chatID: group chat uses chatid, single chat uses userid
+ chatID := msg.ChatID
+ if chatID == "" {
+ chatID = userID
+ }
+
+ streamID := c.generateStreamID()
+
+ // WeCom stops sending stream-refresh callbacks after 6 minutes.
+ // Set a slightly shorter deadline so we can send a timeout notice before it gives up.
+ deadline := time.Now().Add(30 * time.Second)
+
+ // Each task gets its own context derived from the channel lifetime context.
+ // Canceling taskCancel interrupts the agent goroutine when the task is removed.
+ taskCtx, taskCancel := context.WithCancel(c.ctx)
+
+ task := &streamTask{
+ StreamID: streamID,
+ ChatID: chatID,
+ ResponseURL: msg.ResponseURL,
+ Question: content,
+ CreatedTime: time.Now(),
+ Deadline: deadline,
+ Finished: false,
+ answerCh: make(chan string, 1),
+ ctx: taskCtx,
+ cancel: taskCancel,
+ }
+
+ c.taskMu.Lock()
+ c.streamTasks[streamID] = task
+ c.chatTasks[chatID] = append(c.chatTasks[chatID], task)
+ c.taskMu.Unlock()
+
+ // Publish to agent asynchronously; agent will call Send() with reply.
+ // Use task.ctx (not c.ctx) so the agent goroutine is canceled when the task is removed.
+ go func() {
+ sender := bus.SenderInfo{
+ Platform: "wecom_aibot",
+ PlatformID: userID,
+ CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID),
+ DisplayName: userID,
+ }
+ peerKind := "direct"
+ if msg.ChatType == "group" {
+ peerKind = "group"
+ }
+ peer := bus.Peer{Kind: peerKind, ID: chatID}
+ metadata := map[string]string{
+ "channel": "wecom_aibot",
+ "chat_type": msg.ChatType,
+ "msg_type": "text",
+ "msgid": msg.MsgID,
+ "aibotid": msg.AIBotID,
+ "stream_id": streamID,
+ "response_url": msg.ResponseURL,
+ }
+ c.HandleMessage(task.ctx, peer, msg.MsgID, userID, chatID,
+ content, nil, metadata, sender)
+ }()
+
+ // Return first streaming response immediately (finish=false, content empty)
+ return c.getStreamResponse(task, timestamp, nonce)
+}
+
+// handleStreamMessage handles stream polling requests
+func (c *WeComAIBotChannel) handleStreamMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ if msg.Stream == nil {
+ logger.ErrorC("wecom_aibot", "Stream message missing stream field")
+ return c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ streamID := msg.Stream.ID
+
+ c.taskMu.RLock()
+ task, exists := c.streamTasks[streamID]
+ c.taskMu.RUnlock()
+
+ if !exists {
+ logger.DebugCF(
+ "wecom_aibot",
+ "Stream task not found (may be from previous session)",
+ map[string]any{
+ "stream_id": streamID,
+ },
+ )
+ return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: streamID,
+ Finish: true,
+ Content: "Task not found or already finished. Please resend your message to start a new session.",
+ },
+ })
+ }
+
+ // Get next response
+ return c.getStreamResponse(task, timestamp, nonce)
+}
+
+// handleImageMessage handles image messages
+func (c *WeComAIBotChannel) handleImageMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ logger.WarnC("wecom_aibot", "Image message type not yet fully implemented")
+ if msg.Image == nil {
+ logger.ErrorC("wecom_aibot", "Image message missing image field")
+ return c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ imageURL := msg.Image.URL
+
+ // For now, just acknowledge receipt without echoing the image
+ return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: c.generateStreamID(),
+ Finish: true,
+ Content: fmt.Sprintf(
+ "Image received (URL: %s), but image messages are not yet supported",
+ imageURL,
+ ),
+ },
+ })
+}
+
+// handleMixedMessage handles mixed (text + image) messages
+func (c *WeComAIBotChannel) handleMixedMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ logger.WarnC("wecom_aibot", "Mixed message type not yet fully implemented")
+ return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: c.generateStreamID(),
+ Finish: true,
+ Content: "Mixed message type is not yet supported",
+ },
+ })
+}
+
+// handleEventMessage handles event messages
+func (c *WeComAIBotChannel) handleEventMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ eventType := ""
+ if msg.Event != nil {
+ eventType = msg.Event.EventType
+ }
+ logger.DebugCF("wecom_aibot", "Received event", map[string]any{
+ "event_type": eventType,
+ })
+
+ // Send welcome message when user opens the chat window
+ if eventType == "enter_chat" && c.config.WelcomeMessage != "" {
+ streamID := c.generateStreamID()
+ return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: streamID,
+ Finish: true,
+ Content: c.config.WelcomeMessage,
+ },
+ })
+ }
+
+ return c.encryptEmptyResponse(timestamp, nonce)
+}
+
+// getStreamResponse gets the next streaming response for a task.
+// - If agent replied: return finish=true with the real answer.
+// - If deadline passed: return finish=true with a "please wait" notice, keep task alive for response_url.
+// - Otherwise: return finish=false (empty), client will poll again.
+func (c *WeComAIBotChannel) getStreamResponse(task *streamTask, timestamp, nonce string) string {
+ var content string
+ var finish bool
+ var closeStreamOnly bool // close stream but do NOT remove task (response_url still pending)
+
+ select {
+ case answer := <-task.answerCh:
+ // Agent replied before deadline — normal finish.
+ content = answer
+ finish = true
+ default:
+ if time.Now().After(task.Deadline) {
+ // Deadline reached: close the stream with a notice, then wait for agent via response_url.
+ content = "⏳ Processing, please wait. The results will be sent shortly."
+ finish = true
+ closeStreamOnly = true
+ logger.InfoCF(
+ "wecom_aibot",
+ "Stream deadline reached, switching to response_url mode",
+ map[string]any{
+ "stream_id": task.StreamID,
+ "chat_id": task.ChatID,
+ "response_url": task.ResponseURL != "",
+ },
+ )
+ }
+ // else: still waiting, return finish=false
+ }
+
+ if finish && !closeStreamOnly {
+ // Normal finish: remove from all maps.
+ c.removeTask(task)
+ } else if closeStreamOnly {
+ // Mark stream as closed and remove from streamTasks under a single lock
+ // to keep StreamClosed/StreamClosedAt consistent with map membership.
+ c.taskMu.Lock()
+ task.StreamClosed = true
+ task.StreamClosedAt = time.Now()
+ delete(c.streamTasks, task.StreamID)
+ c.taskMu.Unlock()
+ }
+
+ response := WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: task.StreamID,
+ Finish: finish,
+ Content: content,
+ },
+ }
+
+ return c.encryptResponse(task.StreamID, timestamp, nonce, response)
+}
+
+// removeTask removes a task from both streamTasks and chatTasks, marks it finished,
+// and cancels its context to interrupt the associated agent goroutine.
+func (c *WeComAIBotChannel) removeTask(task *streamTask) {
+ // Cancel first so the agent goroutine stops as soon as possible,
+ // before we acquire the write lock.
+ task.cancel()
+
+ c.taskMu.Lock()
+ task.Finished = true // written under c.taskMu, consistent with all readers
+ delete(c.streamTasks, task.StreamID)
+ queue := c.chatTasks[task.ChatID]
+ for i, t := range queue {
+ if t == task {
+ c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...)
+ break
+ }
+ }
+ if len(c.chatTasks[task.ChatID]) == 0 {
+ delete(c.chatTasks, task.ChatID)
+ }
+ c.taskMu.Unlock()
+}
+
+// sendViaResponseURL posts a markdown reply to the WeCom response_url.
+// response_url is valid for 1 hour and can only be used once per callback.
+// Returned errors are wrapped with channels.ErrRateLimit, channels.ErrTemporary,
+// or channels.ErrSendFailed so the manager can apply the right retry policy.
+func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) error {
+ payload := map[string]any{
+ "msgtype": "markdown",
+ "markdown": map[string]string{
+ "content": content,
+ },
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("failed to marshal payload: %w", err)
+ }
+
+ ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, responseURL, bytes.NewBuffer(body))
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+
+ client := &http.Client{Timeout: 15 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode == http.StatusOK {
+ return nil
+ }
+
+ respBody, _ := io.ReadAll(resp.Body)
+ switch {
+ case resp.StatusCode == http.StatusTooManyRequests:
+ return fmt.Errorf("response_url rate limited (%d): %s: %w",
+ resp.StatusCode, respBody, channels.ErrRateLimit)
+ case resp.StatusCode >= 500:
+ return fmt.Errorf("response_url server error (%d): %s: %w",
+ resp.StatusCode, respBody, channels.ErrTemporary)
+ default:
+ return fmt.Errorf("response_url returned %d: %s: %w",
+ resp.StatusCode, respBody, channels.ErrSendFailed)
+ }
+}
+
+// encryptResponse encrypts a streaming response
+func (c *WeComAIBotChannel) encryptResponse(
+ streamID, timestamp, nonce string,
+ response WeComAIBotStreamResponse,
+) string {
+ // Marshal response to JSON
+ plaintext, err := json.Marshal(response)
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to marshal response", map[string]any{
+ "error": err,
+ })
+ return ""
+ }
+
+ logger.DebugCF("wecom_aibot", "Encrypting response", map[string]any{
+ "stream_id": streamID,
+ "finish": response.Stream.Finish,
+ "preview": utils.Truncate(response.Stream.Content, 100),
+ })
+
+ // Encrypt message
+ encrypted, err := c.encryptMessage(string(plaintext), "")
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to encrypt message", map[string]any{
+ "error": err,
+ })
+ return ""
+ }
+
+ // Generate signature
+ signature := computeSignature(c.config.Token, timestamp, nonce, encrypted)
+
+ // Build encrypted response
+ encryptedResp := WeComAIBotEncryptedResponse{
+ Encrypt: encrypted,
+ MsgSignature: signature,
+ Timestamp: timestamp,
+ Nonce: nonce,
+ }
+
+ respJSON, err := json.Marshal(encryptedResp)
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to marshal encrypted response", map[string]any{
+ "error": err,
+ })
+ return ""
+ }
+
+ logger.DebugCF("wecom_aibot", "Response encrypted", map[string]any{
+ "stream_id": streamID,
+ })
+
+ return string(respJSON)
+}
+
+// encryptEmptyResponse returns a minimal valid encrypted response
+func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string {
+ // Construct a zero-value stream response and encrypt it so that
+ // WeCom always receives a syntactically valid encrypted JSON object.
+ emptyResp := WeComAIBotStreamResponse{}
+ return c.encryptResponse("", timestamp, nonce, emptyResp)
+}
+
+// encryptMessage encrypts a plain text message for WeCom AI Bot
+func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) {
+ aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
+ if err != nil {
+ return "", err
+ }
+
+ frame, err := packWeComFrame(plaintext, receiveid)
+ if err != nil {
+ return "", err
+ }
+
+ // PKCS7 padding then AES-CBC encrypt
+ paddedFrame := pkcs7Pad(frame, blockSize)
+ ciphertext, err := encryptAESCBC(aesKey, paddedFrame)
+ if err != nil {
+ return "", err
+ }
+
+ return base64.StdEncoding.EncodeToString(ciphertext), nil
+}
+
+// generateStreamID generates a random stream ID
+func (c *WeComAIBotChannel) generateStreamID() string {
+ const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+ b := make([]byte, 10)
+ for i := range b {
+ n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
+ b[i] = letters[n.Int64()]
+ }
+ return string(b)
+}
+
+// cleanupLoop periodically cleans up old streaming tasks
+func (c *WeComAIBotChannel) cleanupLoop() {
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ c.cleanupOldTasks()
+ case <-c.ctx.Done():
+ return
+ }
+ }
+}
+
+// cleanupOldTasks removes tasks that have exceeded their expected lifetime:
+// - Active tasks (in streamTasks): cleaned up after 1 hour (response_url validity window).
+// - StreamClosed tasks (in chatTasks only): cleaned up after streamClosedGracePeriod.
+// These tasks are waiting for the agent to call Send() via response_url. If the agent
+// crashes or times out without calling Send(), we must not let them accumulate indefinitely.
+// The grace period is generous enough to cover typical LLM latency but far shorter than 1 hour,
+// preventing chatTasks from filling up when many requests time out in quick succession.
+const (
+ streamClosedGracePeriod = 10 * time.Minute // max wait for agent after stream closes
+ taskMaxLifetime = 1 * time.Hour // absolute max (≈ response_url validity)
+)
+
+func (c *WeComAIBotChannel) cleanupOldTasks() {
+ c.taskMu.Lock()
+ defer c.taskMu.Unlock()
+
+ now := time.Now()
+ cutoff := now.Add(-taskMaxLifetime)
+ for id, task := range c.streamTasks {
+ if task.CreatedTime.Before(cutoff) {
+ delete(c.streamTasks, id)
+ task.cancel() // interrupt agent goroutine still waiting for LLM
+ queue := c.chatTasks[task.ChatID]
+ for i, t := range queue {
+ if t == task {
+ c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...)
+ break
+ }
+ }
+ if len(c.chatTasks[task.ChatID]) == 0 {
+ delete(c.chatTasks, task.ChatID)
+ }
+ logger.DebugCF("wecom_aibot", "Cleaned up expired task", map[string]any{
+ "stream_id": id,
+ })
+ }
+ }
+ // Clean up StreamClosed tasks from chatTasks.
+ // Two expiry conditions are checked:
+ // 1. Absolute expiry: task was created more than taskMaxLifetime ago.
+ // 2. Grace expiry: stream closed more than streamClosedGracePeriod ago
+ // (agent had enough time to reply; it is not coming back).
+ for chatID, queue := range c.chatTasks {
+ filtered := queue[:0]
+ for i, t := range queue {
+ absoluteExpired := t.CreatedTime.Before(cutoff)
+ graceExpired := t.StreamClosed &&
+ !t.StreamClosedAt.IsZero() &&
+ t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod))
+ if t.Finished {
+ // Finished tasks should have been removed by removeTask().
+ // Finding one here (especially not at position 0) means an
+ // unexpected code path left it stranded, causing the queue to
+ // grow silently. Log a warning so it is visible, then drop it.
+ if i > 0 {
+ logger.WarnCF("wecom_aibot",
+ "Found stranded Finished task in the middle of chatTasks queue; "+
+ "this should not happen — removeTask() should have spliced it out",
+ map[string]any{
+ "chat_id": chatID,
+ "stream_id": t.StreamID,
+ "position": i,
+ })
+ }
+ // The task is already finished; its context was already canceled
+ // by removeTask(), so no further action is required.
+ continue
+ } else if !absoluteExpired && !graceExpired {
+ filtered = append(filtered, t)
+ } else {
+ t.cancel() // cancel any lingering agent goroutine
+ }
+ }
+ if len(filtered) == 0 {
+ delete(c.chatTasks, chatID)
+ } else {
+ c.chatTasks[chatID] = filtered
+ }
+ }
+}
+
+// handleHealth handles health check requests
+func (c *WeComAIBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
+ status := "ok"
+ if !c.IsRunning() {
+ status = "not running"
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": status,
+ })
+}
diff --git a/pkg/channels/wecom/aibot_test.go b/pkg/channels/wecom/aibot_test.go
new file mode 100644
index 000000000..6f0664187
--- /dev/null
+++ b/pkg/channels/wecom/aibot_test.go
@@ -0,0 +1,210 @@
+package wecom
+
+import (
+ "context"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func TestNewWeComAIBotChannel(t *testing.T) {
+ t.Run("success with valid config", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ WebhookPath: "/webhook/test",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, err := NewWeComAIBotChannel(cfg, messageBus)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ if ch == nil {
+ t.Fatal("Expected channel to be created")
+ }
+
+ if ch.Name() != "wecom_aibot" {
+ t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
+ }
+ })
+
+ t.Run("error with missing token", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ _, err := NewWeComAIBotChannel(cfg, messageBus)
+
+ if err == nil {
+ t.Fatal("Expected error for missing token, got nil")
+ }
+ })
+
+ t.Run("error with missing encoding key", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ }
+
+ messageBus := bus.NewMessageBus()
+ _, err := NewWeComAIBotChannel(cfg, messageBus)
+
+ if err == nil {
+ t.Fatal("Expected error for missing encoding key, got nil")
+ }
+ })
+}
+
+func TestWeComAIBotChannelStartStop(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, err := NewWeComAIBotChannel(cfg, messageBus)
+ if err != nil {
+ t.Fatalf("Failed to create channel: %v", err)
+ }
+
+ ctx := context.Background()
+
+ // Test Start
+ if err := ch.Start(ctx); err != nil {
+ t.Fatalf("Failed to start channel: %v", err)
+ }
+
+ if !ch.IsRunning() {
+ t.Error("Expected channel to be running")
+ }
+
+ // Test Stop
+ if err := ch.Stop(ctx); err != nil {
+ t.Fatalf("Failed to stop channel: %v", err)
+ }
+
+ if ch.IsRunning() {
+ t.Error("Expected channel to be stopped")
+ }
+}
+
+func TestWeComAIBotChannelWebhookPath(t *testing.T) {
+ t.Run("default path", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ expectedPath := "/webhook/wecom-aibot"
+ if ch.WebhookPath() != expectedPath {
+ t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, ch.WebhookPath())
+ }
+ })
+
+ t.Run("custom path", func(t *testing.T) {
+ customPath := "/custom/webhook"
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ WebhookPath: customPath,
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ if ch.WebhookPath() != customPath {
+ t.Errorf("Expected webhook path '%s', got '%s'", customPath, ch.WebhookPath())
+ }
+ })
+}
+
+func TestGenerateStreamID(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ // Generate multiple IDs and check they are unique
+ ids := make(map[string]bool)
+ for i := 0; i < 100; i++ {
+ id := ch.generateStreamID()
+
+ if len(id) != 10 {
+ t.Errorf("Expected stream ID length 10, got %d", len(id))
+ }
+
+ if ids[id] {
+ t.Errorf("Duplicate stream ID generated: %s", id)
+ }
+ ids[id] = true
+ }
+}
+
+func TestEncryptDecrypt(t *testing.T) {
+ // Use a valid 43-character base64 key (企业微信标准格式)
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ plaintext := "Hello, World!"
+ receiveid := ""
+
+ // Encrypt
+ encrypted, err := ch.encryptMessage(plaintext, receiveid)
+ if err != nil {
+ t.Fatalf("Failed to encrypt message: %v", err)
+ }
+
+ if encrypted == "" {
+ t.Fatal("Encrypted message is empty")
+ }
+
+ // Decrypt
+ decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid)
+ if err != nil {
+ t.Fatalf("Failed to decrypt message: %v", err)
+ }
+
+ if decrypted != plaintext {
+ t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
+ }
+}
+
+func TestGenerateSignature(t *testing.T) {
+ token := "test_token"
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ encrypt := "encrypted_msg"
+
+ signature := computeSignature(token, timestamp, nonce, encrypt)
+
+ if signature == "" {
+ t.Error("Generated signature is empty")
+ }
+
+ // Verify signature using verifySignature function
+ if !verifySignature(token, signature, timestamp, nonce, encrypt) {
+ t.Error("Generated signature does not verify correctly")
+ }
+}
diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go
index 771603f3e..717815b9f 100644
--- a/pkg/channels/wecom/app.go
+++ b/pkg/channels/wecom/app.go
@@ -32,13 +32,13 @@ const (
type WeComAppChannel struct {
*channels.BaseChannel
config config.WeComAppConfig
+ client *http.Client
accessToken string
tokenExpiry time.Time
tokenMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
- processedMsgs map[string]bool // Message deduplication: msg_id -> processed
- msgMu sync.RWMutex
+ processedMsgs *MessageDeduplicator
}
// WeComXMLMessage represents the XML message structure from WeCom
@@ -129,13 +129,21 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
+ // Client timeout must be >= the configured ReplyTimeout so the
+ // per-request context deadline is always the effective limit.
+ clientTimeout := 30 * time.Second
+ if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
+ clientTimeout = d
+ }
+
ctx, cancel := context.WithCancel(context.Background())
return &WeComAppChannel{
BaseChannel: base,
config: cfg,
+ client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
- processedMsgs: make(map[string]bool),
+ processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
@@ -148,6 +156,10 @@ func (c *WeComAppChannel) Name() string {
func (c *WeComAppChannel) Start(ctx context.Context) error {
logger.InfoC("wecom_app", "Starting WeCom App channel...")
+ // Cancel the context created in the constructor to avoid a resource leak.
+ if c.cancel != nil {
+ c.cancel()
+ }
c.ctx, c.cancel = context.WithCancel(ctx)
// Get initial access token
@@ -302,8 +314,7 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
}
req.Header.Set("Content-Type", writer.FormDataContentType())
- client := &http.Client{Timeout: 30 * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.client.Do(req)
if err != nil {
return "", channels.ClassifyNetError(err)
}
@@ -330,18 +341,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp
return result.MediaID, nil
}
-// sendImageMessage sends an image message using a media_id.
-func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
+// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
+func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
- msg := WeComImageMessage{
- ToUser: userID,
- MsgType: "image",
- AgentID: c.config.AgentID,
- }
- msg.Image.MediaID = mediaID
-
- jsonData, err := json.Marshal(msg)
+ jsonData, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
@@ -360,8 +364,7 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
}
req.Header.Set("Content-Type", "application/json")
- client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
@@ -389,6 +392,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use
return nil
}
+// sendImageMessage sends an image message using a media_id.
+func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
+ msg := WeComImageMessage{
+ ToUser: userID,
+ MsgType: "image",
+ AgentID: c.config.AgentID,
+ }
+ msg.Image.MediaID = mediaID
+ return c.sendWeComMessage(ctx, accessToken, msg)
+}
+
// WebhookPath returns the path for registering on the shared HTTP server.
func (c *WeComAppChannel) WebhookPath() string {
if c.config.WebhookPath != "" {
@@ -592,23 +606,12 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
// Message deduplication: Use msg_id to prevent duplicate processing
// As per WeCom documentation, use msg_id for deduplication
msgID := fmt.Sprintf("%d", msg.MsgId)
- c.msgMu.Lock()
- if c.processedMsgs[msgID] {
- c.msgMu.Unlock()
+ if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
- c.processedMsgs[msgID] = true
- c.msgMu.Unlock()
-
- // Clean up old messages periodically (keep last 1000)
- if len(c.processedMsgs) > 1000 {
- c.msgMu.Lock()
- c.processedMsgs = make(map[string]bool)
- c.msgMu.Unlock()
- }
senderID := msg.FromUserName
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
@@ -711,64 +714,15 @@ func (c *WeComAppChannel) getAccessToken() string {
return c.accessToken
}
-// sendTextMessage sends a text message to a user
+// sendTextMessage sends a text message to a user.
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
- apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
-
msg := WeComTextMessage{
ToUser: userID,
MsgType: "text",
AgentID: c.config.AgentID,
}
msg.Text.Content = content
-
- jsonData, err := json.Marshal(msg)
- if err != nil {
- return fmt.Errorf("failed to marshal message: %w", err)
- }
-
- // Use configurable timeout (default 5 seconds)
- timeout := c.config.ReplyTimeout
- if timeout <= 0 {
- timeout = 5
- }
-
- reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
- defer cancel()
-
- req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
-
- client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
- resp, err := client.Do(req)
- if err != nil {
- return channels.ClassifyNetError(err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body)))
- }
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("failed to read response: %w", err)
- }
-
- var sendResp WeComSendMessageResponse
- if err := json.Unmarshal(body, &sendResp); err != nil {
- return fmt.Errorf("failed to parse response: %w", err)
- }
-
- if sendResp.ErrCode != 0 {
- return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
- }
-
- return nil
+ return c.sendWeComMessage(ctx, accessToken, msg)
}
// handleHealth handles health check requests
diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go
index 5420949de..7f230494f 100644
--- a/pkg/channels/wecom/app_test.go
+++ b/pkg/channels/wecom/app_test.go
@@ -43,7 +43,7 @@ func encryptTestMessageApp(message, aesKey string) (string, error) {
// Prepare message: random(16) + msg_len(4) + msg + corp_id
random := make([]byte, 0, 16)
- for i := 0; i < 16; i++ {
+ for i := range 16 {
random = append(random, byte(i+1))
}
@@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) {
})
}
-func TestWeComAppPKCS7Unpad(t *testing.T) {
- tests := []struct {
- name string
- input []byte
- expected []byte
- }{
- {
- name: "empty input",
- input: []byte{},
- expected: []byte{},
- },
- {
- name: "valid padding 3 bytes",
- input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
- expected: []byte("hello"),
- },
- {
- name: "valid padding 16 bytes (full block)",
- input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
- expected: []byte("123456789012345"),
- },
- {
- name: "invalid padding larger than data",
- input: []byte{20},
- expected: nil, // should return error
- },
- {
- name: "invalid padding zero",
- input: append([]byte("test"), byte(0)),
- expected: nil, // should return error
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := pkcs7Unpad(tt.input)
- if tt.expected == nil {
- // This case should return an error
- if err == nil {
- t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
- }
- return
- }
- if err != nil {
- t.Errorf("pkcs7Unpad() unexpected error: %v", err)
- return
- }
- if !bytes.Equal(result, tt.expected) {
- t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
- }
- })
- }
-}
-
func TestWeComAppHandleVerification(t *testing.T) {
msgBus := bus.NewMessageBus()
aesKey := generateTestAESKeyApp()
diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go
index e99c710ef..9126a847d 100644
--- a/pkg/channels/wecom/bot.go
+++ b/pkg/channels/wecom/bot.go
@@ -9,7 +9,6 @@ import (
"io"
"net/http"
"strings"
- "sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -25,10 +24,10 @@ import (
type WeComBotChannel struct {
*channels.BaseChannel
config config.WeComConfig
+ client *http.Client
ctx context.Context
cancel context.CancelFunc
- processedMsgs map[string]bool // Message deduplication: msg_id -> processed
- msgMu sync.RWMutex
+ processedMsgs *MessageDeduplicator
}
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
@@ -93,13 +92,21 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
)
+ // Client timeout must be >= the configured ReplyTimeout so the
+ // per-request context deadline is always the effective limit.
+ clientTimeout := 30 * time.Second
+ if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
+ clientTimeout = d
+ }
+
ctx, cancel := context.WithCancel(context.Background())
return &WeComBotChannel{
BaseChannel: base,
config: cfg,
+ client: &http.Client{Timeout: clientTimeout},
ctx: ctx,
cancel: cancel,
- processedMsgs: make(map[string]bool),
+ processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
@@ -112,6 +119,10 @@ func (c *WeComBotChannel) Name() string {
func (c *WeComBotChannel) Start(ctx context.Context) error {
logger.InfoC("wecom", "Starting WeCom Bot channel...")
+ // Cancel the context created in the constructor to avoid a resource leak.
+ if c.cancel != nil {
+ c.cancel()
+ }
c.ctx, c.cancel = context.WithCancel(ctx)
c.SetRunning(true)
@@ -317,23 +328,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
// Message deduplication: Use msg_id to prevent duplicate processing
msgID := msg.MsgID
- c.msgMu.Lock()
- if c.processedMsgs[msgID] {
- c.msgMu.Unlock()
+ if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
- c.processedMsgs[msgID] = true
- c.msgMu.Unlock()
-
- // Clean up old messages periodically (keep last 1000)
- if len(c.processedMsgs) > 1000 {
- c.msgMu.Lock()
- c.processedMsgs = make(map[string]bool)
- c.msgMu.Unlock()
- }
senderID := msg.From.UserID
@@ -446,8 +446,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
}
req.Header.Set("Content-Type", "application/json")
- client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.client.Do(req)
if err != nil {
return channels.ClassifyNetError(err)
}
diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go
index 328b145c2..c053578b1 100644
--- a/pkg/channels/wecom/bot_test.go
+++ b/pkg/channels/wecom/bot_test.go
@@ -42,7 +42,7 @@ func encryptTestMessage(message, aesKey string) (string, error) {
// Prepare message: random(16) + msg_len(4) + msg + receiveid
random := make([]byte, 0, 16)
- for i := 0; i < 16; i++ {
+ for i := range 16 {
random = append(random, byte(i))
}
@@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
}
ch, _ := NewWeComBotChannel(cfg, msgBus)
- t.Run("valid direct message callback", func(t *testing.T) {
- // Create JSON message for direct chat (single)
- jsonMsg := `{
- "msgid": "test_msg_id_123",
- "aibotid": "test_aibot_id",
- "chattype": "single",
- "from": {"userid": "user123"},
- "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
- "msgtype": "text",
- "text": {"content": "Hello World"}
- }`
-
- // Encrypt message
+ runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
+ t.Helper()
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
-
- // Create encrypted XML wrapper
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
@@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
Encrypt: encrypted,
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
-
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encrypted)
-
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
-
ch.handleMessageCallback(context.Background(), w, req)
+ return w
+ }
+ t.Run("valid direct message callback", func(t *testing.T) {
+ w := runBotMessageCallback(t, `{
+ "msgid": "test_msg_id_123",
+ "aibotid": "test_aibot_id",
+ "chattype": "single",
+ "from": {"userid": "user123"},
+ "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ "msgtype": "text",
+ "text": {"content": "Hello World"}
+ }`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
@@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
})
t.Run("valid group message callback", func(t *testing.T) {
- // Create JSON message for group chat
- jsonMsg := `{
+ w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_456",
"aibotid": "test_aibot_id",
"chatid": "group_chat_id_123",
@@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello Group"}
- }`
-
- // Encrypt message
- encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
-
- // Create encrypted XML wrapper
- encryptedWrapper := struct {
- XMLName xml.Name `xml:"xml"`
- Encrypt string `xml:"Encrypt"`
- }{
- Encrypt: encrypted,
- }
- wrapperData, _ := xml.Marshal(encryptedWrapper)
-
- timestamp := "1234567890"
- nonce := "test_nonce"
- signature := generateSignature("test_token", timestamp, nonce, encrypted)
-
- req := httptest.NewRequest(
- http.MethodPost,
- "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
- bytes.NewReader(wrapperData),
- )
- w := httptest.NewRecorder()
-
- ch.handleMessageCallback(context.Background(), w, req)
-
+ }`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go
index 3c1629577..6510e6f81 100644
--- a/pkg/channels/wecom/common.go
+++ b/pkg/channels/wecom/common.go
@@ -1,12 +1,15 @@
package wecom
import (
+ "bytes"
"crypto/aes"
"crypto/cipher"
+ "crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
+ "math/big"
"sort"
"strings"
)
@@ -14,25 +17,23 @@ import (
// blockSize is the PKCS7 block size used by WeCom (32)
const blockSize = 32
+// computeSignature computes the WeCom message signature from the given parameters.
+// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest.
+func computeSignature(token, timestamp, nonce, encrypt string) string {
+ params := []string{token, timestamp, nonce, encrypt}
+ sort.Strings(params)
+ str := strings.Join(params, "")
+ hash := sha1.Sum([]byte(str))
+ return fmt.Sprintf("%x", hash)
+}
+
// verifySignature verifies the message signature for WeCom
// This is a common function used by both WeCom Bot and WeCom App
func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
if token == "" {
return true // Skip verification if token is not set
}
-
- // Sort parameters
- params := []string{token, timestamp, nonce, msgEncrypt}
- sort.Strings(params)
-
- // Concatenate
- str := strings.Join(params, "")
-
- // SHA1 hash
- hash := sha1.Sum([]byte(str))
- expectedSignature := fmt.Sprintf("%x", hash)
-
- return expectedSignature == msgSignature
+ return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
}
// decryptMessage decrypts the encrypted message using AES
@@ -53,64 +54,128 @@ func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (s
return string(decoded), nil
}
- // Decode AES key (base64)
- aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
+ aesKey, err := decodeWeComAESKey(encodingAESKey)
if err != nil {
- return "", fmt.Errorf("failed to decode AES key: %w", err)
+ return "", err
}
- // Decode encrypted message
cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
if err != nil {
return "", fmt.Errorf("failed to decode message: %w", err)
}
- // AES decrypt
+ plainText, err := decryptAESCBC(aesKey, cipherText)
+ if err != nil {
+ return "", err
+ }
+
+ return unpackWeComFrame(plainText, receiveid)
+}
+
+// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
+// appended automatically) and validates that the result is exactly 32 bytes.
+// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
+func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
+ aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode AES key: %w", err)
+ }
+ if len(aesKey) != 32 {
+ return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
+ }
+ return aesKey, nil
+}
+
+// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
+// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
+// plaintext to a multiple of aes.BlockSize before calling.
+func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(aesKey)
if err != nil {
- return "", fmt.Errorf("failed to create cipher: %w", err)
+ return nil, fmt.Errorf("failed to create cipher: %w", err)
}
-
- if len(cipherText) < aes.BlockSize {
- return "", fmt.Errorf("ciphertext too short")
- }
-
- // IV is the first 16 bytes of AESKey
iv := aesKey[:aes.BlockSize]
- mode := cipher.NewCBCDecrypter(block, iv)
- plainText := make([]byte, len(cipherText))
- mode.CryptBlocks(plainText, cipherText)
+ ciphertext := make([]byte, len(plaintext))
+ cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
+ return ciphertext, nil
+}
- // Remove PKCS7 padding
- plainText, err = pkcs7Unpad(plainText)
- if err != nil {
- return "", fmt.Errorf("failed to unpad: %w", err)
+// packWeComFrame builds the WeCom wire format:
+//
+// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
+func packWeComFrame(msg, receiveid string) ([]byte, error) {
+ randomBytes := make([]byte, 16)
+ for i := range 16 {
+ n, err := rand.Int(rand.Reader, big.NewInt(10))
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate random: %w", err)
+ }
+ randomBytes[i] = byte('0' + n.Int64())
}
+ msgBytes := []byte(msg)
+ msgLenBytes := make([]byte, 4)
+ binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
+ var buf bytes.Buffer
+ buf.Write(randomBytes)
+ buf.Write(msgLenBytes)
+ buf.Write(msgBytes)
+ buf.WriteString(receiveid)
+ return buf.Bytes(), nil
+}
- // Parse message structure
- // Format: random(16) + msg_len(4) + msg + receiveid
- if len(plainText) < 20 {
- return "", fmt.Errorf("decrypted message too short")
+// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
+// If receiveid is non-empty it verifies the frame's trailing receiveid field.
+func unpackWeComFrame(data []byte, receiveid string) (string, error) {
+ if len(data) < 20 {
+ return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
}
-
- msgLen := binary.BigEndian.Uint32(plainText[16:20])
- if int(msgLen) > len(plainText)-20 {
- return "", fmt.Errorf("invalid message length")
+ msgLen := binary.BigEndian.Uint32(data[16:20])
+ if int(msgLen) > len(data)-20 {
+ return "", fmt.Errorf("invalid message length: %d", msgLen)
}
-
- msg := plainText[20 : 20+msgLen]
-
- // Verify receiveid if provided
- if receiveid != "" && len(plainText) > 20+int(msgLen) {
- actualReceiveID := string(plainText[20+msgLen:])
+ msg := data[20 : 20+msgLen]
+ if receiveid != "" && len(data) > 20+int(msgLen) {
+ actualReceiveID := string(data[20+msgLen:])
if actualReceiveID != receiveid {
return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
}
}
-
return string(msg), nil
}
+// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
+// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
+func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
+ if len(ciphertext) == 0 {
+ return nil, fmt.Errorf("ciphertext is empty")
+ }
+ if len(ciphertext)%aes.BlockSize != 0 {
+ return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
+ }
+ block, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cipher: %w", err)
+ }
+ iv := aesKey[:aes.BlockSize]
+ plaintext := make([]byte, len(ciphertext))
+ cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
+ plaintext, err = pkcs7Unpad(plaintext)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unpad: %w", err)
+ }
+ return plaintext, nil
+}
+
+// pkcs7Pad adds PKCS7 padding
+func pkcs7Pad(data []byte, blockSize int) []byte {
+ padding := blockSize - (len(data) % blockSize)
+ if padding == 0 {
+ padding = blockSize
+ }
+ padText := bytes.Repeat([]byte{byte(padding)}, padding)
+ return append(data, padText...)
+}
+
// pkcs7Unpad removes PKCS7 padding with validation
func pkcs7Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {
@@ -125,7 +190,7 @@ func pkcs7Unpad(data []byte) ([]byte, error) {
return nil, fmt.Errorf("padding size larger than data")
}
// Verify all padding bytes
- for i := 0; i < padding; i++ {
+ for i := range padding {
if data[len(data)-1-i] != byte(padding) {
return nil, fmt.Errorf("invalid padding byte at position %d", i)
}
diff --git a/pkg/channels/wecom/dedupe.go b/pkg/channels/wecom/dedupe.go
new file mode 100644
index 000000000..865be668e
--- /dev/null
+++ b/pkg/channels/wecom/dedupe.go
@@ -0,0 +1,54 @@
+package wecom
+
+import "sync"
+
+const wecomMaxProcessedMessages = 1000
+
+// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
+// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
+// messages without causing "amnesia cliffs" when the limit is reached.
+type MessageDeduplicator struct {
+ mu sync.Mutex
+ msgs map[string]bool
+ ring []string
+ idx int
+ max int
+}
+
+// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
+func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
+ if maxEntries <= 0 {
+ maxEntries = wecomMaxProcessedMessages
+ }
+ return &MessageDeduplicator{
+ msgs: make(map[string]bool, maxEntries),
+ ring: make([]string, maxEntries),
+ max: maxEntries,
+ }
+}
+
+// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
+func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // 1. Check for duplicate
+ if d.msgs[msgID] {
+ return false
+ }
+
+ // 2. Evict the oldest message at our current ring position (if any)
+ oldestID := d.ring[d.idx]
+ if oldestID != "" {
+ delete(d.msgs, oldestID)
+ }
+
+ // 3. Store the new message
+ d.msgs[msgID] = true
+ d.ring[d.idx] = msgID
+
+ // 4. Advance the circle queue index
+ d.idx = (d.idx + 1) % d.max
+
+ return true
+}
diff --git a/pkg/channels/wecom/dedupe_test.go b/pkg/channels/wecom/dedupe_test.go
new file mode 100644
index 000000000..10dff4cfe
--- /dev/null
+++ b/pkg/channels/wecom/dedupe_test.go
@@ -0,0 +1,83 @@
+package wecom
+
+import (
+ "sync"
+ "testing"
+)
+
+func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
+ d := NewMessageDeduplicator(wecomMaxProcessedMessages)
+
+ if ok := d.MarkMessageProcessed("msg-1"); !ok {
+ t.Fatalf("first message should be accepted")
+ }
+
+ if ok := d.MarkMessageProcessed("msg-1"); ok {
+ t.Fatalf("duplicate message should be rejected")
+ }
+}
+
+func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
+ d := NewMessageDeduplicator(wecomMaxProcessedMessages)
+
+ const goroutines = 64
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ results := make(chan bool, goroutines)
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+ results <- d.MarkMessageProcessed("msg-concurrent")
+ }()
+ }
+
+ wg.Wait()
+ close(results)
+
+ successes := 0
+ for ok := range results {
+ if ok {
+ successes++
+ }
+ }
+
+ if successes != 1 {
+ t.Fatalf("expected exactly 1 successful mark, got %d", successes)
+ }
+}
+
+func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
+ // Create a deduplicator with a very small capacity to test eviction easily.
+ capacity := 3
+ d := NewMessageDeduplicator(capacity)
+
+ // Fill the queue.
+ d.MarkMessageProcessed("msg-1")
+ d.MarkMessageProcessed("msg-2")
+ d.MarkMessageProcessed("msg-3")
+
+ // At this point, the queue is full. msg-1 is the oldest.
+ if len(d.msgs) != 3 {
+ t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
+ }
+
+ // This should evict msg-1 and add msg-4.
+ if ok := d.MarkMessageProcessed("msg-4"); !ok {
+ t.Fatalf("msg-4 should be accepted")
+ }
+
+ if len(d.msgs) != 3 {
+ t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
+ }
+
+ // msg-1 should now be forgotten (evicted).
+ if ok := d.MarkMessageProcessed("msg-1"); !ok {
+ t.Fatalf("msg-1 should be accepted again because it was evicted")
+ }
+
+ // msg-2 should have been evicted when we added msg-1 back.
+ if ok := d.MarkMessageProcessed("msg-2"); !ok {
+ t.Fatalf("msg-2 should be accepted again because it was evicted")
+ }
+}
diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go
index 3ef1ecdf3..bc5a70fa3 100644
--- a/pkg/channels/wecom/init.go
+++ b/pkg/channels/wecom/init.go
@@ -13,4 +13,7 @@ func init() {
channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewWeComAppChannel(cfg.Channels.WeComApp, b)
})
+ channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b)
+ })
}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index d84772d2b..4c9cda738 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -168,17 +168,28 @@ type SessionConfig struct {
}
type AgentDefaults struct {
- Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
- RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
- Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
- ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
- Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
- ModelFallbacks []string `json:"model_fallbacks,omitempty"`
- ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
- ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
- MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
- Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
- MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
+ Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
+ RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
+ AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
+ Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
+ ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
+ Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
+ ModelFallbacks []string `json:"model_fallbacks,omitempty"`
+ ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
+ ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
+ MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
+ Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
+ MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
+ MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
+}
+
+const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
+
+func (d *AgentDefaults) GetMaxMediaSize() int {
+ if d.MaxMediaSize > 0 {
+ return d.MaxMediaSize
+ }
+ return DefaultMaxMediaSize
}
// GetModelName returns the effective model name for the agent defaults.
@@ -191,19 +202,20 @@ func (d *AgentDefaults) GetModelName() string {
}
type ChannelsConfig struct {
- WhatsApp WhatsAppConfig `json:"whatsapp"`
- Telegram TelegramConfig `json:"telegram"`
- Feishu FeishuConfig `json:"feishu"`
- Discord DiscordConfig `json:"discord"`
- MaixCam MaixCamConfig `json:"maixcam"`
- QQ QQConfig `json:"qq"`
- DingTalk DingTalkConfig `json:"dingtalk"`
- Slack SlackConfig `json:"slack"`
- LINE LINEConfig `json:"line"`
- OneBot OneBotConfig `json:"onebot"`
- WeCom WeComConfig `json:"wecom"`
- WeComApp WeComAppConfig `json:"wecom_app"`
- Pico PicoConfig `json:"pico"`
+ WhatsApp WhatsAppConfig `json:"whatsapp"`
+ Telegram TelegramConfig `json:"telegram"`
+ Feishu FeishuConfig `json:"feishu"`
+ Discord DiscordConfig `json:"discord"`
+ MaixCam MaixCamConfig `json:"maixcam"`
+ QQ QQConfig `json:"qq"`
+ DingTalk DingTalkConfig `json:"dingtalk"`
+ Slack SlackConfig `json:"slack"`
+ LINE LINEConfig `json:"line"`
+ OneBot OneBotConfig `json:"onebot"`
+ WeCom WeComConfig `json:"wecom"`
+ WeComApp WeComAppConfig `json:"wecom_app"`
+ WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
+ Pico PicoConfig `json:"pico"`
}
// GroupTriggerConfig controls when the bot responds in group chats.
@@ -235,6 +247,7 @@ type WhatsAppConfig struct {
type TelegramConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
+ BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
@@ -251,6 +264,7 @@ type FeishuConfig struct {
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
}
@@ -359,6 +373,18 @@ type WeComAppConfig struct {
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
}
+type WeComAIBotConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
+ EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
+ ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
+ MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
+ WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
+}
+
type PicoConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
@@ -385,6 +411,7 @@ type DevicesConfig struct {
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI OpenAIProviderConfig `json:"openai"`
+ LiteLLM ProviderConfig `json:"litellm"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
@@ -408,6 +435,7 @@ type ProvidersConfig struct {
func (p ProvidersConfig) IsEmpty() bool {
return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" &&
p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" &&
+ p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" &&
p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" &&
p.Groq.APIKey == "" && p.Groq.APIBase == "" &&
p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" &&
@@ -523,7 +551,8 @@ type WebToolsConfig struct {
Perplexity PerplexityConfig `json:"perplexity"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
- Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
+ Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
+ FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
}
type CronToolsConfig struct {
@@ -531,8 +560,9 @@ type CronToolsConfig struct {
}
type ExecConfig struct {
- EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
- CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
+ EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
+ CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
+ CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
}
type MediaCleanupConfig struct {
@@ -542,11 +572,14 @@ type MediaCleanupConfig struct {
}
type ToolsConfig struct {
- Web WebToolsConfig `json:"web"`
- Cron CronToolsConfig `json:"cron"`
- Exec ExecConfig `json:"exec"`
- Skills SkillsToolsConfig `json:"skills"`
- MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
+ AllowReadPaths []string `json:"allow_read_paths" env:"PICOCLAW_TOOLS_ALLOW_READ_PATHS"`
+ AllowWritePaths []string `json:"allow_write_paths" env:"PICOCLAW_TOOLS_ALLOW_WRITE_PATHS"`
+ Web WebToolsConfig `json:"web"`
+ Cron CronToolsConfig `json:"cron"`
+ Exec ExecConfig `json:"exec"`
+ Skills SkillsToolsConfig `json:"skills"`
+ MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
+ MCP MCPConfig `json:"mcp"`
}
type SkillsToolsConfig struct {
@@ -576,6 +609,34 @@ type ClawHubRegistryConfig struct {
MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
}
+// MCPServerConfig defines configuration for a single MCP server
+type MCPServerConfig struct {
+ // Enabled indicates whether this MCP server is active
+ Enabled bool `json:"enabled"`
+ // Command is the executable to run (e.g., "npx", "python", "/path/to/server")
+ Command string `json:"command"`
+ // Args are the arguments to pass to the command
+ Args []string `json:"args,omitempty"`
+ // Env are environment variables to set for the server process (stdio only)
+ Env map[string]string `json:"env,omitempty"`
+ // EnvFile is the path to a file containing environment variables (stdio only)
+ EnvFile string `json:"env_file,omitempty"`
+ // Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
+ Type string `json:"type,omitempty"`
+ // URL is used for SSE/HTTP transport
+ URL string `json:"url,omitempty"`
+ // Headers are HTTP headers to send with requests (sse/http only)
+ Headers map[string]string `json:"headers,omitempty"`
+}
+
+// MCPConfig defines configuration for all MCP servers
+type MCPConfig struct {
+ // Enabled globally enables/disables MCP integration
+ Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
+ // Servers is a map of server name to server configuration
+ Servers map[string]MCPServerConfig `json:"servers,omitempty"`
+}
+
func LoadConfig(path string) (*Config, error) {
cfg := DefaultConfig()
@@ -632,7 +693,8 @@ func (c *Config) migrateChannelConfigs() {
}
// OneBot: group_trigger_prefix -> group_trigger.prefixes
- if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
+ if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 &&
+ len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix
}
}
@@ -742,25 +804,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig {
// HasProvidersConfig checks if any provider in the old providers config has configuration.
func (c *Config) HasProvidersConfig() bool {
- v := c.Providers
- return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" ||
- v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" ||
- v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" ||
- v.Groq.APIKey != "" || v.Groq.APIBase != "" ||
- v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" ||
- v.VLLM.APIKey != "" || v.VLLM.APIBase != "" ||
- v.Gemini.APIKey != "" || v.Gemini.APIBase != "" ||
- v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" ||
- v.Ollama.APIKey != "" || v.Ollama.APIBase != "" ||
- v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" ||
- v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" ||
- v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" ||
- v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" ||
- v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" ||
- v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" ||
- v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
- v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
- v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
+ return !c.Providers.IsEmpty()
}
// ValidateModelList validates all ModelConfig entries in the model_list.
diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go
index 12fd10b50..6af7c209e 100644
--- a/pkg/config/config_test.go
+++ b/pkg/config/config_test.go
@@ -442,3 +442,28 @@ func TestDefaultConfig_DMScope(t *testing.T) {
t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope)
}
}
+
+func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
+ // Unset to ensure we test the default
+ t.Setenv("PICOCLAW_HOME", "")
+ // Set a known home for consistent test results
+ t.Setenv("HOME", "/tmp/home")
+
+ cfg := DefaultConfig()
+ want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
+
+ if cfg.Agents.Defaults.Workspace != want {
+ t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
+ }
+}
+
+func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
+ t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
+
+ cfg := DefaultConfig()
+ want := "/custom/picoclaw/home/workspace"
+
+ if cfg.Agents.Defaults.Workspace != want {
+ t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
+ }
+}
diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go
index ebb924859..9fc09c5f1 100644
--- a/pkg/config/defaults.go
+++ b/pkg/config/defaults.go
@@ -5,12 +5,28 @@
package config
+import (
+ "os"
+ "path/filepath"
+)
+
// DefaultConfig returns the default configuration for PicoClaw.
func DefaultConfig() *Config {
+ // Determine the base path for the workspace.
+ // Priority: $PICOCLAW_HOME > ~/.picoclaw
+ var homePath string
+ if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
+ homePath = picoclawHome
+ } else {
+ userHome, _ := os.UserHomeDir()
+ homePath = filepath.Join(userHome, ".picoclaw")
+ }
+ workspacePath := filepath.Join(homePath, "workspace")
+
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
- Workspace: "~/.picoclaw/workspace",
+ Workspace: workspacePath,
RestrictToWorkspace: true,
Provider: "",
Model: "",
@@ -121,6 +137,16 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
ReplyTimeout: 5,
},
+ WeComAIBot: WeComAIBotConfig{
+ Enabled: false,
+ Token: "",
+ EncodingAESKey: "",
+ WebhookPath: "/webhook/wecom-aibot",
+ AllowFrom: FlexibleStringSlice{},
+ ReplyTimeout: 5,
+ MaxSteps: 10,
+ WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
+ },
Pico: PicoConfig{
Enabled: false,
Token: "",
@@ -299,7 +325,8 @@ func DefaultConfig() *Config {
Interval: 5,
},
Web: WebToolsConfig{
- Proxy: "",
+ Proxy: "",
+ FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
Brave: BraveConfig{
Enabled: false,
APIKey: "",
@@ -334,6 +361,10 @@ func DefaultConfig() *Config {
TTLSeconds: 300,
},
},
+ MCP: MCPConfig{
+ Enabled: false,
+ Servers: map[string]MCPServerConfig{},
+ },
},
Heartbeat: HeartbeatConfig{
Enabled: true,
diff --git a/pkg/config/migration.go b/pkg/config/migration.go
index aade11c1b..f1dc16acc 100644
--- a/pkg/config/migration.go
+++ b/pkg/config/migration.go
@@ -88,6 +88,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
+ {
+ providerNames: []string{"litellm"},
+ protocol: "litellm",
+ buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
+ if p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" {
+ return ModelConfig{}, false
+ }
+ return ModelConfig{
+ ModelName: "litellm",
+ Model: "litellm/auto",
+ APIKey: p.LiteLLM.APIKey,
+ APIBase: p.LiteLLM.APIBase,
+ Proxy: p.LiteLLM.Proxy,
+ RequestTimeout: p.LiteLLM.RequestTimeout,
+ }, true
+ },
+ },
{
providerNames: []string{"openrouter"},
protocol: "openrouter",
diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go
index 9f3631d08..e5db91737 100644
--- a/pkg/config/migration_test.go
+++ b/pkg/config/migration_test.go
@@ -63,6 +63,33 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
}
}
+func TestConvertProvidersToModelList_LiteLLM(t *testing.T) {
+ cfg := &Config{
+ Providers: ProvidersConfig{
+ LiteLLM: ProviderConfig{
+ APIKey: "litellm-key",
+ APIBase: "http://localhost:4000/v1",
+ },
+ },
+ }
+
+ result := ConvertProvidersToModelList(cfg)
+
+ if len(result) != 1 {
+ t.Fatalf("len(result) = %d, want 1", len(result))
+ }
+
+ if result[0].ModelName != "litellm" {
+ t.Errorf("ModelName = %q, want %q", result[0].ModelName, "litellm")
+ }
+ if result[0].Model != "litellm/auto" {
+ t.Errorf("Model = %q, want %q", result[0].Model, "litellm/auto")
+ }
+ if result[0].APIBase != "http://localhost:4000/v1" {
+ t.Errorf("APIBase = %q, want %q", result[0].APIBase, "http://localhost:4000/v1")
+ }
+}
+
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
@@ -115,6 +142,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}},
+ LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"},
Anthropic: ProviderConfig{APIKey: "key2"},
OpenRouter: ProviderConfig{APIKey: "key3"},
Groq: ProviderConfig{APIKey: "key4"},
@@ -137,9 +165,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
result := ConvertProvidersToModelList(cfg)
- // All 18 providers should be converted
- if len(result) != 18 {
- t.Errorf("len(result) = %d, want 18", len(result))
+ // All 19 providers should be converted
+ if len(result) != 19 {
+ t.Errorf("len(result) = %d, want 19", len(result))
}
}
diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go
index 084f50a82..da6e506f8 100644
--- a/pkg/config/model_config_test.go
+++ b/pkg/config/model_config_test.go
@@ -64,7 +64,7 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
// Test round-robin distribution
results := make(map[string]int)
- for i := 0; i < 30; i++ {
+ for range 30 {
result, err := cfg.GetModelConfig("lb-model")
if err != nil {
t.Fatalf("GetModelConfig() error = %v", err)
@@ -94,17 +94,15 @@ func TestGetModelConfig_Concurrent(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, goroutines*iterations)
- for i := 0; i < goroutines; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for j := 0; j < iterations; j++ {
+ for range goroutines {
+ wg.Go(func() {
+ for range iterations {
_, err := cfg.GetModelConfig("concurrent-model")
if err != nil {
errors <- err
}
}
- }()
+ })
}
wg.Wait()
diff --git a/pkg/health/server.go b/pkg/health/server.go
index de1ff60fe..5609ebdf6 100644
--- a/pkg/health/server.go
+++ b/pkg/health/server.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "maps"
"net/http"
"sync"
"time"
@@ -122,9 +123,7 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
s.mu.RLock()
ready := s.ready
checks := make(map[string]Check)
- for k, v := range s.checks {
- checks[k] = v
- }
+ maps.Copy(checks, s.checks)
s.mu.RUnlock()
if !ready {
diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go
index a7aef8c3a..3b7eeeefb 100644
--- a/pkg/heartbeat/service_test.go
+++ b/pkg/heartbeat/service_test.go
@@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
}
}
-func TestExecuteHeartbeat_Error(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- hs := NewHeartbeatService(tmpDir, 30, true)
- hs.stopChan = make(chan struct{}) // Enable for testing
-
- hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
- return &tools.ToolResult{
- ForLLM: "Heartbeat failed: connection error",
- ForUser: "",
- Silent: false,
- IsError: true,
- Async: false,
- }
- })
-
- // Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
-
- hs.executeHeartbeat()
-
- // Check log file for error message
- logFile := filepath.Join(tmpDir, "heartbeat.log")
- data, err := os.ReadFile(logFile)
- if err != nil {
- t.Fatalf("Failed to read log file: %v", err)
+func TestExecuteHeartbeat_ResultLogging(t *testing.T) {
+ tests := []struct {
+ name string
+ result *tools.ToolResult
+ wantLog string
+ }{
+ {
+ name: "error result",
+ result: &tools.ToolResult{
+ ForLLM: "Heartbeat failed: connection error",
+ ForUser: "",
+ Silent: false,
+ IsError: true,
+ Async: false,
+ },
+ wantLog: "error message",
+ },
+ {
+ name: "silent result",
+ result: &tools.ToolResult{
+ ForLLM: "Heartbeat completed successfully",
+ ForUser: "",
+ Silent: true,
+ IsError: false,
+ Async: false,
+ },
+ wantLog: "completion message",
+ },
}
- logContent := string(data)
- if logContent == "" {
- t.Error("Expected log file to contain error message")
- }
-}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
-func TestExecuteHeartbeat_Silent(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
+ hs := NewHeartbeatService(tmpDir, 30, true)
+ hs.stopChan = make(chan struct{}) // Enable for testing
- hs := NewHeartbeatService(tmpDir, 30, true)
- hs.stopChan = make(chan struct{}) // Enable for testing
+ hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
+ return tt.result
+ })
- hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
- return &tools.ToolResult{
- ForLLM: "Heartbeat completed successfully",
- ForUser: "",
- Silent: true,
- IsError: false,
- Async: false,
- }
- })
+ os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
+ hs.executeHeartbeat()
- // Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
-
- hs.executeHeartbeat()
-
- // Check log file for completion message
- logFile := filepath.Join(tmpDir, "heartbeat.log")
- data, err := os.ReadFile(logFile)
- if err != nil {
- t.Fatalf("Failed to read log file: %v", err)
- }
-
- logContent := string(data)
- if logContent == "" {
- t.Error("Expected log file to contain completion message")
+ logFile := filepath.Join(tmpDir, "heartbeat.log")
+ data, err := os.ReadFile(logFile)
+ if err != nil {
+ t.Fatalf("Failed to read log file: %v", err)
+ }
+ if string(data) == "" {
+ t.Errorf("Expected log file to contain %s", tt.wantLog)
+ }
+ })
}
}
diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go
new file mode 100644
index 000000000..7b63cc979
--- /dev/null
+++ b/pkg/mcp/manager.go
@@ -0,0 +1,532 @@
+package mcp
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// headerTransport is an http.RoundTripper that adds custom headers to requests
+type headerTransport struct {
+ base http.RoundTripper
+ headers map[string]string
+}
+
+func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ // Clone the request to avoid modifying the original
+ req = req.Clone(req.Context())
+
+ // Add custom headers
+ for key, value := range t.headers {
+ req.Header.Set(key, value)
+ }
+
+ // Use the base transport
+ base := t.base
+ if base == nil {
+ base = http.DefaultTransport
+ }
+ return base.RoundTrip(req)
+}
+
+// loadEnvFile loads environment variables from a file in .env format
+// Each line should be in the format: KEY=value
+// Lines starting with # are comments
+// Empty lines are ignored
+func loadEnvFile(path string) (map[string]string, error) {
+ file, err := os.Open(path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open env file: %w", err)
+ }
+ defer file.Close()
+
+ envVars := make(map[string]string)
+ scanner := bufio.NewScanner(file)
+ lineNum := 0
+
+ for scanner.Scan() {
+ lineNum++
+ line := strings.TrimSpace(scanner.Text())
+
+ // Skip empty lines and comments
+ if line == "" || strings.HasPrefix(line, "#") {
+ continue
+ }
+
+ // Parse KEY=value
+ parts := strings.SplitN(line, "=", 2)
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
+ }
+
+ key := strings.TrimSpace(parts[0])
+ value := strings.TrimSpace(parts[1])
+
+ if key == "" {
+ return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum)
+ }
+
+ // Remove surrounding quotes if present
+ if len(value) >= 2 {
+ if (value[0] == '"' && value[len(value)-1] == '"') ||
+ (value[0] == '\'' && value[len(value)-1] == '\'') {
+ value = value[1 : len(value)-1]
+ }
+ }
+
+ envVars[key] = value
+ }
+
+ if err := scanner.Err(); err != nil {
+ return nil, fmt.Errorf("error reading env file: %w", err)
+ }
+
+ return envVars, nil
+}
+
+// ServerConnection represents a connection to an MCP server
+type ServerConnection struct {
+ Name string
+ Client *mcp.Client
+ Session *mcp.ClientSession
+ Tools []*mcp.Tool
+}
+
+// Manager manages multiple MCP server connections
+type Manager struct {
+ servers map[string]*ServerConnection
+ mu sync.RWMutex
+ closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
+ wg sync.WaitGroup // tracks in-flight CallTool calls
+}
+
+// NewManager creates a new MCP manager
+func NewManager() *Manager {
+ return &Manager{
+ servers: make(map[string]*ServerConnection),
+ }
+}
+
+// LoadFromConfig loads MCP servers from configuration
+func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
+ return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
+}
+
+// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
+// This is the minimal dependency version that doesn't require the full Config object.
+func (m *Manager) LoadFromMCPConfig(
+ ctx context.Context,
+ mcpCfg config.MCPConfig,
+ workspacePath string,
+) error {
+ if !mcpCfg.Enabled {
+ logger.InfoCF("mcp", "MCP integration is disabled", nil)
+ return nil
+ }
+
+ if len(mcpCfg.Servers) == 0 {
+ logger.InfoCF("mcp", "No MCP servers configured", nil)
+ return nil
+ }
+
+ logger.InfoCF("mcp", "Initializing MCP servers",
+ map[string]any{
+ "count": len(mcpCfg.Servers),
+ })
+
+ var wg sync.WaitGroup
+ errs := make(chan error, len(mcpCfg.Servers))
+ enabledCount := 0
+
+ for name, serverCfg := range mcpCfg.Servers {
+ if !serverCfg.Enabled {
+ logger.DebugCF("mcp", "Skipping disabled server",
+ map[string]any{
+ "server": name,
+ })
+ continue
+ }
+
+ enabledCount++
+ wg.Add(1)
+ go func(name string, serverCfg config.MCPServerConfig, workspace string) {
+ defer wg.Done()
+
+ // Resolve relative envFile paths relative to workspace
+ if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
+ if workspace == "" {
+ err := fmt.Errorf(
+ "workspace path is empty while resolving relative envFile %q for server %s",
+ serverCfg.EnvFile,
+ name,
+ )
+ logger.ErrorCF("mcp", "Invalid MCP server configuration",
+ map[string]any{
+ "server": name,
+ "env_file": serverCfg.EnvFile,
+ "error": err.Error(),
+ })
+ errs <- err
+ return
+ }
+ serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
+ }
+
+ if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
+ logger.ErrorCF("mcp", "Failed to connect to MCP server",
+ map[string]any{
+ "server": name,
+ "error": err.Error(),
+ })
+ errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
+ }
+ }(name, serverCfg, workspacePath)
+ }
+
+ wg.Wait()
+ close(errs)
+
+ // Collect errors
+ var allErrors []error
+ for err := range errs {
+ allErrors = append(allErrors, err)
+ }
+
+ connectedCount := len(m.GetServers())
+
+ // If all enabled servers failed to connect, return aggregated error
+ if enabledCount > 0 && connectedCount == 0 {
+ logger.ErrorCF("mcp", "All MCP servers failed to connect",
+ map[string]any{
+ "failed": len(allErrors),
+ "total": enabledCount,
+ })
+ return errors.Join(allErrors...)
+ }
+
+ if len(allErrors) > 0 {
+ logger.WarnCF("mcp", "Some MCP servers failed to connect",
+ map[string]any{
+ "failed": len(allErrors),
+ "connected": connectedCount,
+ "total": enabledCount,
+ })
+ // Don't fail completely if some servers successfully connected
+ }
+
+ logger.InfoCF("mcp", "MCP server initialization complete",
+ map[string]any{
+ "connected": connectedCount,
+ "total": enabledCount,
+ })
+
+ return nil
+}
+
+// ConnectServer connects to a single MCP server
+func (m *Manager) ConnectServer(
+ ctx context.Context,
+ name string,
+ cfg config.MCPServerConfig,
+) error {
+ logger.InfoCF("mcp", "Connecting to MCP server",
+ map[string]any{
+ "server": name,
+ "command": cfg.Command,
+ "args_count": len(cfg.Args),
+ })
+
+ // Create client
+ client := mcp.NewClient(&mcp.Implementation{
+ Name: "picoclaw",
+ Version: "1.0.0",
+ }, nil)
+
+ // Create transport based on configuration
+ // Auto-detect transport type if not explicitly specified
+ var transport mcp.Transport
+ transportType := cfg.Type
+
+ // Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
+ if transportType == "" {
+ if cfg.URL != "" {
+ transportType = "sse"
+ } else if cfg.Command != "" {
+ transportType = "stdio"
+ } else {
+ return fmt.Errorf("either URL or command must be provided")
+ }
+ }
+
+ switch transportType {
+ case "sse", "http":
+ if cfg.URL == "" {
+ return fmt.Errorf("URL is required for SSE/HTTP transport")
+ }
+ logger.DebugCF("mcp", "Using SSE/HTTP transport",
+ map[string]any{
+ "server": name,
+ "url": cfg.URL,
+ })
+
+ sseTransport := &mcp.StreamableClientTransport{
+ Endpoint: cfg.URL,
+ }
+
+ // Add custom headers if provided
+ if len(cfg.Headers) > 0 {
+ // Create a custom HTTP client with header-injecting transport
+ sseTransport.HTTPClient = &http.Client{
+ Transport: &headerTransport{
+ base: http.DefaultTransport,
+ headers: cfg.Headers,
+ },
+ }
+ logger.DebugCF("mcp", "Added custom HTTP headers",
+ map[string]any{
+ "server": name,
+ "header_count": len(cfg.Headers),
+ })
+ }
+
+ transport = sseTransport
+ case "stdio":
+ if cfg.Command == "" {
+ return fmt.Errorf("command is required for stdio transport")
+ }
+ logger.DebugCF("mcp", "Using stdio transport",
+ map[string]any{
+ "server": name,
+ "command": cfg.Command,
+ })
+ // Create command with context
+ cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
+
+ // Build environment variables with proper override semantics
+ // Use a map to ensure config variables override file variables
+ envMap := make(map[string]string)
+
+ // Start with parent process environment
+ for _, e := range cmd.Environ() {
+ if idx := strings.Index(e, "="); idx > 0 {
+ envMap[e[:idx]] = e[idx+1:]
+ }
+ }
+
+ // Load environment variables from file if specified
+ if cfg.EnvFile != "" {
+ envVars, err := loadEnvFile(cfg.EnvFile)
+ if err != nil {
+ return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
+ }
+ for k, v := range envVars {
+ envMap[k] = v
+ }
+ logger.DebugCF("mcp", "Loaded environment variables from file",
+ map[string]any{
+ "server": name,
+ "envFile": cfg.EnvFile,
+ "var_count": len(envVars),
+ })
+ }
+
+ // Environment variables from config override those from file
+ for k, v := range cfg.Env {
+ envMap[k] = v
+ }
+
+ // Convert map to slice
+ env := make([]string, 0, len(envMap))
+ for k, v := range envMap {
+ env = append(env, fmt.Sprintf("%s=%s", k, v))
+ }
+ cmd.Env = env
+
+ transport = &mcp.CommandTransport{Command: cmd}
+ default:
+ return fmt.Errorf(
+ "unsupported transport type: %s (supported: stdio, sse, http)",
+ transportType,
+ )
+ }
+
+ // Connect to server
+ session, err := client.Connect(ctx, transport, nil)
+ if err != nil {
+ return fmt.Errorf("failed to connect: %w", err)
+ }
+
+ // Get server info
+ initResult := session.InitializeResult()
+ logger.InfoCF("mcp", "Connected to MCP server",
+ map[string]any{
+ "server": name,
+ "serverName": initResult.ServerInfo.Name,
+ "serverVersion": initResult.ServerInfo.Version,
+ "protocol": initResult.ProtocolVersion,
+ })
+
+ // List available tools if supported
+ var tools []*mcp.Tool
+ if initResult.Capabilities.Tools != nil {
+ for tool, err := range session.Tools(ctx, nil) {
+ if err != nil {
+ logger.WarnCF("mcp", "Error listing tool",
+ map[string]any{
+ "server": name,
+ "error": err.Error(),
+ })
+ continue
+ }
+ tools = append(tools, tool)
+ }
+
+ logger.InfoCF("mcp", "Listed tools from MCP server",
+ map[string]any{
+ "server": name,
+ "toolCount": len(tools),
+ })
+ }
+
+ // Store connection
+ m.mu.Lock()
+ m.servers[name] = &ServerConnection{
+ Name: name,
+ Client: client,
+ Session: session,
+ Tools: tools,
+ }
+ m.mu.Unlock()
+
+ return nil
+}
+
+// GetServers returns all connected servers
+func (m *Manager) GetServers() map[string]*ServerConnection {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ result := make(map[string]*ServerConnection, len(m.servers))
+ for k, v := range m.servers {
+ result[k] = v
+ }
+ return result
+}
+
+// GetServer returns a specific server connection
+func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ conn, ok := m.servers[name]
+ return conn, ok
+}
+
+// CallTool calls a tool on a specific server
+func (m *Manager) CallTool(
+ ctx context.Context,
+ serverName, toolName string,
+ arguments map[string]any,
+) (*mcp.CallToolResult, error) {
+ // Check if closed before acquiring lock (fast path)
+ if m.closed.Load() {
+ return nil, fmt.Errorf("manager is closed")
+ }
+
+ m.mu.RLock()
+ // Double-check after acquiring lock to prevent TOCTOU race
+ if m.closed.Load() {
+ m.mu.RUnlock()
+ return nil, fmt.Errorf("manager is closed")
+ }
+ conn, ok := m.servers[serverName]
+ if ok {
+ m.wg.Add(1) // Add to WaitGroup while holding the lock
+ }
+ m.mu.RUnlock()
+
+ if !ok {
+ return nil, fmt.Errorf("server %s not found", serverName)
+ }
+ defer m.wg.Done()
+
+ params := &mcp.CallToolParams{
+ Name: toolName,
+ Arguments: arguments,
+ }
+
+ result, err := conn.Session.CallTool(ctx, params)
+ if err != nil {
+ return nil, fmt.Errorf("failed to call tool: %w", err)
+ }
+
+ return result, nil
+}
+
+// Close closes all server connections
+func (m *Manager) Close() error {
+ // Use Swap to atomically set closed=true and get the previous value
+ // This prevents TOCTOU race with CallTool's closed check
+ if m.closed.Swap(true) {
+ return nil // already closed
+ }
+
+ // Wait for all in-flight CallTool calls to finish before closing sessions
+ // After closed=true is set, no new CallTool can start (they check closed first)
+ m.wg.Wait()
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ logger.InfoCF("mcp", "Closing all MCP server connections",
+ map[string]any{
+ "count": len(m.servers),
+ })
+
+ var errs []error
+ for name, conn := range m.servers {
+ if err := conn.Session.Close(); err != nil {
+ logger.ErrorCF("mcp", "Failed to close server connection",
+ map[string]any{
+ "server": name,
+ "error": err.Error(),
+ })
+ errs = append(errs, fmt.Errorf("server %s: %w", name, err))
+ }
+ }
+
+ m.servers = make(map[string]*ServerConnection)
+
+ if len(errs) > 0 {
+ return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...))
+ }
+
+ return nil
+}
+
+// GetAllTools returns all tools from all connected servers
+func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ result := make(map[string][]*mcp.Tool)
+ for name, conn := range m.servers {
+ if len(conn.Tools) > 0 {
+ result[name] = conn.Tools
+ }
+ }
+ return result
+}
diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go
new file mode 100644
index 000000000..8ce81d09e
--- /dev/null
+++ b/pkg/mcp/manager_test.go
@@ -0,0 +1,298 @@
+package mcp
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func TestLoadEnvFile(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ expected map[string]string
+ expectErr bool
+ }{
+ {
+ name: "basic env file",
+ content: `API_KEY=secret123
+DATABASE_URL=postgres://localhost/db
+PORT=8080`,
+ expected: map[string]string{
+ "API_KEY": "secret123",
+ "DATABASE_URL": "postgres://localhost/db",
+ "PORT": "8080",
+ },
+ expectErr: false,
+ },
+ {
+ name: "with comments and empty lines",
+ content: `# This is a comment
+API_KEY=secret123
+
+# Another comment
+DATABASE_URL=postgres://localhost/db
+
+PORT=8080`,
+ expected: map[string]string{
+ "API_KEY": "secret123",
+ "DATABASE_URL": "postgres://localhost/db",
+ "PORT": "8080",
+ },
+ expectErr: false,
+ },
+ {
+ name: "with quoted values",
+ content: `API_KEY="secret with spaces"
+NAME='single quoted'
+PLAIN=no-quotes`,
+ expected: map[string]string{
+ "API_KEY": "secret with spaces",
+ "NAME": "single quoted",
+ "PLAIN": "no-quotes",
+ },
+ expectErr: false,
+ },
+ {
+ name: "with spaces around equals",
+ content: `API_KEY = secret123
+DATABASE_URL= postgres://localhost/db
+PORT =8080`,
+ expected: map[string]string{
+ "API_KEY": "secret123",
+ "DATABASE_URL": "postgres://localhost/db",
+ "PORT": "8080",
+ },
+ expectErr: false,
+ },
+ {
+ name: "invalid format - no equals",
+ content: `INVALID_LINE`,
+ expectErr: true,
+ },
+ {
+ name: "empty file",
+ content: ``,
+ expected: map[string]string{},
+ expectErr: false,
+ },
+ {
+ name: "only comments",
+ content: `# Comment 1
+# Comment 2`,
+ expected: map[string]string{},
+ expectErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir := t.TempDir()
+ envFile := filepath.Join(tmpDir, ".env")
+
+ if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ result, err := loadEnvFile(envFile)
+
+ if tt.expectErr {
+ if err == nil {
+ t.Errorf("Expected error but got none")
+ }
+ return
+ }
+
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ return
+ }
+
+ if len(result) != len(tt.expected) {
+ t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
+ }
+
+ for key, expectedValue := range tt.expected {
+ if actualValue, ok := result[key]; !ok {
+ t.Errorf("Expected key %s not found", key)
+ } else if actualValue != expectedValue {
+ t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
+ }
+ }
+ })
+ }
+}
+
+func TestLoadEnvFileNotFound(t *testing.T) {
+ _, err := loadEnvFile("/nonexistent/file.env")
+ if err == nil {
+ t.Error("Expected error for nonexistent file")
+ }
+}
+
+func TestEnvFilePriority(t *testing.T) {
+ // Create a temporary .env file
+ tmpDir := t.TempDir()
+ envFile := filepath.Join(tmpDir, ".env")
+
+ envContent := `API_KEY=from_file
+DATABASE_URL=from_file
+SHARED_VAR=from_file`
+
+ if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil {
+ t.Fatalf("Failed to create .env file: %v", err)
+ }
+
+ // Load envFile
+ envVars, err := loadEnvFile(envFile)
+ if err != nil {
+ t.Fatalf("Failed to load env file: %v", err)
+ }
+
+ // Verify envFile variables
+ if envVars["API_KEY"] != "from_file" {
+ t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
+ }
+
+ // Simulate config.Env overriding envFile
+ configEnv := map[string]string{
+ "SHARED_VAR": "from_config",
+ "NEW_VAR": "from_config",
+ }
+
+ // Merge: envFile first, then config overrides
+ merged := make(map[string]string)
+ for k, v := range envVars {
+ merged[k] = v
+ }
+ for k, v := range configEnv {
+ merged[k] = v
+ }
+
+ // Verify priority: config.Env should override envFile
+ if merged["SHARED_VAR"] != "from_config" {
+ t.Errorf(
+ "Expected SHARED_VAR=from_config (config should override file), got %s",
+ merged["SHARED_VAR"],
+ )
+ }
+ if merged["API_KEY"] != "from_file" {
+ t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
+ }
+ if merged["NEW_VAR"] != "from_config" {
+ t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
+ }
+}
+
+func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
+ mgr := NewManager()
+
+ mcpCfg := config.MCPConfig{
+ Enabled: true,
+ Servers: map[string]config.MCPServerConfig{
+ "test-server": {
+ Enabled: true,
+ Command: "echo",
+ Args: []string{"ok"},
+ EnvFile: ".env",
+ },
+ },
+ }
+
+ err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "")
+ if err == nil {
+ t.Fatal("expected error for relative env_file with empty workspace path, got nil")
+ }
+
+ if !strings.Contains(err.Error(), "workspace path is empty") {
+ t.Fatalf("expected workspace path validation error, got: %v", err)
+ }
+}
+
+func TestNewManager_InitialState(t *testing.T) {
+ mgr := NewManager()
+ if mgr == nil {
+ t.Fatal("expected manager instance, got nil")
+ }
+ if len(mgr.GetServers()) != 0 {
+ t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers()))
+ }
+}
+
+func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
+ mgr := NewManager()
+
+ err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
+ if err != nil {
+ t.Fatalf("expected nil error when MCP disabled, got: %v", err)
+ }
+
+ err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
+ if err != nil {
+ t.Fatalf("expected nil error when no servers configured, got: %v", err)
+ }
+}
+
+func TestGetServers_ReturnsCopy(t *testing.T) {
+ mgr := NewManager()
+ mgr.servers["s1"] = &ServerConnection{Name: "s1"}
+
+ servers := mgr.GetServers()
+ delete(servers, "s1")
+
+ if _, ok := mgr.GetServer("s1"); !ok {
+ t.Fatal("expected internal manager state to remain unchanged")
+ }
+}
+
+func TestGetAllTools_FiltersEmptyTools(t *testing.T) {
+ mgr := NewManager()
+ mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil}
+ mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}}
+
+ all := mgr.GetAllTools()
+ if _, ok := all["empty"]; ok {
+ t.Fatal("expected server without tools to be excluded")
+ }
+ if _, ok := all["with-tools"]; !ok {
+ t.Fatal("expected server with tools to be included")
+ }
+}
+
+func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
+ t.Run("manager closed", func(t *testing.T) {
+ mgr := NewManager()
+ mgr.closed.Store(true)
+
+ _, err := mgr.CallTool(context.Background(), "s1", "tool", nil)
+ if err == nil || !strings.Contains(err.Error(), "manager is closed") {
+ t.Fatalf("expected manager closed error, got: %v", err)
+ }
+ })
+
+ t.Run("server missing", func(t *testing.T) {
+ mgr := NewManager()
+
+ _, err := mgr.CallTool(context.Background(), "missing", "tool", nil)
+ if err == nil || !strings.Contains(err.Error(), "not found") {
+ t.Fatalf("expected server not found error, got: %v", err)
+ }
+ })
+}
+
+func TestClose_IdempotentOnEmptyManager(t *testing.T) {
+ mgr := NewManager()
+
+ if err := mgr.Close(); err != nil {
+ t.Fatalf("first close should succeed, got: %v", err)
+ }
+ if err := mgr.Close(); err != nil {
+ t.Fatalf("second close should be idempotent, got: %v", err)
+ }
+}
diff --git a/pkg/media/store_test.go b/pkg/media/store_test.go
index 989f90d7c..1dcfdf350 100644
--- a/pkg/media/store_test.go
+++ b/pkg/media/store_test.go
@@ -49,7 +49,7 @@ func TestReleaseAll(t *testing.T) {
paths := make([]string, 3)
refs := make([]string, 3)
- for i := 0; i < 3; i++ {
+ for i := range 3 {
paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg")
var err error
refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1")
@@ -228,12 +228,12 @@ func TestConcurrentSafety(t *testing.T) {
var wg sync.WaitGroup
wg.Add(goroutines)
- for g := 0; g < goroutines; g++ {
+ for g := range goroutines {
go func(gIdx int) {
defer wg.Done()
scope := strings.Repeat("s", gIdx+1)
- for i := 0; i < filesPerGoroutine; i++ {
+ for i := range filesPerGoroutine {
path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp")
ref, err := store.Store(path, MediaMeta{Source: "test"}, scope)
if err != nil {
@@ -448,11 +448,11 @@ func TestConcurrentCleanupSafety(t *testing.T) {
wg.Add(workers * 4)
// Store workers
- for w := 0; w < workers; w++ {
+ for w := range workers {
go func(wIdx int) {
defer wg.Done()
scope := fmt.Sprintf("scope-%d", wIdx)
- for i := 0; i < ops; i++ {
+ for i := range ops {
p := createTempFile(t, dir, fmt.Sprintf("w%d-f%d.tmp", wIdx, i))
store.Store(p, MediaMeta{Source: "test"}, scope)
}
@@ -460,30 +460,30 @@ func TestConcurrentCleanupSafety(t *testing.T) {
}
// Resolve workers
- for w := 0; w < workers; w++ {
+ for range workers {
go func() {
defer wg.Done()
- for i := 0; i < ops; i++ {
+ for range ops {
store.Resolve("media://nonexistent")
}
}()
}
// ReleaseAll workers
- for w := 0; w < workers; w++ {
+ for w := range workers {
go func(wIdx int) {
defer wg.Done()
- for i := 0; i < ops; i++ {
+ for range ops {
store.ReleaseAll(fmt.Sprintf("scope-%d", wIdx))
}
}(w)
}
// CleanExpired workers
- for w := 0; w < workers; w++ {
+ for range workers {
go func() {
defer wg.Done()
- for i := 0; i < ops; i++ {
+ for range ops {
store.CleanExpired()
}
}()
diff --git a/pkg/migrate/internal/common_test.go b/pkg/migrate/internal/common_test.go
index a089157f5..a67293c19 100644
--- a/pkg/migrate/internal/common_test.go
+++ b/pkg/migrate/internal/common_test.go
@@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) {
assert.GreaterOrEqual(t, len(actions), 1)
}
-func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) {
- tmpDir := t.TempDir()
- srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
- dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
+func TestPlanWorkspaceMigrationExistingFile(t *testing.T) {
+ tests := []struct {
+ name string
+ force bool
+ wantActionType ActionType
+ }{
+ {
+ name: "backup when not forced",
+ force: false,
+ wantActionType: ActionBackup,
+ },
+ {
+ name: "copy when forced",
+ force: true,
+ wantActionType: ActionCopy,
+ },
+ }
- err := os.MkdirAll(srcWorkspace, 0o755)
- require.NoError(t, err)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir := t.TempDir()
+ srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
+ dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
- err = os.MkdirAll(dstWorkspace, 0o755)
- require.NoError(t, err)
+ err := os.MkdirAll(srcWorkspace, 0o755)
+ require.NoError(t, err)
- err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
- require.NoError(t, err)
+ err = os.MkdirAll(dstWorkspace, 0o755)
+ require.NoError(t, err)
- err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
- require.NoError(t, err)
+ err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
+ require.NoError(t, err)
- actions, err := PlanWorkspaceMigration(
- srcWorkspace,
- dstWorkspace,
- []string{"file1.txt"},
- []string{},
- false,
- )
- require.NoError(t, err)
+ err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
+ require.NoError(t, err)
- require.GreaterOrEqual(t, len(actions), 1)
- assert.Equal(t, ActionBackup, actions[0].Type)
-}
+ actions, err := PlanWorkspaceMigration(
+ srcWorkspace,
+ dstWorkspace,
+ []string{"file1.txt"},
+ []string{},
+ tt.force,
+ )
+ require.NoError(t, err)
-func TestPlanWorkspaceMigrationForce(t *testing.T) {
- tmpDir := t.TempDir()
- srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
- dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
-
- err := os.MkdirAll(srcWorkspace, 0o755)
- require.NoError(t, err)
-
- err = os.MkdirAll(dstWorkspace, 0o755)
- require.NoError(t, err)
-
- err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
- require.NoError(t, err)
-
- err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
- require.NoError(t, err)
-
- actions, err := PlanWorkspaceMigration(
- srcWorkspace,
- dstWorkspace,
- []string{"file1.txt"},
- []string{},
- true,
- )
- require.NoError(t, err)
-
- require.GreaterOrEqual(t, len(actions), 1)
- assert.Equal(t, ActionCopy, actions[0].Type)
+ require.GreaterOrEqual(t, len(actions), 1)
+ assert.Equal(t, tt.wantActionType, actions[0].Type)
+ })
+ }
}
func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go
index 9162174c9..1bb15f771 100644
--- a/pkg/providers/anthropic/provider.go
+++ b/pkg/providers/anthropic/provider.go
@@ -212,14 +212,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
}
func parseResponse(resp *anthropic.Message) *LLMResponse {
- var content string
+ var content strings.Builder
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
- content += tb.Text
+ content.WriteString(tb.Text)
case "tool_use":
tu := block.AsToolUse()
var args map[string]any
@@ -246,7 +246,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
}
return &LLMResponse{
- Content: content,
+ Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
@@ -264,8 +264,8 @@ func normalizeBaseURL(apiBase string) string {
}
base = strings.TrimRight(base, "/")
- if strings.HasSuffix(base, "/v1") {
- base = strings.TrimSuffix(base, "/v1")
+ if before, ok := strings.CutSuffix(base, "/v1"); ok {
+ base = before
}
if base == "" {
return defaultBaseURL
diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go
index 74ec33b98..6c4f6a767 100644
--- a/pkg/providers/claude_cli_provider.go
+++ b/pkg/providers/claude_cli_provider.go
@@ -100,44 +100,12 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe
}
if len(tools) > 0 {
- parts = append(parts, p.buildToolsPrompt(tools))
+ parts = append(parts, buildCLIToolsPrompt(tools))
}
return strings.Join(parts, "\n\n")
}
-// buildToolsPrompt creates the tool definitions section for the system prompt.
-func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
- var sb strings.Builder
-
- sb.WriteString("## Available Tools\n\n")
- sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
- sb.WriteString("```json\n")
- sb.WriteString(
- `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
- )
- sb.WriteString("\n```\n\n")
- sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
- sb.WriteString("### Tool Definitions:\n\n")
-
- for _, tool := range tools {
- if tool.Type != "function" {
- continue
- }
- sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
- if tool.Function.Description != "" {
- sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
- }
- if len(tool.Function.Parameters) > 0 {
- paramsJSON, _ := json.Marshal(tool.Function.Parameters)
- sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
- }
- sb.WriteString("\n")
- }
-
- return sb.String()
-}
-
// parseClaudeCliResponse parses the JSON output from the claude CLI.
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
var resp claudeCliJSONResponse
diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go
index 3a3cafaca..d4d648f5a 100644
--- a/pkg/providers/claude_cli_provider_test.go
+++ b/pkg/providers/claude_cli_provider_test.go
@@ -660,12 +660,11 @@ func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
// --- buildToolsPrompt tests ---
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
- p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
}
- got := p.buildToolsPrompt(tools)
+ got := buildCLIToolsPrompt(tools)
if strings.Contains(got, "skip_me") {
t.Error("buildToolsPrompt() should skip non-function tools")
}
@@ -675,11 +674,10 @@ func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
}
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
- p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
}
- got := p.buildToolsPrompt(tools)
+ got := buildCLIToolsPrompt(tools)
if !strings.Contains(got, "bare_tool") {
t.Error("should include tool name")
}
@@ -689,14 +687,13 @@ func TestBuildToolsPrompt_NoDescription(t *testing.T) {
}
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
- p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{
Name: "no_params_tool",
Description: "A tool with no parameters",
}},
}
- got := p.buildToolsPrompt(tools)
+ got := buildCLIToolsPrompt(tools)
if strings.Contains(got, "Parameters:") {
t.Error("should not include Parameters: section when nil")
}
diff --git a/pkg/providers/codex_cli_provider.go b/pkg/providers/codex_cli_provider.go
index 4c783ece5..13f53ad9e 100644
--- a/pkg/providers/codex_cli_provider.go
+++ b/pkg/providers/codex_cli_provider.go
@@ -115,7 +115,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
}
if len(tools) > 0 {
- sb.WriteString(p.buildToolsPrompt(tools))
+ sb.WriteString(buildCLIToolsPrompt(tools))
sb.WriteString("\n\n")
}
@@ -128,38 +128,6 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
return sb.String()
}
-// buildToolsPrompt creates a tool definitions section for the prompt.
-func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
- var sb strings.Builder
-
- sb.WriteString("## Available Tools\n\n")
- sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
- sb.WriteString("```json\n")
- sb.WriteString(
- `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
- )
- sb.WriteString("\n```\n\n")
- sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
- sb.WriteString("### Tool Definitions:\n\n")
-
- for _, tool := range tools {
- if tool.Type != "function" {
- continue
- }
- sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
- if tool.Function.Description != "" {
- sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
- }
- if len(tool.Function.Parameters) > 0 {
- paramsJSON, _ := json.Marshal(tool.Function.Parameters)
- sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
- }
- sb.WriteString("\n")
- }
-
- return sb.String()
-}
-
// codexEvent represents a single JSONL event from `codex exec --json`.
type codexEvent struct {
Type string `json:"type"`
diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go
index dcc740ba4..47618300a 100644
--- a/pkg/providers/codex_provider.go
+++ b/pkg/providers/codex_provider.go
@@ -163,8 +163,8 @@ func resolveCodexModel(model string) (string, string) {
return codexDefaultModel, "empty model"
}
- if strings.HasPrefix(m, "openai/") {
- m = strings.TrimPrefix(m, "openai/")
+ if after, ok := strings.CutPrefix(m, "openai/"); ok {
+ m = after
} else if strings.Contains(m, "/") {
return codexDefaultModel, "non-openai model namespace"
}
diff --git a/pkg/providers/cooldown_test.go b/pkg/providers/cooldown_test.go
index 47f43ad5c..b517e7feb 100644
--- a/pkg/providers/cooldown_test.go
+++ b/pkg/providers/cooldown_test.go
@@ -138,7 +138,7 @@ func TestCooldown_FailureWindowReset(t *testing.T) {
ct, current := newTestTracker(now)
// 4 errors → 1h cooldown
- for i := 0; i < 4; i++ {
+ for range 4 {
ct.MarkFailure("openai", FailoverRateLimit)
*current = current.Add(2 * time.Second) // small advance between errors
}
@@ -230,7 +230,7 @@ func TestCooldown_ConcurrentAccess(t *testing.T) {
ct := NewCooldownTracker()
var wg sync.WaitGroup
- for i := 0; i < 100; i++ {
+ for range 100 {
wg.Add(3)
go func() {
defer wg.Done()
diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go
index a0f003006..fd9bf1e81 100644
--- a/pkg/providers/error_classifier.go
+++ b/pkg/providers/error_classifier.go
@@ -6,6 +6,13 @@ import (
"strings"
)
+// Common patterns in Go HTTP error messages
+var httpStatusPatterns = []*regexp.Regexp{
+ regexp.MustCompile(`status[:\s]+(\d{3})`),
+ regexp.MustCompile(`http[/\s]+\d*\.?\d*\s+(\d{3})`),
+ regexp.MustCompile(`\b([3-5]\d{2})\b`),
+}
+
// errorPattern defines a single pattern (string or regex) for error classification.
type errorPattern struct {
substring string
@@ -198,20 +205,13 @@ func classifyByMessage(msg string) FailoverReason {
}
// extractHTTPStatus extracts an HTTP status code from an error message.
-// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
+// Looks for patterns like "status: 429", "status 429", "http/1.1 429", "http 429", or standalone "429".
func extractHTTPStatus(msg string) int {
- // Common patterns in Go HTTP error messages
- patterns := []*regexp.Regexp{
- regexp.MustCompile(`status[:\s]+(\d{3})`),
- regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
- }
-
- for _, p := range patterns {
+ for _, p := range httpStatusPatterns {
if m := p.FindStringSubmatch(msg); len(m) > 1 {
return parseDigits(m[1])
}
}
-
return 0
}
diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go
index 865aea57a..67d9af62b 100644
--- a/pkg/providers/error_classifier_test.go
+++ b/pkg/providers/error_classifier_test.go
@@ -305,7 +305,8 @@ func TestExtractHTTPStatus(t *testing.T) {
}{
{"status: 429 rate limited", 429},
{"status 401 unauthorized", 401},
- {"HTTP/1.1 502 Bad Gateway", 502},
+ {"http/1.1 502 bad gateway", 502},
+ {"error 429", 429},
{"no status code here", 0},
{"random number 12345", 0},
}
diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go
index 11af14da4..5b3e42b9e 100644
--- a/pkg/providers/factory.go
+++ b/pkg/providers/factory.go
@@ -102,6 +102,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://openrouter.ai/api/v1"
}
}
+ case "litellm":
+ if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" {
+ sel.apiKey = cfg.Providers.LiteLLM.APIKey
+ sel.apiBase = cfg.Providers.LiteLLM.APIBase
+ sel.proxy = cfg.Providers.LiteLLM.Proxy
+ if sel.apiBase == "" {
+ sel.apiBase = "http://localhost:4000/v1"
+ }
+ }
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go
index 53f7a08a0..155317a3b 100644
--- a/pkg/providers/factory_provider.go
+++ b/pkg/providers/factory_provider.go
@@ -53,7 +53,7 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
-// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
+// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot
// Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil {
@@ -92,7 +92,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
- case "openrouter", "groq", "zhipu", "gemini", "nvidia",
+ case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"volcengine", "vllm", "qwen", "mistral":
// All other OpenAI-compatible HTTP providers
@@ -180,6 +180,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://api.openai.com/v1"
case "openrouter":
return "https://openrouter.ai/api/v1"
+ case "litellm":
+ return "http://localhost:4000/v1"
case "groq":
return "https://api.groq.com/openai/v1"
case "zhipu":
diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go
index e0c0eddef..78389f331 100644
--- a/pkg/providers/factory_provider_test.go
+++ b/pkg/providers/factory_provider_test.go
@@ -135,6 +135,32 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
}
}
+func TestGetDefaultAPIBase_LiteLLM(t *testing.T) {
+ if got := getDefaultAPIBase("litellm"); got != "http://localhost:4000/v1" {
+ t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "litellm", got, "http://localhost:4000/v1")
+ }
+}
+
+func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
+ cfg := &config.ModelConfig{
+ ModelName: "test-litellm",
+ Model: "litellm/my-proxy-alias",
+ APIKey: "test-key",
+ APIBase: "http://localhost:4000/v1",
+ }
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+ if provider == nil {
+ t.Fatal("CreateProviderFromConfig() returned nil provider")
+ }
+ if modelID != "my-proxy-alias" {
+ t.Errorf("modelID = %q, want %q", modelID, "my-proxy-alias")
+ }
+}
+
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-anthropic",
diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go
index 5680f23b3..f7a916d9e 100644
--- a/pkg/providers/factory_test.go
+++ b/pkg/providers/factory_test.go
@@ -17,6 +17,27 @@ func TestResolveProviderSelection(t *testing.T) {
wantProxy string
wantErrSubstr string
}{
+ {
+ name: "explicit litellm provider uses configured base",
+ setup: func(cfg *config.Config) {
+ cfg.Agents.Defaults.Provider = "litellm"
+ cfg.Providers.LiteLLM.APIKey = "litellm-key"
+ cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1"
+ cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890"
+ },
+ wantType: providerTypeHTTPCompat,
+ wantAPIBase: "http://localhost:4000/v1",
+ wantProxy: "http://127.0.0.1:7890",
+ },
+ {
+ name: "explicit litellm provider defaults base when only key is configured",
+ setup: func(cfg *config.Config) {
+ cfg.Agents.Defaults.Provider = "litellm"
+ cfg.Providers.LiteLLM.APIKey = "litellm-key"
+ },
+ wantType: providerTypeHTTPCompat,
+ wantAPIBase: "http://localhost:4000/v1",
+ },
{
name: "explicit claude-cli provider routes to cli provider type",
setup: func(cfg *config.Config) {
diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go
index 3fb15db2f..6d642b2b5 100644
--- a/pkg/providers/github_copilot_provider.go
+++ b/pkg/providers/github_copilot_provider.go
@@ -26,8 +26,9 @@ func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*Gi
switch connectMode {
case "stdio":
- // TODO:
- return nil, fmt.Errorf("stdio mode not implemented")
+ // TODO: Implement stdio mode for GitHub Copilot provider
+ // See https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md for details
+ return nil, fmt.Errorf("stdio mode not implemented for GitHub Copilot provider; please use 'grpc' mode instead")
case "grpc":
client := copilot.NewClient(&copilot.ClientOptions{
CLIUrl: uri,
@@ -100,9 +101,12 @@ func (p *GitHubCopilotProvider) Chat(
return nil, fmt.Errorf("provider closed")
}
- resp, _ := session.SendAndWait(ctx, copilot.MessageOptions{
+ resp, err := session.SendAndWait(ctx, copilot.MessageOptions{
Prompt: string(fullcontent),
})
+ if err != nil {
+ return nil, fmt.Errorf("failed to send message to copilot: %w", err)
+ }
if resp == nil {
return nil, fmt.Errorf("empty response from copilot")
diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go
index 604331185..ff9109e96 100644
--- a/pkg/providers/openai_compat/provider.go
+++ b/pkg/providers/openai_compat/provider.go
@@ -116,7 +116,7 @@ func (p *Provider) Chat(
requestBody := map[string]any{
"model": model,
- "messages": stripSystemParts(messages),
+ "messages": serializeMessages(messages),
}
if len(tools) > 0 {
@@ -296,26 +296,62 @@ type openaiMessage struct {
ToolCallID string `json:"tool_call_id,omitempty"`
}
-// stripSystemParts converts []Message to []openaiMessage, dropping the
-// SystemParts field so it doesn't leak into the JSON payload sent to
-// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
-func stripSystemParts(messages []Message) []openaiMessage {
- out := make([]openaiMessage, len(messages))
- for i, m := range messages {
- out[i] = openaiMessage{
- Role: m.Role,
- Content: m.Content,
- ReasoningContent: m.ReasoningContent,
- ToolCalls: m.ToolCalls,
- ToolCallID: m.ToolCallID,
+// serializeMessages converts internal Message structs to the OpenAI wire format.
+// - Strips SystemParts (unknown to third-party endpoints)
+// - Converts messages with Media to multipart content format (text + image_url parts)
+// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
+func serializeMessages(messages []Message) []any {
+ out := make([]any, 0, len(messages))
+ for _, m := range messages {
+ if len(m.Media) == 0 {
+ out = append(out, openaiMessage{
+ Role: m.Role,
+ Content: m.Content,
+ ReasoningContent: m.ReasoningContent,
+ ToolCalls: m.ToolCalls,
+ ToolCallID: m.ToolCallID,
+ })
+ continue
}
+
+ // Multipart content format for messages with media
+ parts := make([]map[string]any, 0, 1+len(m.Media))
+ if m.Content != "" {
+ parts = append(parts, map[string]any{
+ "type": "text",
+ "text": m.Content,
+ })
+ }
+ for _, mediaURL := range m.Media {
+ parts = append(parts, map[string]any{
+ "type": "image_url",
+ "image_url": map[string]any{
+ "url": mediaURL,
+ },
+ })
+ }
+
+ msg := map[string]any{
+ "role": m.Role,
+ "content": parts,
+ }
+ if m.ToolCallID != "" {
+ msg["tool_call_id"] = m.ToolCallID
+ }
+ if len(m.ToolCalls) > 0 {
+ msg["tool_calls"] = m.ToolCalls
+ }
+ if m.ReasoningContent != "" {
+ msg["reasoning_content"] = m.ReasoningContent
+ }
+ out = append(out, msg)
}
return out
}
func normalizeModel(model, apiBase string) string {
- idx := strings.Index(model, "/")
- if idx == -1 {
+ before, after, ok := strings.Cut(model, "/")
+ if !ok {
return model
}
@@ -323,10 +359,10 @@ func normalizeModel(model, apiBase string) string {
return model
}
- prefix := strings.ToLower(model[:idx])
+ prefix := strings.ToLower(before)
switch prefix {
- case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
- return model[idx+1:]
+ case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
+ return after
default:
return model
}
diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go
index 8fe936f29..174bcf00d 100644
--- a/pkg/providers/openai_compat/provider_test.go
+++ b/pkg/providers/openai_compat/provider_test.go
@@ -5,8 +5,11 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
@@ -146,6 +149,56 @@ func TestProviderChat_ParsesReasoningContent(t *testing.T) {
}
}
+func TestProviderChat_PreservesReasoningContentInHistory(t *testing.T) {
+ var requestBody map[string]any
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ resp := map[string]any{
+ "choices": []map[string]any{
+ {
+ "message": map[string]any{"content": "ok"},
+ "finish_reason": "stop",
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer server.Close()
+
+ p := NewProvider("key", server.URL, "")
+
+ // Simulate a multi-turn conversation where the assistant's previous
+ // reply included reasoning_content (e.g. from kimi-k2.5).
+ messages := []Message{
+ {Role: "user", Content: "What is 1+1?"},
+ {Role: "assistant", Content: "2", ReasoningContent: "Let me think... 1+1=2"},
+ {Role: "user", Content: "What about 2+2?"},
+ }
+
+ _, err := p.Chat(t.Context(), messages, nil, "kimi-k2.5", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ // Verify reasoning_content is preserved in the serialized request.
+ reqMessages, ok := requestBody["messages"].([]any)
+ if !ok {
+ t.Fatalf("messages is not []any: %T", requestBody["messages"])
+ }
+ assistantMsg, ok := reqMessages[1].(map[string]any)
+ if !ok {
+ t.Fatalf("assistant message is not map[string]any: %T", reqMessages[1])
+ }
+ if assistantMsg["reasoning_content"] != "Let me think... 1+1=2" {
+ t.Errorf("reasoning_content not preserved in request, got %v", assistantMsg["reasoning_content"])
+ }
+}
+
func TestProviderChat_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
@@ -206,6 +259,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
input string
wantModel string
}{
+ {
+ name: "strips litellm prefix and preserves proxy model name",
+ input: "litellm/my-proxy-alias",
+ wantModel: "my-proxy-alias",
+ },
{
name: "strips groq prefix and keeps nested model",
input: "groq/openai/gpt-oss-120b",
@@ -362,61 +420,96 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
}
}
-// TestStripSystemParts_PreservesReasoningContent verifies that reasoning_content
-// is preserved in the wire message format when present, and omitted when empty.
-// Regression test for: Kimi K2 API returning 400 "reasoning_content is missing".
-func TestStripSystemParts_PreservesReasoningContent(t *testing.T) {
- messages := []Message{
- {Role: "user", Content: "What is 1+1?"},
+func TestSerializeMessages_PlainText(t *testing.T) {
+ messages := []protocoltypes.Message{
+ {Role: "user", Content: "hello"},
+ {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
+ }
+ result := serializeMessages(messages)
+
+ data, err := json.Marshal(result)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ if msgs[0]["content"] != "hello" {
+ t.Fatalf("expected plain string content, got %v", msgs[0]["content"])
+ }
+ if msgs[1]["reasoning_content"] != "thinking..." {
+ t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
+ }
+}
+
+func TestSerializeMessages_WithMedia(t *testing.T) {
+ messages := []protocoltypes.Message{
+ {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
+ }
+ result := serializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ content, ok := msgs[0]["content"].([]any)
+ if !ok {
+ t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
+ }
+ if len(content) != 2 {
+ t.Fatalf("expected 2 content parts, got %d", len(content))
+ }
+
+ textPart := content[0].(map[string]any)
+ if textPart["type"] != "text" || textPart["text"] != "describe this" {
+ t.Fatalf("text part mismatch: %v", textPart)
+ }
+
+ imgPart := content[1].(map[string]any)
+ if imgPart["type"] != "image_url" {
+ t.Fatalf("expected image_url type, got %v", imgPart["type"])
+ }
+ imgURL := imgPart["image_url"].(map[string]any)
+ if imgURL["url"] != "data:image/png;base64,abc123" {
+ t.Fatalf("image url mismatch: %v", imgURL["url"])
+ }
+}
+
+func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
+ messages := []protocoltypes.Message{
+ {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
+ }
+ result := serializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ if msgs[0]["tool_call_id"] != "call_1" {
+ t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"])
+ }
+ // Content should be multipart array
+ if _, ok := msgs[0]["content"].([]any); !ok {
+ t.Fatalf("expected array content, got %T", msgs[0]["content"])
+ }
+}
+
+func TestSerializeMessages_StripsSystemParts(t *testing.T) {
+ messages := []protocoltypes.Message{
{
- Role: "assistant",
- Content: "The answer is 2",
- ReasoningContent: "Let me think step by step... 1+1=2",
+ Role: "system",
+ Content: "you are helpful",
+ SystemParts: []protocoltypes.ContentBlock{
+ {Type: "text", Text: "you are helpful"},
+ },
},
- {Role: "user", Content: "Thanks"},
}
+ result := serializeMessages(messages)
- result := stripSystemParts(messages)
-
- if len(result) != 3 {
- t.Fatalf("len(result) = %d, want 3", len(result))
- }
-
- // Assistant message should preserve reasoning_content
- if result[1].ReasoningContent != "Let me think step by step... 1+1=2" {
- t.Errorf("ReasoningContent = %q, want %q", result[1].ReasoningContent, "Let me think step by step... 1+1=2")
- }
-
- // Verify it serializes to JSON correctly
- data, err := json.Marshal(result[1])
- if err != nil {
- t.Fatalf("json.Marshal error: %v", err)
- }
-
- jsonStr := string(data)
- if !contains(jsonStr, `"reasoning_content"`) {
- t.Errorf("JSON should contain reasoning_content field, got: %s", jsonStr)
- }
-
- // User message should have empty reasoning_content (omitted via omitempty)
- data2, err := json.Marshal(result[0])
- if err != nil {
- t.Fatalf("json.Marshal error: %v", err)
- }
- if contains(string(data2), `"reasoning_content"`) {
- t.Errorf("JSON should omit empty reasoning_content, got: %s", string(data2))
+ data, _ := json.Marshal(result)
+ raw := string(data)
+ if strings.Contains(raw, "system_parts") {
+ t.Fatal("system_parts should not appear in serialized output")
}
}
-
-func contains(s, substr string) bool {
- return len(s) >= len(substr) && searchString(s, substr)
-}
-
-func searchString(s, substr string) bool {
- for i := 0; i+len(substr) <= len(s); i++ {
- if s[i:i+len(substr)] == substr {
- return true
- }
- }
- return false
-}
diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go
index 99f13334e..194c1aa6f 100644
--- a/pkg/providers/protocoltypes/types.go
+++ b/pkg/providers/protocoltypes/types.go
@@ -65,6 +65,7 @@ type ContentBlock struct {
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
+ Media []string `json:"media,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go
index 49218b1b1..a33e1eb5c 100644
--- a/pkg/providers/toolcall_utils.go
+++ b/pkg/providers/toolcall_utils.go
@@ -5,7 +5,43 @@
package providers
-import "encoding/json"
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+)
+
+// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
+func buildCLIToolsPrompt(tools []ToolDefinition) string {
+ var sb strings.Builder
+
+ sb.WriteString("## Available Tools\n\n")
+ sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
+ sb.WriteString("```json\n")
+ sb.WriteString(
+ `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
+ )
+ sb.WriteString("\n```\n\n")
+ sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
+ sb.WriteString("### Tool Definitions:\n\n")
+
+ for _, tool := range tools {
+ if tool.Type != "function" {
+ continue
+ }
+ sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
+ if tool.Function.Description != "" {
+ sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
+ }
+ if len(tool.Function.Parameters) > 0 {
+ paramsJSON, _ := json.Marshal(tool.Function.Parameters)
+ sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
+ }
+ sb.WriteString("\n")
+ }
+
+ return sb.String()
+}
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)
diff --git a/pkg/routing/agent_id_test.go b/pkg/routing/agent_id_test.go
index 050fe0645..f9a65c969 100644
--- a/pkg/routing/agent_id_test.go
+++ b/pkg/routing/agent_id_test.go
@@ -1,6 +1,9 @@
package routing
-import "testing"
+import (
+ "strings"
+ "testing"
+)
func TestNormalizeAgentID_Empty(t *testing.T) {
if got := NormalizeAgentID(""); got != DefaultAgentID {
@@ -57,11 +60,11 @@ func TestNormalizeAgentID_AllInvalid(t *testing.T) {
}
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
- long := ""
- for i := 0; i < 100; i++ {
- long += "a"
+ var long strings.Builder
+ for range 100 {
+ long.WriteString("a")
}
- got := NormalizeAgentID(long)
+ got := NormalizeAgentID(long.String())
if len(got) > MaxAgentIDLength {
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
}
diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go
index 20f6a49d9..c9f19f25d 100644
--- a/pkg/skills/installer.go
+++ b/pkg/skills/installer.go
@@ -2,7 +2,6 @@ package skills
import (
"context"
- "encoding/json"
"fmt"
"io"
"net/http"
@@ -18,14 +17,6 @@ type SkillInstaller struct {
workspace string
}
-type AvailableSkill struct {
- Name string `json:"name"`
- Repository string `json:"repository"`
- Description string `json:"description"`
- Author string `json:"author"`
- Tags []string `json:"tags"`
-}
-
func NewSkillInstaller(workspace string) *SkillInstaller {
return &SkillInstaller{
workspace: workspace,
@@ -89,35 +80,3 @@ func (si *SkillInstaller) Uninstall(skillName string) error {
return nil
}
-
-func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableSkill, error) {
- url := "https://raw.githubusercontent.com/sipeed/picoclaw-skills/main/skills.json"
-
- client := &http.Client{Timeout: 15 * time.Second}
- req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
-
- resp, err := utils.DoRequestWithRetry(client, req)
- if err != nil {
- return nil, fmt.Errorf("failed to fetch skills list: %w", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- return nil, fmt.Errorf("failed to fetch skills list: HTTP %d", resp.StatusCode)
- }
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("failed to read response: %w", err)
- }
-
- var skills []AvailableSkill
- if err := json.Unmarshal(body, &skills); err != nil {
- return nil, fmt.Errorf("failed to parse skills list: %w", err)
- }
-
- return skills, nil
-}
diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go
index 67d3e70e0..30d84635a 100644
--- a/pkg/skills/loader.go
+++ b/pkg/skills/loader.go
@@ -64,6 +64,29 @@ type SkillsLoader struct {
builtinSkills string // builtin skills
}
+// SkillRoots returns all unique skill root directories used by this loader.
+// The order follows resolution priority: workspace > global > builtin.
+func (sl *SkillsLoader) SkillRoots() []string {
+ roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills}
+ seen := make(map[string]struct{}, len(roots))
+ out := make([]string, 0, len(roots))
+
+ for _, root := range roots {
+ trimmed := strings.TrimSpace(root)
+ if trimmed == "" {
+ continue
+ }
+ clean := filepath.Clean(trimmed)
+ if _, ok := seen[clean]; ok {
+ continue
+ }
+ seen[clean] = struct{}{}
+ out = append(out, clean)
+ }
+
+ return out
+}
+
func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader {
return &SkillsLoader{
workspace: workspace,
@@ -240,7 +263,7 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
normalized := strings.ReplaceAll(content, "\r\n", "\n")
normalized = strings.ReplaceAll(normalized, "\r", "\n")
- for _, line := range strings.Split(normalized, "\n") {
+ for line := range strings.SplitSeq(normalized, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go
index 9428bea62..31619f9c2 100644
--- a/pkg/skills/loader_test.go
+++ b/pkg/skills/loader_test.go
@@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) {
})
}
}
+
+func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) {
+ tmp := t.TempDir()
+ workspace := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+ builtin := filepath.Join(tmp, "builtin")
+
+ sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n")
+ roots := sl.SkillRoots()
+
+ assert.Equal(t, []string{
+ filepath.Join(workspace, "skills"),
+ global,
+ builtin,
+ }, roots)
+}
diff --git a/pkg/skills/search_cache.go b/pkg/skills/search_cache.go
index 5d7d2797e..1686e3f98 100644
--- a/pkg/skills/search_cache.go
+++ b/pkg/skills/search_cache.go
@@ -1,7 +1,7 @@
package skills
import (
- "sort"
+ "slices"
"strings"
"sync"
"time"
@@ -183,7 +183,7 @@ func buildTrigrams(s string) []uint32 {
}
// Sort and Deduplication
- sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
+ slices.Sort(trigrams)
n := 1
for i := 1; i < len(trigrams); i++ {
if trigrams[i] != trigrams[i-1] {
diff --git a/pkg/skills/search_cache_test.go b/pkg/skills/search_cache_test.go
index 816bdfb93..6bbb0e6eb 100644
--- a/pkg/skills/search_cache_test.go
+++ b/pkg/skills/search_cache_test.go
@@ -153,7 +153,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
// Concurrent writes
go func() {
- for i := 0; i < 100; i++ {
+ for i := range 100 {
cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
}
done <- struct{}{}
@@ -161,7 +161,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
// Concurrent reads
go func() {
- for i := 0; i < 100; i++ {
+ for range 100 {
cache.Get("query-write-a")
}
done <- struct{}{}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index 1663faa4c..57f371f12 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -40,7 +40,9 @@ func NewManager(workspace string) *Manager {
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
- os.MkdirAll(stateDir, 0o755)
+ if err := os.MkdirAll(stateDir, 0o755); err != nil {
+ log.Fatalf("[FATAL] state: failed to create state directory: %v", err)
+ }
sm := &Manager{
workspace: workspace,
@@ -54,13 +56,17 @@ func NewManager(workspace string) *Manager {
if data, err := os.ReadFile(oldStateFile); err == nil {
if err := json.Unmarshal(data, sm.state); err == nil {
// Migrate to new location
- sm.saveAtomic()
+ if err := sm.saveAtomic(); err != nil {
+ log.Printf("[WARN] state: failed to save state: %v", err)
+ }
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
}
}
} else {
// Load from new location
- sm.load()
+ if err := sm.load(); err != nil {
+ log.Printf("[WARN] state: failed to load state: %v", err)
+ }
}
return sm
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
index f717a5bb4..e5e116ef6 100644
--- a/pkg/state/state_test.go
+++ b/pkg/state/state_test.go
@@ -2,8 +2,10 @@ package state
import (
"encoding/json"
+ "errors"
"fmt"
"os"
+ "os/exec"
"path/filepath"
"testing"
)
@@ -135,7 +137,7 @@ func TestConcurrentAccess(t *testing.T) {
// Test concurrent writes
done := make(chan bool, 10)
- for i := 0; i < 10; i++ {
+ for i := range 10 {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
@@ -144,7 +146,7 @@ func TestConcurrentAccess(t *testing.T) {
}
// Wait for all goroutines to complete
- for i := 0; i < 10; i++ {
+ for range 10 {
<-done
}
@@ -214,3 +216,39 @@ func TestNewManager_EmptyWorkspace(t *testing.T) {
t.Error("Expected zero timestamp for new state")
}
}
+
+func TestNewManager_MkdirFailureCrashes(t *testing.T) {
+ // Since log.Fatalf calls os.Exit(1), we cannot test it normally
+ // Otherwise, the test suite would stop altogether.
+ // We use the standard pattern of Go: rerun this test in a subprocess.
+ if os.Getenv("BE_CRASHER") == "1" {
+ tmpDir := os.Getenv("CRASH_DIR")
+
+ statePath := filepath.Join(tmpDir, "state")
+ if err := os.WriteFile(statePath, []byte("I'm a file, not a folder"), 0o644); err != nil {
+ fmt.Printf("setup failed: %v", err)
+ os.Exit(0)
+ }
+
+ NewManager(tmpDir)
+ os.Exit(0)
+ }
+
+ tmpDir, err := os.MkdirTemp("", "state-crash-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureCrashes")
+ cmd.Env = append(os.Environ(), "BE_CRASHER=1", "CRASH_DIR="+tmpDir)
+
+ err = cmd.Run()
+
+ var e *exec.ExitError
+ if errors.As(err, &e) && !e.Success() {
+ return
+ }
+
+ t.Fatalf("The process ended without error, a crash was expected via os.Exit(1). Err: %v", err)
+}
diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go
index f91397c5d..6888d1326 100644
--- a/pkg/tools/cron.go
+++ b/pkg/tools/cron.go
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
+ "strings"
"sync"
"time"
@@ -222,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
return SilentResult("No scheduled jobs")
}
- result := "Scheduled jobs:\n"
+ var result strings.Builder
+ result.WriteString("Scheduled jobs:\n")
for _, j := range jobs {
var scheduleInfo string
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
@@ -234,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
} else {
scheduleInfo = "unknown"
}
- result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
+ result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
}
- return SilentResult(result)
+ return SilentResult(result.String())
}
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go
index d3ab267bf..d5bebf4a2 100644
--- a/pkg/tools/edit.go
+++ b/pkg/tools/edit.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/fs"
+ "regexp"
"strings"
)
@@ -15,14 +16,12 @@ type EditFileTool struct {
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
-func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &EditFileTool{fs: fs}
+ return &EditFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *EditFileTool) Name() string {
@@ -80,14 +79,12 @@ type AppendFileTool struct {
fs fileSystem
}
-func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &AppendFileTool{fs: fs}
+ return &AppendFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *AppendFileTool) Name() string {
diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go
index 03d461dcc..cd8da3195 100644
--- a/pkg/tools/filesystem.go
+++ b/pkg/tools/filesystem.go
@@ -6,6 +6,7 @@ import (
"io/fs"
"os"
"path/filepath"
+ "regexp"
"strings"
"time"
@@ -87,14 +88,12 @@ type ReadFileTool struct {
fs fileSystem
}
-func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewReadFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ReadFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &ReadFileTool{fs: fs}
+ return &ReadFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *ReadFileTool) Name() string {
@@ -135,14 +134,12 @@ type WriteFileTool struct {
fs fileSystem
}
-func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &WriteFileTool{fs: fs}
+ return &WriteFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *WriteFileTool) Name() string {
@@ -192,14 +189,12 @@ type ListDirTool struct {
fs fileSystem
}
-func NewListDirTool(workspace string, restrict bool) *ListDirTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewListDirTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ListDirTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &ListDirTool{fs: fs}
+ return &ListDirTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *ListDirTool) Name() string {
@@ -394,6 +389,57 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
return entries, err
}
+// whitelistFs wraps a sandboxFs and allows access to specific paths outside
+// the workspace when they match any of the provided patterns.
+type whitelistFs struct {
+ sandbox *sandboxFs
+ host hostFs
+ patterns []*regexp.Regexp
+}
+
+func (w *whitelistFs) matches(path string) bool {
+ for _, p := range w.patterns {
+ if p.MatchString(path) {
+ return true
+ }
+ }
+ return false
+}
+
+func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
+ if w.matches(path) {
+ return w.host.ReadFile(path)
+ }
+ return w.sandbox.ReadFile(path)
+}
+
+func (w *whitelistFs) WriteFile(path string, data []byte) error {
+ if w.matches(path) {
+ return w.host.WriteFile(path, data)
+ }
+ return w.sandbox.WriteFile(path, data)
+}
+
+func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
+ if w.matches(path) {
+ return w.host.ReadDir(path)
+ }
+ return w.sandbox.ReadDir(path)
+}
+
+// buildFs returns the appropriate fileSystem implementation based on restriction
+// settings and optional path whitelist patterns.
+func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
+ if !restrict {
+ return &hostFs{}
+ }
+ sandbox := &sandboxFs{workspace: workspace}
+ if len(patterns) > 0 {
+ return &whitelistFs{sandbox: sandbox, patterns: patterns}
+ }
+ return sandbox
+}
+
// Helper to get a safe relative path for os.Root usage
func getSafeRelPath(workspace, path string) (string, error) {
if workspace == "" {
diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go
index 6f896e22d..666004cd4 100644
--- a/pkg/tools/filesystem_test.go
+++ b/pkg/tools/filesystem_test.go
@@ -5,6 +5,7 @@ import (
"io"
"os"
"path/filepath"
+ "regexp"
"strings"
"testing"
@@ -486,3 +487,36 @@ func TestRootRW_Write(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, newData, content)
}
+
+// TestWhitelistFs_AllowsMatchingPaths verifies that whitelistFs allows access to
+// paths matching the whitelist patterns while blocking non-matching paths.
+func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
+ workspace := t.TempDir()
+ outsideDir := t.TempDir()
+ outsideFile := filepath.Join(outsideDir, "allowed.txt")
+ os.WriteFile(outsideFile, []byte("outside content"), 0o644)
+
+ // Pattern allows access to the outsideDir.
+ patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(outsideDir))}
+
+ tool := NewReadFileTool(workspace, true, patterns)
+
+ // Read from whitelisted path should succeed.
+ result := tool.Execute(context.Background(), map[string]any{"path": outsideFile})
+ if result.IsError {
+ t.Errorf("expected whitelisted path to be readable, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "outside content") {
+ t.Errorf("expected file content, got: %s", result.ForLLM)
+ }
+
+ // Read from non-whitelisted path outside workspace should fail.
+ otherDir := t.TempDir()
+ otherFile := filepath.Join(otherDir, "blocked.txt")
+ os.WriteFile(otherFile, []byte("blocked"), 0o644)
+
+ result = tool.Execute(context.Background(), map[string]any{"path": otherFile})
+ if !result.IsError {
+ t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
+ }
+}
diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go
new file mode 100644
index 000000000..6e53cf354
--- /dev/null
+++ b/pkg/tools/mcp_tool.go
@@ -0,0 +1,246 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "hash/fnv"
+ "strings"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+// MCPManager defines the interface for MCP manager operations
+// This allows for easier testing with mock implementations
+type MCPManager interface {
+ CallTool(
+ ctx context.Context,
+ serverName, toolName string,
+ arguments map[string]any,
+ ) (*mcp.CallToolResult, error)
+}
+
+// MCPTool wraps an MCP tool to implement the Tool interface
+type MCPTool struct {
+ manager MCPManager
+ serverName string
+ tool *mcp.Tool
+}
+
+// NewMCPTool creates a new MCP tool wrapper
+func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
+ return &MCPTool{
+ manager: manager,
+ serverName: serverName,
+ tool: tool,
+ }
+}
+
+// sanitizeIdentifierComponent normalizes a string so it can be safely used
+// as part of a tool/function identifier for downstream providers.
+// It:
+// - lowercases the string
+// - replaces any character not in [a-z0-9_-] with '_'
+// - collapses multiple consecutive '_' into a single '_'
+// - trims leading/trailing '_'
+// - falls back to "unnamed" if the result is empty
+// - truncates overly long components to a reasonable length
+func sanitizeIdentifierComponent(s string) string {
+ const maxLen = 64
+
+ s = strings.ToLower(s)
+ var b strings.Builder
+ b.Grow(len(s))
+
+ prevUnderscore := false
+ for _, r := range s {
+ isAllowed := (r >= 'a' && r <= 'z') ||
+ (r >= '0' && r <= '9') ||
+ r == '_' || r == '-'
+
+ if !isAllowed {
+ // Normalize any disallowed character to '_'
+ if !prevUnderscore {
+ b.WriteRune('_')
+ prevUnderscore = true
+ }
+ continue
+ }
+
+ if r == '_' {
+ if prevUnderscore {
+ continue
+ }
+ prevUnderscore = true
+ } else {
+ prevUnderscore = false
+ }
+
+ b.WriteRune(r)
+ }
+
+ result := strings.Trim(b.String(), "_")
+ if result == "" {
+ result = "unnamed"
+ }
+
+ if len(result) > maxLen {
+ result = result[:maxLen]
+ }
+
+ return result
+}
+
+// Name returns the tool name, prefixed with the server name.
+// The total length is capped at 64 characters (OpenAI-compatible API limit).
+// A short hash of the original (unsanitized) server and tool names is appended
+// whenever sanitization is lossy or the name is truncated, ensuring that two
+// names which differ only in disallowed characters remain distinct after sanitization.
+func (t *MCPTool) Name() string {
+ // Prefix with server name to avoid conflicts, and sanitize components
+ sanitizedServer := sanitizeIdentifierComponent(t.serverName)
+ sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
+ full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
+
+ // Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
+ lossless := strings.ToLower(t.serverName) == sanitizedServer &&
+ strings.ToLower(t.tool.Name) == sanitizedTool
+
+ const maxTotal = 64
+ if lossless && len(full) <= maxTotal {
+ return full
+ }
+
+ // Sanitization was lossy or name too long: append hash of the ORIGINAL names
+ // (not the sanitized names) so different originals always yield different hashes.
+ h := fnv.New32a()
+ _, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
+ suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
+
+ base := full
+ if len(base) > maxTotal-9 {
+ base = strings.TrimRight(full[:maxTotal-9], "_")
+ }
+ return base + "_" + suffix
+}
+
+// Description returns the tool description
+func (t *MCPTool) Description() string {
+ desc := t.tool.Description
+ if desc == "" {
+ desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
+ }
+ // Add server info to description
+ return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
+}
+
+// Parameters returns the tool parameters schema
+func (t *MCPTool) Parameters() map[string]any {
+ // The InputSchema is already a JSON Schema object
+ schema := t.tool.InputSchema
+
+ // Handle nil schema
+ if schema == nil {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ // Try direct conversion first (fast path)
+ if schemaMap, ok := schema.(map[string]any); ok {
+ return schemaMap
+ }
+
+ // Handle json.RawMessage and []byte - unmarshal directly
+ var jsonData []byte
+ if rawMsg, ok := schema.(json.RawMessage); ok {
+ jsonData = rawMsg
+ } else if bytes, ok := schema.([]byte); ok {
+ jsonData = bytes
+ }
+
+ if jsonData != nil {
+ var result map[string]any
+ if err := json.Unmarshal(jsonData, &result); err == nil {
+ return result
+ }
+ // Fallback on error
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ // For other types (structs, etc.), convert via JSON marshal/unmarshal
+ var err error
+ jsonData, err = json.Marshal(schema)
+ if err != nil {
+ // Fallback to empty schema if marshaling fails
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ var result map[string]any
+ if err := json.Unmarshal(jsonData, &result); err != nil {
+ // Fallback to empty schema if unmarshaling fails
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ return result
+}
+
+// Execute executes the MCP tool
+func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
+ result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
+ if err != nil {
+ return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
+ }
+
+ if result == nil {
+ nilErr := fmt.Errorf("MCP tool returned nil result without error")
+ return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
+ }
+
+ // Handle error result from server
+ if result.IsError {
+ errMsg := extractContentText(result.Content)
+ return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
+ WithError(fmt.Errorf("MCP tool error: %s", errMsg))
+ }
+
+ // Extract text content from result
+ output := extractContentText(result.Content)
+
+ return &ToolResult{
+ ForLLM: output,
+ IsError: false,
+ }
+}
+
+// extractContentText extracts text from MCP content array
+func extractContentText(content []mcp.Content) string {
+ var parts []string
+ for _, c := range content {
+ switch v := c.(type) {
+ case *mcp.TextContent:
+ parts = append(parts, v.Text)
+ case *mcp.ImageContent:
+ // For images, just indicate that an image was returned
+ parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
+ default:
+ // For other content types, use string representation
+ parts = append(parts, fmt.Sprintf("[Content: %T]", v))
+ }
+ }
+ return strings.Join(parts, "\n")
+}
diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go
new file mode 100644
index 000000000..95bb0f992
--- /dev/null
+++ b/pkg/tools/mcp_tool_test.go
@@ -0,0 +1,492 @@
+package tools
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+// MockMCPManager is a mock implementation of MCPManager interface for testing
+type MockMCPManager struct {
+ callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
+}
+
+func (m *MockMCPManager) CallTool(
+ ctx context.Context,
+ serverName, toolName string,
+ arguments map[string]any,
+) (*mcp.CallToolResult, error) {
+ if m.callToolFunc != nil {
+ return m.callToolFunc(ctx, serverName, toolName, arguments)
+ }
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "mock result"},
+ },
+ IsError: false,
+ }, nil
+}
+
+// TestNewMCPTool verifies MCP tool creation
+func TestNewMCPTool(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ Description: "A test tool",
+ InputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "input": map[string]any{
+ "type": "string",
+ "description": "Test input",
+ },
+ },
+ },
+ }
+
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ if mcpTool == nil {
+ t.Fatal("NewMCPTool should not return nil")
+ }
+ // Verify tool properties we can access
+ if mcpTool.Name() != "mcp_test_server_test_tool" {
+ t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
+ }
+}
+
+// TestMCPTool_Name verifies tool name with server prefix
+func TestMCPTool_Name(t *testing.T) {
+ tests := []struct {
+ name string
+ serverName string
+ toolName string
+ expected string
+ }{
+ {
+ name: "simple name",
+ serverName: "github",
+ toolName: "create_issue",
+ expected: "mcp_github_create_issue",
+ },
+ {
+ name: "filesystem server",
+ serverName: "filesystem",
+ toolName: "read_file",
+ expected: "mcp_filesystem_read_file",
+ },
+ {
+ name: "remote server",
+ serverName: "remote-api",
+ toolName: "fetch_data",
+ expected: "mcp_remote-api_fetch_data",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{Name: tt.toolName}
+ mcpTool := NewMCPTool(manager, tt.serverName, tool)
+
+ result := mcpTool.Name()
+ if result != tt.expected {
+ t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
+ }
+ })
+ }
+}
+
+// TestMCPTool_Description verifies tool description generation
+func TestMCPTool_Description(t *testing.T) {
+ tests := []struct {
+ name string
+ serverName string
+ toolDescription string
+ expectContains []string
+ }{
+ {
+ name: "with description",
+ serverName: "github",
+ toolDescription: "Create a GitHub issue",
+ expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
+ },
+ {
+ name: "empty description",
+ serverName: "filesystem",
+ toolDescription: "",
+ expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ Description: tt.toolDescription,
+ }
+ mcpTool := NewMCPTool(manager, tt.serverName, tool)
+
+ result := mcpTool.Description()
+
+ for _, expected := range tt.expectContains {
+ if !strings.Contains(result, expected) {
+ t.Errorf("Description should contain '%s', got: %s", expected, result)
+ }
+ }
+ })
+ }
+}
+
+// TestMCPTool_Parameters verifies parameter schema conversion
+func TestMCPTool_Parameters(t *testing.T) {
+ tests := []struct {
+ name string
+ inputSchema any
+ expectType string
+ checkProperty string
+ expectProperty bool
+ }{
+ {
+ name: "map schema",
+ inputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "query": map[string]any{
+ "type": "string",
+ "description": "Search query",
+ },
+ },
+ "required": []string{"query"},
+ },
+ expectType: "object",
+ checkProperty: "query",
+ expectProperty: true,
+ },
+ {
+ name: "nil schema",
+ inputSchema: nil,
+ expectType: "object",
+ expectProperty: false,
+ },
+ {
+ name: "json.RawMessage schema",
+ inputSchema: []byte(`{
+ "type": "object",
+ "properties": {
+ "repo": {
+ "type": "string",
+ "description": "Repository name"
+ },
+ "stars": {
+ "type": "integer",
+ "description": "Minimum stars"
+ }
+ },
+ "required": ["repo"]
+ }`),
+ expectType: "object",
+ checkProperty: "repo",
+ expectProperty: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ InputSchema: tt.inputSchema,
+ }
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ params := mcpTool.Parameters()
+
+ if params == nil {
+ t.Fatal("Parameters should not be nil")
+ }
+
+ if params["type"] != tt.expectType {
+ t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
+ }
+
+ // Check if property exists when expected
+ if tt.checkProperty != "" {
+ properties, ok := params["properties"].(map[string]any)
+ if !ok && tt.expectProperty {
+ t.Errorf("Expected properties to be a map")
+ return
+ }
+ if ok {
+ _, hasProperty := properties[tt.checkProperty]
+ if hasProperty != tt.expectProperty {
+ t.Errorf("Expected property '%s' existence: %v, got: %v",
+ tt.checkProperty, tt.expectProperty, hasProperty)
+ }
+ }
+ }
+ })
+ }
+}
+
+// TestMCPTool_Execute_Success tests successful tool execution
+func TestMCPTool_Execute_Success(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ // Verify correct parameters passed
+ if serverName != "github" {
+ t.Errorf("Expected serverName 'github', got '%s'", serverName)
+ }
+ if toolName != "search_repos" {
+ t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
+ }
+
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "Found 3 repositories"},
+ },
+ IsError: false,
+ }, nil
+ },
+ }
+
+ tool := &mcp.Tool{
+ Name: "search_repos",
+ Description: "Search GitHub repositories",
+ }
+ mcpTool := NewMCPTool(manager, "github", tool)
+
+ ctx := context.Background()
+ args := map[string]any{
+ "query": "golang mcp",
+ }
+
+ result := mcpTool.Execute(ctx, args)
+
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if result.IsError {
+ t.Errorf("Expected no error, got error: %s", result.ForLLM)
+ }
+ if result.ForLLM != "Found 3 repositories" {
+ t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
+ }
+}
+
+// TestMCPTool_Execute_ManagerError tests execution when manager returns error
+func TestMCPTool_Execute_ManagerError(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return nil, fmt.Errorf("connection failed")
+ },
+ }
+
+ tool := &mcp.Tool{Name: "test_tool"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ ctx := context.Background()
+ result := mcpTool.Execute(ctx, map[string]any{})
+
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if !result.IsError {
+ t.Error("Expected IsError to be true")
+ }
+ if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
+ t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "connection failed") {
+ t.Errorf("Error message should include original error, got: %s", result.ForLLM)
+ }
+}
+
+// TestMCPTool_Execute_ServerError tests execution when server returns error
+func TestMCPTool_Execute_ServerError(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "Invalid API key"},
+ },
+ IsError: true,
+ }, nil
+ },
+ }
+
+ tool := &mcp.Tool{Name: "test_tool"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ ctx := context.Background()
+ result := mcpTool.Execute(ctx, map[string]any{})
+
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if !result.IsError {
+ t.Error("Expected IsError to be true")
+ }
+ if !strings.Contains(result.ForLLM, "MCP tool returned error") {
+ t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "Invalid API key") {
+ t.Errorf("Error message should include server message, got: %s", result.ForLLM)
+ }
+}
+
+// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
+func TestMCPTool_Execute_MultipleContent(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "First line"},
+ &mcp.TextContent{Text: "Second line"},
+ &mcp.TextContent{Text: "Third line"},
+ },
+ IsError: false,
+ }, nil
+ },
+ }
+
+ tool := &mcp.Tool{Name: "multi_output"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ ctx := context.Background()
+ result := mcpTool.Execute(ctx, map[string]any{})
+
+ if result.IsError {
+ t.Errorf("Expected no error, got: %s", result.ForLLM)
+ }
+
+ expected := "First line\nSecond line\nThird line"
+ if result.ForLLM != expected {
+ t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
+ }
+}
+
+// TestExtractContentText_TextContent tests text content extraction
+func TestExtractContentText_TextContent(t *testing.T) {
+ content := []mcp.Content{
+ &mcp.TextContent{Text: "Hello World"},
+ &mcp.TextContent{Text: "Second message"},
+ }
+
+ result := extractContentText(content)
+ expected := "Hello World\nSecond message"
+
+ if result != expected {
+ t.Errorf("Expected '%s', got '%s'", expected, result)
+ }
+}
+
+// TestExtractContentText_ImageContent tests image content extraction
+func TestExtractContentText_ImageContent(t *testing.T) {
+ content := []mcp.Content{
+ &mcp.ImageContent{
+ Data: []byte("base64data"),
+ MIMEType: "image/png",
+ },
+ }
+
+ result := extractContentText(content)
+
+ if !strings.Contains(result, "[Image:") {
+ t.Errorf("Expected image indicator, got: %s", result)
+ }
+ if !strings.Contains(result, "image/png") {
+ t.Errorf("Expected MIME type in output, got: %s", result)
+ }
+}
+
+// TestExtractContentText_MixedContent tests mixed content types
+func TestExtractContentText_MixedContent(t *testing.T) {
+ content := []mcp.Content{
+ &mcp.TextContent{Text: "Description"},
+ &mcp.ImageContent{
+ Data: []byte("data"),
+ MIMEType: "image/jpeg",
+ },
+ &mcp.TextContent{Text: "More text"},
+ }
+
+ result := extractContentText(content)
+
+ if !strings.Contains(result, "Description") {
+ t.Errorf("Should contain text content, got: %s", result)
+ }
+ if !strings.Contains(result, "[Image:") {
+ t.Errorf("Should contain image indicator, got: %s", result)
+ }
+ if !strings.Contains(result, "More text") {
+ t.Errorf("Should contain second text, got: %s", result)
+ }
+}
+
+// TestExtractContentText_EmptyContent tests empty content array
+func TestExtractContentText_EmptyContent(t *testing.T) {
+ content := []mcp.Content{}
+
+ result := extractContentText(content)
+
+ if result != "" {
+ t.Errorf("Expected empty string for empty content, got: %s", result)
+ }
+}
+
+// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
+func TestMCPTool_InterfaceCompliance(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{Name: "test"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ // Verify it implements Tool interface
+ var _ Tool = mcpTool
+}
+
+// TestMCPTool_Parameters_MapSchema tests schema that's already a map
+func TestMCPTool_Parameters_MapSchema(t *testing.T) {
+ manager := &MockMCPManager{}
+ schema := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "name": map[string]any{
+ "type": "string",
+ "description": "The name parameter",
+ },
+ },
+ "required": []string{"name"},
+ }
+
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ InputSchema: schema,
+ }
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ params := mcpTool.Parameters()
+
+ // Should return the schema as-is when it's already a map
+ if params["type"] != "object" {
+ t.Errorf("Expected type 'object', got '%v'", params["type"])
+ }
+
+ props, ok := params["properties"].(map[string]any)
+ if !ok {
+ t.Error("Properties should be a map")
+ }
+
+ nameParam, ok := props["name"].(map[string]any)
+ if !ok {
+ t.Error("Name parameter should exist")
+ }
+
+ if nameParam["type"] != "string" {
+ t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
+ }
+}
diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go
index d37a093a8..0ba983e02 100644
--- a/pkg/tools/registry.go
+++ b/pkg/tools/registry.go
@@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry {
func (r *ToolRegistry) Register(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
- r.tools[tool.Name()] = tool
+ name := tool.Name()
+ if _, exists := r.tools[name]; exists {
+ logger.WarnCF("tools", "Tool registration overwrites existing tool",
+ map[string]any{"name": name})
+ }
+ r.tools[name] = tool
}
func (r *ToolRegistry) Get(name string) (Tool, bool) {
diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go
index 8ae13b20c..8fe88ca78 100644
--- a/pkg/tools/registry_test.go
+++ b/pkg/tools/registry_test.go
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
r := NewToolRegistry()
var wg sync.WaitGroup
- for i := 0; i < 50; i++ {
+ for i := range 50 {
wg.Add(1)
go func(n int) {
defer wg.Done()
diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go
index 88e4256db..08711ae14 100644
--- a/pkg/tools/shell.go
+++ b/pkg/tools/shell.go
@@ -21,53 +21,77 @@ type ExecTool struct {
timeout time.Duration
denyPatterns []*regexp.Regexp
allowPatterns []*regexp.Regexp
+ customAllowPatterns []*regexp.Regexp
restrictToWorkspace bool
}
-var defaultDenyPatterns = []*regexp.Regexp{
- regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
- regexp.MustCompile(`\bdel\s+/[fq]\b`),
- regexp.MustCompile(`\brmdir\s+/s\b`),
- regexp.MustCompile(`(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`), // Match disk wiping commands, avoid matching --format flags
- regexp.MustCompile(`\bdd\s+if=`),
- regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
- regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
- regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
- regexp.MustCompile(`\$\([^)]+\)`),
- regexp.MustCompile(`\$\{[^}]+\}`),
- regexp.MustCompile("`[^`]+`"),
- regexp.MustCompile(`\|\s*sh\b`),
- regexp.MustCompile(`\|\s*bash\b`),
- regexp.MustCompile(`;\s*rm\s+-[rf]`),
- regexp.MustCompile(`&&\s*rm\s+-[rf]`),
- regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
- regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
- regexp.MustCompile(`<<\s*EOF`),
- regexp.MustCompile(`\$\(\s*cat\s+`),
- regexp.MustCompile(`\$\(\s*curl\s+`),
- regexp.MustCompile(`\$\(\s*wget\s+`),
- regexp.MustCompile(`\$\(\s*which\s+`),
- regexp.MustCompile(`\bsudo\b`),
- regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
- regexp.MustCompile(`\bchown\b`),
- regexp.MustCompile(`\bpkill\b`),
- regexp.MustCompile(`\bkillall\b`),
- regexp.MustCompile(`\bkill\s+-[9]\b`),
- regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
- regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
- regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
- regexp.MustCompile(`\bpip\s+install\s+--user\b`),
- regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
- regexp.MustCompile(`\byum\s+(install|remove)\b`),
- regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
- regexp.MustCompile(`\bdocker\s+run\b`),
- regexp.MustCompile(`\bdocker\s+exec\b`),
- regexp.MustCompile(`\bgit\s+push\b`),
- regexp.MustCompile(`\bgit\s+force\b`),
- regexp.MustCompile(`\bssh\b.*@`),
- regexp.MustCompile(`\beval\b`),
- regexp.MustCompile(`\bsource\s+.*\.sh\b`),
-}
+var (
+ defaultDenyPatterns = []*regexp.Regexp{
+ regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
+ regexp.MustCompile(`\bdel\s+/[fq]\b`),
+ regexp.MustCompile(`\brmdir\s+/s\b`),
+ // Match disk wiping commands, avoid matching --format flags
+ regexp.MustCompile(
+ `(?:^|[;&|]\s*|\s+)(format|mkfs|diskpart)\s`,
+ ),
+ regexp.MustCompile(`\bdd\s+if=`),
+ // Block writes to block devices (all common naming schemes).
+ regexp.MustCompile(
+ `>\s*/dev/(sd[a-z]|hd[a-z]|vd[a-z]|xvd[a-z]|nvme\d|mmcblk\d|loop\d|dm-\d|md\d|sr\d|nbd\d)`,
+ ),
+ regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
+ regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
+ regexp.MustCompile(`\$\([^)]+\)`),
+ regexp.MustCompile(`\$\{[^}]+\}`),
+ regexp.MustCompile("`[^`]+`"),
+ regexp.MustCompile(`\|\s*sh\b`),
+ regexp.MustCompile(`\|\s*bash\b`),
+ regexp.MustCompile(`;\s*rm\s+-[rf]`),
+ regexp.MustCompile(`&&\s*rm\s+-[rf]`),
+ regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
+ regexp.MustCompile(`<<\s*EOF`),
+ regexp.MustCompile(`\$\(\s*cat\s+`),
+ regexp.MustCompile(`\$\(\s*curl\s+`),
+ regexp.MustCompile(`\$\(\s*wget\s+`),
+ regexp.MustCompile(`\$\(\s*which\s+`),
+ regexp.MustCompile(`\bsudo\b`),
+ regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
+ regexp.MustCompile(`\bchown\b`),
+ regexp.MustCompile(`\bpkill\b`),
+ regexp.MustCompile(`\bkillall\b`),
+ regexp.MustCompile(`\bkill\s+-[9]\b`),
+ regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
+ regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
+ regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
+ regexp.MustCompile(`\bpip\s+install\s+--user\b`),
+ regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
+ regexp.MustCompile(`\byum\s+(install|remove)\b`),
+ regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
+ regexp.MustCompile(`\bdocker\s+run\b`),
+ regexp.MustCompile(`\bdocker\s+exec\b`),
+ regexp.MustCompile(`\bgit\s+push\b`),
+ regexp.MustCompile(`\bgit\s+force\b`),
+ regexp.MustCompile(`\bssh\b.*@`),
+ regexp.MustCompile(`\beval\b`),
+ regexp.MustCompile(`\bsource\s+.*\.sh\b`),
+ }
+
+ // absolutePathPattern matches absolute file paths in commands (Unix and Windows).
+ absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
+
+ // safePaths are kernel pseudo-devices that are always safe to reference in
+ // commands, regardless of workspace restriction. They contain no user data
+ // and cannot cause destructive writes.
+ safePaths = map[string]bool{
+ "/dev/null": true,
+ "/dev/zero": true,
+ "/dev/random": true,
+ "/dev/urandom": true,
+ "/dev/stdin": true,
+ "/dev/stdout": true,
+ "/dev/stderr": true,
+ }
+)
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
@@ -75,6 +99,7 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
+ customAllowPatterns := make([]*regexp.Regexp, 0)
if config != nil {
execConfig := config.Tools.Exec
@@ -95,6 +120,13 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
}
+ for _, pattern := range execConfig.CustomAllowPatterns {
+ re, err := regexp.Compile(pattern)
+ if err != nil {
+ return nil, fmt.Errorf("invalid custom allow pattern %q: %w", pattern, err)
+ }
+ customAllowPatterns = append(customAllowPatterns, re)
+ }
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
@@ -104,6 +136,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
timeout: 60 * time.Second,
denyPatterns: denyPatterns,
allowPatterns: nil,
+ customAllowPatterns: customAllowPatterns,
restrictToWorkspace: restrict,
}, nil
}
@@ -258,9 +291,20 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
cmd := strings.TrimSpace(command)
lower := strings.ToLower(cmd)
- for _, pattern := range t.denyPatterns {
+ // Custom allow patterns exempt a command from deny checks.
+ explicitlyAllowed := false
+ for _, pattern := range t.customAllowPatterns {
if pattern.MatchString(lower) {
- return "Command blocked by safety guard (dangerous pattern detected)"
+ explicitlyAllowed = true
+ break
+ }
+ }
+
+ if !explicitlyAllowed {
+ for _, pattern := range t.denyPatterns {
+ if pattern.MatchString(lower) {
+ return "Command blocked by safety guard (dangerous pattern detected)"
+ }
}
}
@@ -287,16 +331,18 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
return ""
}
- pathPattern := regexp.MustCompile(`(?:^|\s|=)([A-Za-z]:\\[^\\"']+|/[a-zA-Z.][^\s"']*)`)
- matches := pathPattern.FindAllStringSubmatch(cmd, -1)
+ matches := absolutePathPattern.FindAllString(cmd, -1)
- for _, match := range matches {
- raw := match[1]
+ for _, raw := range matches {
p, err := filepath.Abs(raw)
if err != nil {
continue
}
+ if safePaths[p] {
+ continue
+ }
+
rel, err := filepath.Rel(cwdPath, p)
if err != nil {
continue
diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go
index 009a03c80..cee16603d 100644
--- a/pkg/tools/shell_test.go
+++ b/pkg/tools/shell_test.go
@@ -7,6 +7,8 @@ import (
"strings"
"testing"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/config"
)
// TestShellTool_Success verifies successful command execution
@@ -310,6 +312,60 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
}
}
+// TestShellTool_DevNullAllowed verifies that /dev/null redirections are not blocked (issue #964).
+func TestShellTool_DevNullAllowed(t *testing.T) {
+ tmpDir := t.TempDir()
+ tool, err := NewExecTool(tmpDir, true)
+ if err != nil {
+ t.Fatalf("unable to configure exec tool: %s", err)
+ }
+
+ commands := []string{
+ "echo hello 2>/dev/null",
+ "echo hello >/dev/null",
+ "echo hello > /dev/null",
+ "echo hello 2> /dev/null",
+ "echo hello >/dev/null 2>&1",
+ "find " + tmpDir + " -name '*.go' 2>/dev/null",
+ }
+
+ for _, cmd := range commands {
+ result := tool.Execute(context.Background(), map[string]any{"command": cmd})
+ if result.IsError && strings.Contains(result.ForLLM, "blocked") {
+ t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
+ }
+ }
+}
+
+// TestShellTool_BlockDevices verifies that writes to block devices are blocked (issue #965).
+func TestShellTool_BlockDevices(t *testing.T) {
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Fatalf("unable to configure exec tool: %s", err)
+ }
+
+ blocked := []string{
+ "echo x > /dev/sda",
+ "echo x > /dev/hda",
+ "echo x > /dev/vda",
+ "echo x > /dev/xvda",
+ "echo x > /dev/nvme0n1",
+ "echo x > /dev/mmcblk0",
+ "echo x > /dev/loop0",
+ "echo x > /dev/dm-0",
+ "echo x > /dev/md0",
+ "echo x > /dev/sr0",
+ "echo x > /dev/nbd0",
+ }
+
+ for _, cmd := range blocked {
+ result := tool.Execute(context.Background(), map[string]any{"command": cmd})
+ if !result.IsError {
+ t.Errorf("expected block device write to be blocked: %s", cmd)
+ }
+ }
+}
+
// TestShellTool_DenyPattern_DiskWiping verifies the deny pattern for disk wiping
// commands (format, mkfs, diskpart) blocks them when preceded by shell separators
// but does NOT block legitimate uses like --format flags.
@@ -322,7 +378,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
ctx := context.Background()
// These should be BLOCKED (disk wiping commands)
- blocked := []struct {
+ blockedCmds := []struct {
name string
cmd string
}{
@@ -334,7 +390,7 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
{"diskpart standalone", "diskpart /s script.txt"},
}
- for _, tt := range blocked {
+ for _, tt := range blockedCmds {
t.Run("blocked_"+tt.name, func(t *testing.T) {
result := tool.Execute(ctx, map[string]any{"command": tt.cmd})
if !result.IsError {
@@ -362,35 +418,60 @@ func TestShellTool_DenyPattern_DiskWiping(t *testing.T) {
}
}
-// TestShellTool_RestrictToWorkspace_HiddenDirs verifies that hidden directory
-// paths (starting with .) are properly detected by the workspace guard.
-func TestShellTool_RestrictToWorkspace_HiddenDirs(t *testing.T) {
+// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
+// are allowed even when workspace restriction is active.
+func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
tmpDir := t.TempDir()
- tool, err := NewExecTool(tmpDir, false)
+ tool, err := NewExecTool(tmpDir, true)
if err != nil {
t.Fatalf("unable to configure exec tool: %s", err)
}
- tool.SetRestrictToWorkspace(true)
- ctx := context.Background()
-
- // Reading a hidden dir outside workspace should be blocked
- result := tool.Execute(ctx, map[string]any{
- "command": "cat /.ssh/config",
- })
- if !result.IsError {
- t.Errorf("Expected /.ssh/config to be blocked with restrictToWorkspace=true")
+ // These reference paths outside workspace but should be allowed via safePaths.
+ commands := []string{
+ "cat /dev/urandom | head -c 16 | od",
+ "echo test > /dev/null",
+ "dd if=/dev/zero bs=1 count=1",
}
- // Flag-attached paths outside workspace should be blocked
- result2 := tool.Execute(ctx, map[string]any{
- "command": "grep --include=/etc/passwd pattern",
- })
- if !result2.IsError {
- // This tests the = delimiter fix; --include=/etc/passwd uses = in real
- // usage but --include /etc/passwd uses space. Both patterns should catch it.
- // If this specific form isn't blocked, it's acceptable since the primary
- // concern is the = form (--file=/etc/passwd).
- _ = result2 // acceptable either way for this pattern variant
+ for _, cmd := range commands {
+ result := tool.Execute(context.Background(), map[string]any{"command": cmd})
+ if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
+ t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
+ }
+ }
+}
+
+// TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt
+// commands from deny pattern checks.
+func TestShellTool_CustomAllowPatterns(t *testing.T) {
+ cfg := &config.Config{
+ Tools: config.ToolsConfig{
+ Exec: config.ExecConfig{
+ EnableDenyPatterns: true,
+ CustomAllowPatterns: []string{`\bgit\s+push\s+origin\b`},
+ },
+ },
+ }
+
+ tool, err := NewExecToolWithConfig("", false, cfg)
+ if err != nil {
+ t.Fatalf("unable to configure exec tool: %s", err)
+ }
+
+ // "git push origin main" should be allowed by custom allow pattern.
+ result := tool.Execute(context.Background(), map[string]any{
+ "command": "git push origin main",
+ })
+ if result.IsError && strings.Contains(result.ForLLM, "blocked") {
+ t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
+ }
+
+ // "git push upstream main" should still be blocked (does not match allow pattern).
+ result = tool.Execute(context.Background(), map[string]any{
+ "command": "git push upstream main",
+ })
+ if !result.IsError {
+ t.Errorf("'git push upstream main' should still be blocked by deny pattern")
}
}
diff --git a/pkg/tools/web.go b/pkg/tools/web.go
index 8ba2a723a..15d2330ff 100644
--- a/pkg/tools/web.go
+++ b/pkg/tools/web.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
@@ -15,6 +16,14 @@ import (
const (
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
+
+ // HTTP client timeouts for web tool providers.
+ searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
+ perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
+ fetchTimeout = 60 * time.Second // WebFetchTool
+
+ defaultMaxChars = 50000
+ maxRedirects = 5
)
// Pre-compiled regexes for HTML text extraction
@@ -74,6 +83,7 @@ type SearchProvider interface {
type BraveSearchProvider struct {
apiKey string
proxy string
+ client *http.Client
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -88,11 +98,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
- client, err := createHTTPClient(p.proxy, 10*time.Second)
- if err != nil {
- return "", fmt.Errorf("failed to create HTTP client: %w", err)
- }
- resp, err := client.Do(req)
+ resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -103,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
return "", fmt.Errorf("failed to read response: %w", err)
}
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
+ }
+
var searchResp struct {
Web struct {
Results []struct {
@@ -143,6 +153,7 @@ type TavilySearchProvider struct {
apiKey string
baseURL string
proxy string
+ client *http.Client
}
func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -174,11 +185,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
- client, err := createHTTPClient(p.proxy, 10*time.Second)
- if err != nil {
- return "", fmt.Errorf("failed to create HTTP client: %w", err)
- }
- resp, err := client.Do(req)
+ resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -226,7 +233,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i
}
type DuckDuckGoSearchProvider struct {
- proxy string
+ proxy string
+ client *http.Client
}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -239,11 +247,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("User-Agent", userAgent)
- client, err := createHTTPClient(p.proxy, 10*time.Second)
- if err != nil {
- return "", fmt.Errorf("failed to create HTTP client: %w", err)
- }
- resp, err := client.Do(req)
+ resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -285,7 +289,7 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
maxItems := min(len(matches), count)
- for i := 0; i < maxItems; i++ {
+ for i := range maxItems {
urlStr := matches[i][1]
title := stripTags(matches[i][2])
title = strings.TrimSpace(title)
@@ -293,9 +297,9 @@ func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query
// URL decoding if needed
if strings.Contains(urlStr, "uddg=") {
if u, err := url.QueryUnescape(urlStr); err == nil {
- idx := strings.Index(u, "uddg=")
- if idx != -1 {
- urlStr = u[idx+5:]
+ _, after, ok := strings.Cut(u, "uddg=")
+ if ok {
+ urlStr = after
}
}
}
@@ -322,6 +326,7 @@ func stripTags(content string) string {
type PerplexitySearchProvider struct {
apiKey string
proxy string
+ client *http.Client
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -356,11 +361,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
- client, err := createHTTPClient(p.proxy, 30*time.Second)
- if err != nil {
- return "", fmt.Errorf("failed to create HTTP client: %w", err)
- }
- resp, err := client.Do(req)
+ resp, err := p.client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
@@ -415,43 +416,60 @@ type WebSearchToolOptions struct {
Proxy string
}
-func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
+func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) {
var provider SearchProvider
maxResults := 5
// Priority: Perplexity > Brave > Tavily > DuckDuckGo
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
- provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
+ client, err := createHTTPClient(opts.Proxy, perplexityTimeout)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err)
+ }
+ provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
- provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
+ client, err := createHTTPClient(opts.Proxy, searchTimeout)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err)
+ }
+ provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
+ client, err := createHTTPClient(opts.Proxy, searchTimeout)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err)
+ }
provider = &TavilySearchProvider{
apiKey: opts.TavilyAPIKey,
baseURL: opts.TavilyBaseURL,
proxy: opts.Proxy,
+ client: client,
}
if opts.TavilyMaxResults > 0 {
maxResults = opts.TavilyMaxResults
}
} else if opts.DuckDuckGoEnabled {
- provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
+ client, err := createHTTPClient(opts.Proxy, searchTimeout)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err)
+ }
+ provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else {
- return nil
+ return nil, nil
}
return &WebSearchTool{
provider: provider,
maxResults: maxResults,
- }
+ }, nil
}
func (t *WebSearchTool) Name() string {
@@ -506,27 +524,40 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR
}
type WebFetchTool struct {
- maxChars int
- proxy string
+ maxChars int
+ proxy string
+ client *http.Client
+ fetchLimitBytes int64
}
-func NewWebFetchTool(maxChars int) *WebFetchTool {
- if maxChars <= 0 {
- maxChars = 50000
- }
- return &WebFetchTool{
- maxChars: maxChars,
- }
+func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) {
+ // createHTTPClient cannot fail with an empty proxy string.
+ return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes)
}
-func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
+func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) {
if maxChars <= 0 {
- maxChars = 50000
+ maxChars = defaultMaxChars
+ }
+ client, err := createHTTPClient(proxy, fetchTimeout)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err)
+ }
+ client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
+ if len(via) >= maxRedirects {
+ return fmt.Errorf("stopped after %d redirects", maxRedirects)
+ }
+ return nil
+ }
+ if fetchLimitBytes <= 0 {
+ fetchLimitBytes = 10 * 1024 * 1024 // Security Fallback
}
return &WebFetchTool{
- maxChars: maxChars,
- proxy: proxy,
- }
+ maxChars: maxChars,
+ proxy: proxy,
+ client: client,
+ fetchLimitBytes: fetchLimitBytes,
+ }, nil
}
func (t *WebFetchTool) Name() string {
@@ -588,27 +619,21 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
req.Header.Set("User-Agent", userAgent)
- client, err := createHTTPClient(t.proxy, 60*time.Second)
- if err != nil {
- return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
- }
-
- // Configure redirect handling
- client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
- if len(via) >= 5 {
- return fmt.Errorf("stopped after 5 redirects")
- }
- return nil
- }
-
- resp, err := client.Do(req)
+ resp, err := t.client.Do(req)
if err != nil {
return ErrorResult(fmt.Sprintf("request failed: %v", err))
}
+
+ resp.Body = http.MaxBytesReader(nil, resp.Body, t.fetchLimitBytes)
+
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
+ var maxBytesErr *http.MaxBytesError
+ if errors.As(err, &maxBytesErr) {
+ return ErrorResult(fmt.Sprintf("failed to read response: size exceeded %d bytes limit", t.fetchLimitBytes))
+ }
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
}
@@ -652,14 +677,14 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe
resultJSON, _ := json.MarshalIndent(result, "", " ")
return &ToolResult{
- ForLLM: fmt.Sprintf(
+ ForLLM: string(resultJSON),
+ ForUser: fmt.Sprintf(
"Fetched %d bytes from %s (extractor: %s, truncated: %v)",
len(text),
urlStr,
extractor,
truncated,
),
- ForUser: string(resultJSON),
}
}
diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go
index 2cd79eb24..8a8b88131 100644
--- a/pkg/tools/web_test.go
+++ b/pkg/tools/web_test.go
@@ -1,15 +1,21 @@
package tools
import (
+ "bytes"
"context"
"encoding/json"
+ "fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
)
+const testFetchLimit = int64(10 * 1024 * 1024)
+
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -19,7 +25,11 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
}))
defer server.Close()
- tool := NewWebFetchTool(50000)
+ tool, err := NewWebFetchTool(50000, testFetchLimit)
+ if err != nil {
+ t.Fatalf("Failed to create web fetch tool: %v", err)
+ }
+
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -32,14 +42,14 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
- // ForUser should contain the fetched content
- if !strings.Contains(result.ForUser, "Test Page") {
- t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
+ // ForLLM should contain the fetched content (full JSON result)
+ if !strings.Contains(result.ForLLM, "Test Page") {
+ t.Errorf("Expected ForLLM to contain 'Test Page', got: %s", result.ForLLM)
}
- // ForLLM should contain summary
- if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
- t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
+ // ForUser should contain summary
+ if !strings.Contains(result.ForUser, "bytes") && !strings.Contains(result.ForUser, "extractor") {
+ t.Errorf("Expected ForUser to contain summary, got: %s", result.ForUser)
}
}
@@ -55,7 +65,11 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
}))
defer server.Close()
- tool := NewWebFetchTool(50000)
+ tool, err := NewWebFetchTool(50000, testFetchLimit)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ }
+
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -68,15 +82,19 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
- // ForUser should contain formatted JSON
- if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
- t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
+ // ForLLM should contain formatted JSON
+ if !strings.Contains(result.ForLLM, "key") && !strings.Contains(result.ForLLM, "value") {
+ t.Errorf("Expected ForLLM to contain JSON data, got: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
- tool := NewWebFetchTool(50000)
+ tool, err := NewWebFetchTool(50000, testFetchLimit)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ }
+
ctx := context.Background()
args := map[string]any{
"url": "not-a-valid-url",
@@ -97,7 +115,11 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
- tool := NewWebFetchTool(50000)
+ tool, err := NewWebFetchTool(50000, testFetchLimit)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ }
+
ctx := context.Background()
args := map[string]any{
"url": "ftp://example.com/file.txt",
@@ -118,7 +140,11 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
- tool := NewWebFetchTool(50000)
+ tool, err := NewWebFetchTool(50000, testFetchLimit)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ }
+
ctx := context.Background()
args := map[string]any{}
@@ -146,7 +172,11 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}))
defer server.Close()
- tool := NewWebFetchTool(1000) // Limit to 1000 chars
+ tool, err := NewWebFetchTool(1000, testFetchLimit) // Limit to 1000 chars
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ }
+
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -159,9 +189,9 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
- // ForUser should contain truncated content (not the full 20000 chars)
+ // ForLLM should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]any)
- json.Unmarshal([]byte(result.ForUser), &resultMap)
+ json.Unmarshal([]byte(result.ForLLM), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
@@ -174,15 +204,64 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
}
+func TestWebFetchTool_PayloadTooLarge(t *testing.T) {
+ // Create a mock HTTP server
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.WriteHeader(http.StatusOK)
+
+ // Generate a payload intentionally larger than our limit.
+ // Limit: 10 * 1024 * 1024 (10MB). We generate 10MB + 100 bytes of the letter 'A'.
+ largeData := bytes.Repeat([]byte("A"), int(testFetchLimit)+100)
+
+ w.Write(largeData)
+ }))
+ // Ensure the server is shut down at the end of the test
+ defer ts.Close()
+
+ // Initialize the tool
+ tool, err := NewWebFetchTool(50000, testFetchLimit)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ }
+
+ // Prepare the arguments pointing to the URL of our local mock server
+ args := map[string]any{
+ "url": ts.URL,
+ }
+
+ // Execute the tool
+ ctx := context.Background()
+ result := tool.Execute(ctx, args)
+
+ // Assuming ErrorResult sets the ForLLM field with the error text.
+ if result == nil {
+ t.Fatal("expected a ToolResult, got nil")
+ }
+
+ // Search for the exact error string we set earlier in the Execute method
+ expectedErrorMsg := fmt.Sprintf("size exceeded %d bytes limit", testFetchLimit)
+
+ if !strings.Contains(result.ForLLM, expectedErrorMsg) && !strings.Contains(result.ForUser, expectedErrorMsg) {
+ t.Errorf("test failed: expected error %q, but got: %+v", expectedErrorMsg, result)
+ }
+}
+
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
- tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
+ tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
if tool != nil {
t.Errorf("Expected nil tool when Brave API key is empty")
}
// Also nil when nothing is enabled
- tool = NewWebSearchTool(WebSearchToolOptions{})
+ tool, err = NewWebSearchTool(WebSearchToolOptions{})
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
if tool != nil {
t.Errorf("Expected nil tool when no provider is enabled")
}
@@ -190,7 +269,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
- tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
+ tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
ctx := context.Background()
args := map[string]any{}
@@ -215,7 +297,11 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
}))
defer server.Close()
- tool := NewWebFetchTool(50000)
+ tool, err := NewWebFetchTool(50000, testFetchLimit)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ }
+
ctx := context.Background()
args := map[string]any{
"url": server.URL,
@@ -228,14 +314,14 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
- // ForUser should contain extracted text (without script/style tags)
- if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
- t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
+ // ForLLM should contain extracted text (without script/style tags)
+ if !strings.Contains(result.ForLLM, "Title") && !strings.Contains(result.ForLLM, "Content") {
+ t.Errorf("Expected ForLLM to contain extracted text, got: %s", result.ForLLM)
}
- // Should NOT contain script or style tags
- if strings.Contains(result.ForUser, "