/skills` (builtin)
+
+For advanced/test setups, you can override the builtin skills root with:
+
+```bash
+export PICOCLAW_BUILTIN_SKILLS=/path/to/skills
+```
+
### 🔒 Security Sandbox
PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace.
@@ -794,7 +926,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
### Providers
> [!NOTE]
-> Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
+> Groq provides free voice transcription via Whisper. If configured, audio messages from any channel will be automatically transcribed at the agent level.
| Provider | Purpose | Get API Key |
| -------------------------- | --------------------------------------- | -------------------------------------------------------------------- |
@@ -822,7 +954,7 @@ This design also enables **multi-agent support** with flexible provider selectio
#### 📋 All Supported Vendors
| Vendor | `model` Prefix | Default API Base | Protocol | API Key |
-| ------------------- | ----------------- | --------------------------------------------------- | --------- | ---------------------------------------------------------------- |
+| ------------------- | ----------------- |-----------------------------------------------------| --------- | ---------------------------------------------------------------- |
| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) |
| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) |
| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
@@ -834,6 +966,7 @@ This design also enables **multi-agent support** with flexible provider selectio
| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) |
| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) |
| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) |
+| **LiteLLM Proxy** | `litellm/` | `http://localhost:4000/v1 | OpenAI | Your LiteLLM proxy key |
| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local |
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) |
| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) |
@@ -930,10 +1063,24 @@ This design also enables **multi-agent support** with flexible provider selectio
"model_name": "my-custom-model",
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
+ "api_key": "sk-...",
+ "request_timeout": 300
+}
+```
+
+**LiteLLM Proxy**
+
+```json
+{
+ "model_name": "lite-gpt4",
+ "model": "litellm/lite-gpt4",
+ "api_base": "http://localhost:4000/v1",
"api_key": "sk-..."
}
```
+PicoClaw strips only the outer `litellm/` prefix before sending the request, so proxy aliases like `litellm/lite-gpt4` send `lite-gpt4`, while `litellm/openai/gpt-4o` sends `openai/gpt-4o`.
+
#### Load Balancing
Configure multiple endpoints for the same model name—PicoClaw will automatically round-robin between them:
@@ -1078,7 +1225,11 @@ picoclaw agent -m "Hello"
"allow_from": [""]
},
"whatsapp": {
- "enabled": false
+ "enabled": false,
+ "bridge_url": "ws://localhost:3001",
+ "use_native": false,
+ "session_store_path": "",
+ "allow_from": []
},
"feishu": {
"enabled": false,
diff --git a/README.pt-br.md b/README.pt-br.md
index 0115b7f89..67ce9e0d3 100644
--- a/README.pt-br.md
+++ b/README.pt-br.md
@@ -165,39 +165,43 @@ Você tambêm pode rodar o PicoClaw usando Docker Compose sem instalar nada loca
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
-# 2. Configure suas API keys
-cp config/config.example.json config/config.json
-vim config/config.json # Configure DISCORD_BOT_TOKEN, API keys, etc.
+# 2. Primeiro uso — gera docker/data/config.json automaticamente e para
+docker compose -f docker/docker-compose.yml --profile gateway up
+# O contêiner exibe "First-run setup complete." e para.
-# 3. Build & Iniciar
-docker compose --profile gateway up -d
+# 3. Configure suas API keys
+vim docker/data/config.json # Chaves de API do provedor, tokens de bot, etc.
+
+# 4. Iniciar
+docker compose -f docker/docker-compose.yml --profile gateway up -d
+```
> [!TIP]
> **Usuários Docker**: Por padrão, o Gateway ouve em `127.0.0.1`, o que não é acessível a partir do host. Se você precisar acessar os endpoints de integridade ou expor portas, defina `PICOCLAW_GATEWAY_HOST=0.0.0.0` em seu ambiente ou atualize o `config.json`.
+```bash
+# 5. Ver logs
+docker compose -f docker/docker-compose.yml logs -f picoclaw-gateway
-# 4. Ver logs
-docker compose logs -f picoclaw-gateway
-
-# 5. Parar
-docker compose --profile gateway down
+# 6. Parar
+docker compose -f docker/docker-compose.yml --profile gateway down
```
### Modo Agente (Execução única)
```bash
# Fazer uma pergunta
-docker compose run --rm picoclaw-agent -m "Quanto e 2+2?"
+docker compose -f docker/docker-compose.yml run --rm picoclaw-agent -m "Quanto e 2+2?"
# Modo interativo
-docker compose run --rm picoclaw-agent
+docker compose -f docker/docker-compose.yml run --rm picoclaw-agent
```
-### Rebuild
+### Atualizar
```bash
-docker compose --profile gateway build --no-cache
-docker compose --profile gateway up -d
+docker compose -f docker/docker-compose.yml pull
+docker compose -f docker/docker-compose.yml --profile gateway up -d
```
### 🚀 Início Rápido
@@ -222,6 +226,7 @@ picoclaw onboard
"model_name": "gpt4",
"model": "openai/gpt-5.2",
"api_key": "sk-your-openai-key",
+ "request_timeout": 300,
"api_base": "https://api.openai.com/v1"
}
],
@@ -246,6 +251,9 @@ picoclaw onboard
}
```
+> **Novo**: O formato de configuração `model_list` permite adicionar provedores sem alterar código. Veja [Configuração de Modelo](#configuração-de-modelo-model_list) para detalhes.
+> `request_timeout` é opcional e usa segundos. Se omitido ou definido como `<= 0`, o PicoClaw usa o timeout padrão (120s).
+
**3. Obter API Keys**
* **Provedor de LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
@@ -274,7 +282,7 @@ Converse com seu PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom.
| **QQ** | Fácil (AppID + AppSecret) |
| **DingTalk** | Médio (credenciais do app) |
| **LINE** | Médio (credenciais + webhook URL) |
-| **WeCom** | Médio (CorpID + configuração webhook) |
+| **WeCom AI Bot** | Médio (Token + chave AES) |
Telegram (Recomendado)
@@ -442,8 +450,6 @@ picoclaw gateway
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
@@ -457,11 +463,13 @@ O LINE requer HTTPS para webhooks. Use um reverse proxy ou tunnel:
```bash
# Exemplo com ngrok
-ngrok http 18791
+ngrok http 18790
```
Em seguida, configure a Webhook URL no LINE Developers Console para `https://seu-dominio/webhook/line` e habilite **Use webhook**.
+> **Nota**: O webhook do LINE é servido pelo Gateway compartilhado (padrão 127.0.0.1:18790). Use um proxy reverso/HTTPS ou túnel (como ngrok) para expor o Gateway de forma segura quando necessário.
+
**4. Executar**
```bash
@@ -470,19 +478,20 @@ picoclaw gateway
> Em chats de grupo, o bot responde apenas quando mencionado com @. As respostas citam a mensagem original.
-> **Docker Compose**: Adicione `ports: ["18791:18791"]` ao serviço `picoclaw-gateway` para expor a porta do webhook.
+> **Docker Compose**: Se você usa Docker Compose, exponha o Gateway (padrão 127.0.0.1:18790) se precisar acessar o webhook LINE externamente, por exemplo `ports: ["18790:18790"]`.
WeCom (WeChat Work)
-O PicoClaw suporta dois tipos de integração WeCom:
+O PicoClaw suporta três tipos de integração WeCom:
-**Opção 1: WeCom Bot (Robô Inteligente)** - Configuração mais fácil, suporta chats em grupo
-**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas
+**Opção 1: WeCom Bot (Robô)** - Configuração mais fácil, suporta chats em grupo
+**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas, somente chat privado
+**Opção 3: WeCom AI Bot (Robô Inteligente)** - Bot IA oficial, respostas em streaming, suporta grupo e privado
-Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para instruções detalhadas.
+Veja o [Guia de Configuração WeCom AI Bot](docs/channels/wecom/wecom_aibot/README.zh.md) para instruções detalhadas.
**Configuração Rápida - WeCom Bot:**
@@ -501,8 +510,6 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": []
}
@@ -510,6 +517,8 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
}
```
+> **Nota**: O webhook do WeCom Bot é atendido pelo Gateway compartilhado (padrão 127.0.0.1:18790). Use um proxy reverso/HTTPS ou túnel para expor o Gateway em produção.
+
**Configuração Rápida - WeCom App:**
**1. Criar um aplicativo**
@@ -521,7 +530,7 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
**2. Configurar recebimento de mensagens**
* Nos detalhes do aplicativo, clique em "Receber Mensagens" → "Configurar API"
-* Defina a URL como `http://your-server:18792/webhook/wecom-app`
+* Defina a URL como `http://your-server:18790/webhook/wecom-app`
* Gere o **Token** e o **EncodingAESKey**
**3. Configurar**
@@ -536,8 +545,6 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": []
}
@@ -551,7 +558,40 @@ Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para
picoclaw gateway
```
-> **Nota**: O WeCom App requer a abertura da porta 18792 para callbacks de webhook. Use um proxy reverso para HTTPS em produção.
+> **Nota**: O WeCom App (callbacks de webhook) é servido pelo Gateway compartilhado (padrão 127.0.0.1:18790). Em produção use um proxy reverso HTTPS para expor a porta do Gateway, ou atualize `PICOCLAW_GATEWAY_HOST` para `0.0.0.0` se necessário.
+
+**Configuração Rápida - WeCom AI Bot:**
+
+**1. Criar um AI Bot**
+
+* Acesse o Console de Administração WeCom → Gerenciamento de Aplicativos → AI Bot
+* Configure a URL de callback: `http://your-server:18791/webhook/wecom-aibot`
+* Copie o **Token** e gere o **EncodingAESKey**
+
+**2. Configurar**
+
+```json
+{
+ "channels": {
+ "wecom_aibot": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "allow_from": [],
+ "welcome_message": "Olá! Como posso ajudá-lo?"
+ }
+ }
+}
+```
+
+**3. Executar**
+
+```bash
+picoclaw gateway
+```
+
+> **Nota**: O WeCom AI Bot usa protocolo de pull em streaming — sem preocupações com timeout de resposta. Tarefas longas (>5,5 min) alternam automaticamente para entrega via `response_url`.
@@ -565,6 +605,31 @@ Conecte o PicoClaw a Rede Social de Agentes simplesmente enviando uma única men
Arquivo de configuração: `~/.picoclaw/config.json`
+### Variáveis de Ambiente
+
+Você pode substituir os caminhos padrão usando variáveis de ambiente. Isso é útil para instalações portáteis, implantações em contêineres ou para executar o picoclaw como um serviço do sistema. Essas variáveis são independentes e controlam caminhos diferentes.
+
+| Variável | Descrição | Caminho Padrão |
+|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
+| `PICOCLAW_CONFIG` | Substitui o caminho para o arquivo de configuração. Isso informa diretamente ao picoclaw qual `config.json` carregar, ignorando todos os outros locais. | `~/.picoclaw/config.json` |
+| `PICOCLAW_HOME` | Substitui o diretório raiz dos dados do picoclaw. Isso altera o local padrão do `workspace` e de outros diretórios de dados. | `~/.picoclaw` |
+
+**Exemplos:**
+
+```bash
+# Executar o picoclaw usando um arquivo de configuração específico
+# O caminho do workspace será lido de dentro desse arquivo de configuração
+PICOCLAW_CONFIG=/etc/picoclaw/production.json picoclaw gateway
+
+# Executar o picoclaw com todos os seus dados armazenados em /opt/picoclaw
+# A configuração será carregada do ~/.picoclaw/config.json padrão
+# O workspace será criado em /opt/picoclaw/workspace
+PICOCLAW_HOME=/opt/picoclaw picoclaw agent
+
+# Use ambos para uma configuração totalmente personalizada
+PICOCLAW_HOME=/srv/picoclaw PICOCLAW_CONFIG=/srv/picoclaw/main.json picoclaw gateway
+```
+
### Estrutura do Workspace
O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/workspace`):
@@ -758,7 +823,7 @@ O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se com
### Provedores
> [!NOTE]
-> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de voz do Telegram serão automaticamente transcritas.
+> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de áudio de qualquer canal serão automaticamente transcritas no nível do agente.
| Provedor | Finalidade | Obter API Key |
| --- | --- | --- |
@@ -973,6 +1038,17 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve
```
> Execute `picoclaw auth login --provider anthropic` para configurar credenciais OAuth.
+**Proxy/API personalizada**
+```json
+{
+ "model_name": "my-custom-model",
+ "model": "openai/custom-model",
+ "api_base": "https://my-proxy.com/v1",
+ "api_key": "sk-...",
+ "request_timeout": 300
+}
+```
+
#### Balanceamento de Carga
Configure vários endpoints para o mesmo nome de modelo—PicoClaw fará round-robin automaticamente entre eles:
diff --git a/README.vi.md b/README.vi.md
index 015bc264e..5755896ed 100644
--- a/README.vi.md
+++ b/README.vi.md
@@ -3,7 +3,7 @@
PicoClaw: Trợ lý AI Siêu Nhẹ viết bằng Go
-Phần cứng $10 · RAM 10MB · Khởi động 1 giây · 皮皮虾,我们走!
+Phần cứng $10 · RAM 10MB · Khởi động 1 giây · Nào, xuất phát!
@@ -145,39 +145,43 @@ Bạn cũng có thể chạy PicoClaw bằng Docker Compose mà không cần cà
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
-# 2. Thiết lập API Key
-cp config/config.example.json config/config.json
-vim config/config.json # Thiết lập DISCORD_BOT_TOKEN, API keys, v.v.
+# 2. Lần chạy đầu tiên — tự tạo docker/data/config.json rồi dừng lại
+docker compose -f docker/docker-compose.yml --profile gateway up
+# Container hiển thị "First-run setup complete." rồi tự dừng.
-# 3. Build & Khởi động
-docker compose --profile gateway up -d
+# 3. Thiết lập API Key
+vim docker/data/config.json # API key của provider, bot token, v.v.
+
+# 4. Khởi động
+docker compose -f docker/docker-compose.yml --profile gateway up -d
+```
> [!TIP]
> **Người dùng Docker**: Theo mặc định, Gateway lắng nghe trên `127.0.0.1`, không thể truy cập từ máy chủ. Nếu bạn cần truy cập các endpoint kiểm tra sức khỏe hoặc mở cổng, hãy đặt `PICOCLAW_GATEWAY_HOST=0.0.0.0` trong môi trường của bạn hoặc cập nhật `config.json`.
+```bash
+# 5. Xem logs
+docker compose -f docker/docker-compose.yml logs -f picoclaw-gateway
-# 4. Xem logs
-docker compose logs -f picoclaw-gateway
-
-# 5. Dừng
-docker compose --profile gateway down
+# 6. Dừng
+docker compose -f docker/docker-compose.yml --profile gateway down
```
### Chế độ Agent (chạy một lần)
```bash
# Đặt câu hỏi
-docker compose run --rm picoclaw-agent -m "2+2 bằng mấy?"
+docker compose -f docker/docker-compose.yml run --rm picoclaw-agent -m "2+2 bằng mấy?"
# Chế độ tương tác
-docker compose run --rm picoclaw-agent
+docker compose -f docker/docker-compose.yml run --rm picoclaw-agent
```
-### Build lại
+### Cập nhật
```bash
-docker compose --profile gateway build --no-cache
-docker compose --profile gateway up -d
+docker compose -f docker/docker-compose.yml pull
+docker compose -f docker/docker-compose.yml --profile gateway up -d
```
### 🚀 Bắt đầu nhanh
@@ -202,6 +206,7 @@ picoclaw onboard
"model_name": "gpt4",
"model": "openai/gpt-5.2",
"api_key": "sk-your-openai-key",
+ "request_timeout": 300,
"api_base": "https://api.openai.com/v1"
}
],
@@ -220,6 +225,9 @@ picoclaw onboard
}
```
+> **Mới**: Định dạng cấu hình `model_list` cho phép thêm nhà cung cấp mà không cần thay đổi mã nguồn. Xem [Cấu hình Mô hình](#cấu-hình-mô-hình-model_list) để biết chi tiết.
+> `request_timeout` là tùy chọn và dùng đơn vị giây. Nếu bỏ qua hoặc đặt `<= 0`, PicoClaw sẽ dùng timeout mặc định (120s).
+
**3. Lấy API Key**
* **Nhà cung cấp LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
@@ -248,7 +256,7 @@ Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk, LINE hoặc WeCom.
| **QQ** | Dễ (AppID + AppSecret) |
| **DingTalk** | Trung bình (app credentials) |
| **LINE** | Trung bình (credentials + webhook URL) |
-| **WeCom** | Trung bình (CorpID + cấu hình webhook) |
+| **WeCom AI Bot** | Trung bình (Token + khóa AES) |
Telegram (Khuyên dùng)
@@ -416,8 +424,6 @@ picoclaw gateway
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
@@ -431,7 +437,7 @@ LINE yêu cầu HTTPS cho webhook. Sử dụng reverse proxy hoặc tunnel:
```bash
# Ví dụ với ngrok
-ngrok http 18791
+ngrok http 18790
```
Sau đó cài đặt Webhook URL trong LINE Developers Console thành `https://your-domain/webhook/line` và bật **Use webhook**.
@@ -444,19 +450,20 @@ picoclaw gateway
> Trong nhóm chat, bot chỉ phản hồi khi được @mention. Các câu trả lời sẽ trích dẫn tin nhắn gốc.
-> **Docker Compose**: Thêm `ports: ["18791:18791"]` vào service `picoclaw-gateway` để mở port webhook.
+> **Docker Compose**: Nếu bạn cần mở port webhook cục bộ, hãy thêm một rule chuyển tiếp từ port Gateway (mặc định 18790) tới host. Lưu ý: LINE webhook được phục vụ bởi Gateway HTTP chung (mặc định 127.0.0.1:18790).
WeCom (WeChat Work)
-PicoClaw hỗ trợ hai loại tích hợp WeCom:
+PicoClaw hỗ trợ ba loại tích hợp WeCom:
-**Tùy chọn 1: WeCom Bot (Robot Thông minh)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm
-**Tùy chọn 2: WeCom App (Ứng dụng Tự xây dựng)** - Nhiều tính năng hơn, nhắn tin chủ động
+**Tùy chọn 1: WeCom Bot (Robot)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm
+**Tùy chọn 2: WeCom App (Ứng dụng Tùy chỉnh)** - Nhiều tính năng hơn, nhắn tin chủ động, chỉ chat riêng tư
+**Tùy chọn 3: WeCom AI Bot (Bot Thông Minh)** - Bot AI chính thức, phản hồi streaming, hỗ trợ nhóm và riêng tư
-Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) để biết hướng dẫn chi tiết.
+Xem [Hướng dẫn Cấu hình WeCom AI Bot](docs/channels/wecom/wecom_aibot/README.zh.md) để biết hướng dẫn chi tiết.
**Thiết lập Nhanh - WeCom Bot:**
@@ -475,8 +482,6 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": []
}
@@ -484,6 +489,8 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
}
```
+> **Lưu ý:** Các endpoint webhook của WeCom Bot được phục vụ bởi máy chủ Gateway HTTP dùng chung (mặc định 127.0.0.1:18790). Nếu bạn cần truy cập từ bên ngoài, hãy cấu hình reverse proxy hoặc mở cổng Gateway tương ứng.
+
**Thiết lập Nhanh - WeCom App:**
**1. Tạo ứng dụng**
@@ -495,7 +502,7 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
**2. Cấu hình nhận tin nhắn**
* Trong chi tiết ứng dụng, nhấp vào "Nhận Tin nhắn" → "Thiết lập API"
-* Đặt URL thành `http://your-server:18792/webhook/wecom-app`
+* Đặt URL thành `http://your-server:18790/webhook/wecom-app`
* Tạo **Token** và **EncodingAESKey**
**3. Cấu hình**
@@ -510,8 +517,6 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": []
}
@@ -525,7 +530,40 @@ Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) đ
picoclaw gateway
```
-> **Lưu ý**: WeCom App yêu cầu mở cổng 18792 cho callback webhook. Sử dụng proxy ngược cho HTTPS trong môi trường sản xuất.
+> **Lưu ý**: WeCom App callback webhook được phục vụ bởi Gateway HTTP chung (mặc định 127.0.0.1:18790). Sử dụng proxy ngược để cung cấp HTTPS trong môi trường production nếu cần.
+
+**Thiết lập Nhanh - WeCom AI Bot:**
+
+**1. Tạo AI Bot**
+
+* Truy cập Bảng điều khiển Quản trị WeCom → Quản lý Ứng dụng → AI Bot
+* Cấu hình URL callback: `http://your-server:18791/webhook/wecom-aibot`
+* Sao chép **Token** và tạo **EncodingAESKey**
+
+**2. Cấu hình**
+
+```json
+{
+ "channels": {
+ "wecom_aibot": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "allow_from": [],
+ "welcome_message": "Xin chào! Tôi có thể giúp gì cho bạn?"
+ }
+ }
+}
+```
+
+**3. Chạy**
+
+```bash
+picoclaw gateway
+```
+
+> **Lưu ý**: WeCom AI Bot sử dụng giao thức pull streaming — không lo timeout phản hồi. Tác vụ dài (>5,5 phút) tự động chuyển sang gửi qua `response_url`.
@@ -539,6 +577,31 @@ Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một
File cấu hình: `~/.picoclaw/config.json`
+### Biến môi trường
+
+Bạn có thể ghi đè các đường dẫn mặc định bằng cách sử dụng các biến môi trường. Điều này hữu ích cho việc cài đặt di động, triển khai container hóa hoặc chạy picoclaw như một dịch vụ hệ thống. Các biến này độc lập và kiểm soát các đường dẫn khác nhau.
+
+| Biến | Mô tả | Đường dẫn mặc định |
+|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
+| `PICOCLAW_CONFIG` | Ghi đè đường dẫn đến file cấu hình. Điều này trực tiếp yêu cầu picoclaw tải file `config.json` nào, bỏ qua tất cả các vị trí khác. | `~/.picoclaw/config.json` |
+| `PICOCLAW_HOME` | Ghi đè thư mục gốc cho dữ liệu picoclaw. Điều này thay đổi vị trí mặc định của `workspace` và các thư mục dữ liệu khác. | `~/.picoclaw` |
+
+**Ví dụ:**
+
+```bash
+# Chạy picoclaw bằng một file cấu hình cụ thể
+# Đường dẫn workspace sẽ được đọc từ trong file cấu hình đó
+PICOCLAW_CONFIG=/etc/picoclaw/production.json picoclaw gateway
+
+# Chạy picoclaw với tất cả dữ liệu được lưu trữ trong /opt/picoclaw
+# Cấu hình sẽ được tải từ ~/.picoclaw/config.json mặc định
+# Workspace sẽ được tạo tại /opt/picoclaw/workspace
+PICOCLAW_HOME=/opt/picoclaw picoclaw agent
+
+# Sử dụng cả hai để có thiết lập tùy chỉnh hoàn toàn
+PICOCLAW_HOME=/srv/picoclaw PICOCLAW_CONFIG=/srv/picoclaw/main.json picoclaw gateway
+```
+
### Cấu trúc Workspace
PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định: `~/.picoclaw/workspace`):
@@ -732,7 +795,7 @@ Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và
### Nhà cung cấp (Providers)
> [!NOTE]
-> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn thoại trên Telegram sẽ được tự động chuyển thành văn bản.
+> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn âm thanh từ bất kỳ kênh nào sẽ được tự động chuyển thành văn bản ở cấp độ agent.
| Nhà cung cấp | Mục đích | Lấy API Key |
| --- | --- | --- |
@@ -944,6 +1007,17 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch
```
> Chạy `picoclaw auth login --provider anthropic` để thiết lập thông tin xác thực OAuth.
+**Proxy/API tùy chỉnh**
+```json
+{
+ "model_name": "my-custom-model",
+ "model": "openai/custom-model",
+ "api_base": "https://my-proxy.com/v1",
+ "api_key": "sk-...",
+ "request_timeout": 300
+}
+```
+
#### Cân bằng Tải tải
Định cấu hình nhiều endpoint cho cùng một tên mô hình—PicoClaw sẽ tự động phân phối round-robin giữa chúng:
diff --git a/README.zh.md b/README.zh.md
index 4f4bde46a..bd90173f9 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -166,41 +166,43 @@ make install
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
-# 2. 设置 API Key
-cp config/config.example.json config/config.json
-vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等
+# 2. 首次运行 — 自动生成 docker/data/config.json 后退出
+docker compose -f docker/docker-compose.yml --profile gateway up
+# 容器打印 "First-run setup complete." 后自动停止
-# 3. 构建并启动
-docker compose --profile gateway up -d
+# 3. 填写 API Key 等配置
+vim docker/data/config.json # 设置 provider API key、Bot Token 等
+
+# 4. 正式启动
+docker compose -f docker/docker-compose.yml --profile gateway up -d
+```
> [!TIP]
-**Docker 用户**: 默认情况下, Gateway监听 `127.0.0.1`,这使得这个端口未暴露到容器外。如果你需要通过端口映射访问健康检查接口, 请在环境变量中设置 `PICOCLAW_GATEWAY_HOST=0.0.0.0` 或修改 `config.json`。
+> **Docker 用户**: 默认情况下, Gateway 监听 `127.0.0.1`,该端口不会暴露到容器外。如果需要通过端口映射访问健康检查接口,请在环境变量中设置 `PICOCLAW_GATEWAY_HOST=0.0.0.0` 或修改 `config.json`。
-# 4. 查看日志
-docker compose logs -f picoclaw-gateway
-
-# 5. 停止
-docker compose --profile gateway down
+```bash
+# 5. 查看日志
+docker compose -f docker/docker-compose.yml logs -f picoclaw-gateway
+# 6. 停止
+docker compose -f docker/docker-compose.yml --profile gateway down
```
### Agent 模式 (一次性运行)
```bash
# 提问
-docker compose run --rm picoclaw-agent -m "2+2 等于几?"
+docker compose -f docker/docker-compose.yml run --rm picoclaw-agent -m "2+2 等于几?"
# 交互模式
-docker compose run --rm picoclaw-agent
-
+docker compose -f docker/docker-compose.yml run --rm picoclaw-agent
```
-### 重新构建
+### 更新镜像
```bash
-docker compose --profile gateway build --no-cache
-docker compose --profile gateway up -d
-
+docker compose -f docker/docker-compose.yml pull
+docker compose -f docker/docker-compose.yml --profile gateway up -d
```
### 🚀 快速开始
@@ -234,7 +236,8 @@ picoclaw onboard
{
"model_name": "gpt4",
"model": "openai/gpt-5.2",
- "api_key": "your-api-key"
+ "api_key": "your-api-key",
+ "request_timeout": 300
},
{
"model_name": "claude-sonnet-4.6",
@@ -263,6 +266,7 @@ picoclaw onboard
```
> **新功能**: `model_list` 配置格式支持零代码添加 provider。详见[模型配置](#模型配置-model_list)章节。
+> `request_timeout` 为可选项,单位为秒。若省略或设置为 `<= 0`,PicoClaw 使用默认超时(120 秒)。
**3. 获取 API Key**
@@ -286,6 +290,8 @@ picoclaw agent -m "2+2 等于几?"
PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方。
+> **注意**: 所有 Webhook 类渠道(LINE、WeCom 等)均挂载在同一个 Gateway HTTP 服务器上(`gateway.host`:`gateway.port`,默认 `127.0.0.1:18790`),无需为每个渠道单独配置端口。注意:飞书(Feishu)使用 WebSocket/SDK 模式,不通过该共享 HTTP webhook 服务器接收消息。
+
### 核心渠道
| 渠道 | 设置难度 | 特性说明 | 文档链接 |
@@ -295,7 +301,7 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方
| **Slack** | ⭐ 简单 | **Socket Mode** (无需公网 IP),企业级支持 | [查看文档](docs/channels/slack/README.zh.md) |
| **QQ** | ⭐⭐ 中等 | 官方机器人 API,适合国内社群 | [查看文档](docs/channels/qq/README.zh.md) |
| **钉钉 (DingTalk)** | ⭐⭐ 中等 | Stream 模式无需公网,企业办公首选 | [查看文档](docs/channels/dingtalk/README.zh.md) |
-| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)和自建应用(API) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) |
+| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)、自建应用(API)和智能机器人(AI Bot) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) / [AI Bot 文档](docs/channels/wecom/wecom_aibot/README.zh.md) |
| **飞书 (Feishu)** | ⭐⭐⭐ 较难 | 企业级协作,功能丰富 | [查看文档](docs/channels/feishu/README.zh.md) |
| **Line** | ⭐⭐⭐ 较难 | 需要 HTTPS Webhook | [查看文档](docs/channels/line/README.zh.md) |
| **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) |
@@ -311,6 +317,31 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方
配置文件路径: `~/.picoclaw/config.json`
+### 环境变量
+
+你可以使用环境变量覆盖默认路径。这对于便携安装、容器化部署或将 picoclaw 作为系统服务运行非常有用。这些变量是独立的,控制不同的路径。
+
+| 变量 | 描述 | 默认路径 |
+|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
+| `PICOCLAW_CONFIG` | 覆盖配置文件的路径。这直接告诉 picoclaw 加载哪个 `config.json`,忽略所有其他位置。 | `~/.picoclaw/config.json` |
+| `PICOCLAW_HOME` | 覆盖 picoclaw 数据根目录。这会更改 `workspace` 和其他数据目录的默认位置。 | `~/.picoclaw` |
+
+**示例:**
+
+```bash
+# 使用特定的配置文件运行 picoclaw
+# 工作区路径将从该配置文件中读取
+PICOCLAW_CONFIG=/etc/picoclaw/production.json picoclaw gateway
+
+# 在 /opt/picoclaw 中存储所有数据运行 picoclaw
+# 配置将从默认的 ~/.picoclaw/config.json 加载
+# 工作区将在 /opt/picoclaw/workspace 创建
+PICOCLAW_HOME=/opt/picoclaw picoclaw agent
+
+# 同时使用两者进行完全自定义设置
+PICOCLAW_HOME=/srv/picoclaw PICOCLAW_CONFIG=/srv/picoclaw/main.json picoclaw gateway
+```
+
### 工作区布局 (Workspace Layout)
PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`):
@@ -331,6 +362,20 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work
```
+### 技能来源 (Skill Sources)
+
+默认情况下,技能会按以下顺序加载:
+
+1. `~/.picoclaw/workspace/skills`(工作区)
+2. `~/.picoclaw/skills`(全局)
+3. `/skills`(内置)
+
+在高级/测试场景下,可通过以下环境变量覆盖内置技能目录:
+
+```bash
+export PICOCLAW_BUILTIN_SKILLS=/path/to/skills
+```
+
### 心跳 / 周期性任务 (Heartbeat)
PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件:
@@ -414,7 +459,7 @@ Agent 读取 HEARTBEAT.md
### 提供商 (Providers)
> [!NOTE]
-> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。
+> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,任意渠道的音频消息都将在 Agent 层面自动转录为文字。
| 提供商 | 用途 | 获取 API Key |
| -------------------- | ---------------------------- | -------------------------------------------------------------------- |
@@ -550,7 +595,8 @@ Agent 读取 HEARTBEAT.md
"model_name": "my-custom-model",
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
- "api_key": "sk-..."
+ "api_key": "sk-...",
+ "request_timeout": 300
}
```
diff --git a/assets/picoclaw_detect_person.mp4 b/assets/picoclaw_detect_person.mp4
deleted file mode 100644
index b56999689..000000000
Binary files a/assets/picoclaw_detect_person.mp4 and /dev/null differ
diff --git a/assets/wechat.png b/assets/wechat.png
index 776c07885..32998c122 100644
Binary files a/assets/wechat.png and b/assets/wechat.png differ
diff --git a/cmd/picoclaw-launcher-tui/internal/config/store.go b/cmd/picoclaw-launcher-tui/internal/config/store.go
new file mode 100644
index 000000000..0236de19f
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/config/store.go
@@ -0,0 +1,49 @@
+package configstore
+
+import (
+ "errors"
+ "os"
+ "path/filepath"
+
+ picoclawconfig "github.com/sipeed/picoclaw/pkg/config"
+)
+
+const (
+ configDirName = ".picoclaw"
+ configFileName = "config.json"
+)
+
+func ConfigPath() (string, error) {
+ dir, err := ConfigDir()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(dir, configFileName), nil
+}
+
+func ConfigDir() (string, error) {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(home, configDirName), nil
+}
+
+func Load() (*picoclawconfig.Config, error) {
+ path, err := ConfigPath()
+ if err != nil {
+ return nil, err
+ }
+ return picoclawconfig.LoadConfig(path)
+}
+
+func Save(cfg *picoclawconfig.Config) error {
+ if cfg == nil {
+ return errors.New("config is nil")
+ }
+ path, err := ConfigPath()
+ if err != nil {
+ return err
+ }
+ return picoclawconfig.SaveConfig(path, cfg)
+}
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/app.go b/cmd/picoclaw-launcher-tui/internal/ui/app.go
new file mode 100644
index 000000000..4947d6aea
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/ui/app.go
@@ -0,0 +1,506 @@
+package ui
+
+import (
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+
+ "github.com/gdamore/tcell/v2"
+ "github.com/rivo/tview"
+
+ configstore "github.com/sipeed/picoclaw/cmd/picoclaw-launcher-tui/internal/config"
+ picoclawconfig "github.com/sipeed/picoclaw/pkg/config"
+)
+
+type appState struct {
+ app *tview.Application
+ pages *tview.Pages
+ stack []string
+ config *picoclawconfig.Config
+ configPath string
+ gatewayCmd *exec.Cmd
+ menus map[string]*Menu
+ original []byte
+ hasOriginal bool
+ backupPath string
+ dirty bool
+ logPath string
+}
+
+func Run() error {
+ applyStyles()
+ cfg, err := configstore.Load()
+ if err != nil {
+ return err
+ }
+ path, err := configstore.ConfigPath()
+ if err != nil {
+ return err
+ }
+
+ if cfg == nil {
+ cfg = picoclawconfig.DefaultConfig()
+ }
+
+ originalData, hasOriginal := loadOriginalConfig(path)
+ backupPath := path + ".bak"
+ if hasOriginal {
+ _ = writeBackupConfig(backupPath, originalData)
+ }
+
+ logPath := filepath.Join(filepath.Dir(path), "gateway.log")
+ state := &appState{
+ app: tview.NewApplication(),
+ pages: tview.NewPages(),
+ config: cfg,
+ configPath: path,
+ menus: map[string]*Menu{},
+ original: originalData,
+ hasOriginal: hasOriginal,
+ backupPath: backupPath,
+ logPath: logPath,
+ }
+
+ state.push("main", state.mainMenu())
+
+ root := tview.NewFlex().SetDirection(tview.FlexRow)
+ root.AddItem(bannerView(), 6, 0, false)
+ root.AddItem(state.pages, 0, 1, true)
+
+ if err := state.app.SetRoot(root, true).EnableMouse(false).Run(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *appState) push(name string, primitive tview.Primitive) {
+ s.pages.AddPage(name, primitive, true, true)
+ s.stack = append(s.stack, name)
+ s.pages.SwitchToPage(name)
+ if menu, ok := primitive.(*Menu); ok {
+ s.menus[name] = menu
+ }
+}
+
+func (s *appState) pop() {
+ if len(s.stack) == 0 {
+ return
+ }
+ last := s.stack[len(s.stack)-1]
+ s.pages.RemovePage(last)
+ s.stack = s.stack[:len(s.stack)-1]
+ if len(s.stack) == 0 {
+ s.app.Stop()
+ return
+ }
+ current := s.stack[len(s.stack)-1]
+ s.pages.SwitchToPage(current)
+ if menu, ok := s.menus[current]; ok {
+ s.refreshMenu(current, menu)
+ }
+}
+
+func (s *appState) mainMenu() tview.Primitive {
+ menu := NewMenu("Config Menu", nil)
+ refreshMainMenu(menu, s)
+ menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
+ switch event.Key() {
+ case tcell.KeyEsc:
+ s.requestExit()
+ return nil
+ }
+ if event.Rune() == 'q' {
+ s.requestExit()
+ return nil
+ }
+ return event
+ })
+
+ return menu
+}
+
+func (s *appState) refreshMenu(name string, menu *Menu) {
+ switch name {
+ case "main":
+ refreshMainMenu(menu, s)
+ case "model":
+ refreshModelMenuFromState(menu, s)
+ case "channel":
+ refreshChannelMenuFromState(menu, s)
+ }
+}
+
+func refreshMainMenuIfPresent(s *appState) {
+ if menu, ok := s.menus["main"]; ok {
+ refreshMainMenu(menu, s)
+ }
+}
+
+func refreshMainMenu(menu *Menu, s *appState) {
+ selectedModel := s.selectedModelName()
+ modelReady := selectedModel != ""
+ channelReady := s.hasEnabledChannel()
+ gatewayRunning := s.gatewayCmd != nil || s.isGatewayRunning()
+
+ gatewayLabel := "Start Gateway"
+ gatewayDescription := "Launch gateway for channels"
+ if gatewayRunning {
+ gatewayLabel = "Stop Gateway"
+ gatewayDescription = "Gateway running"
+ }
+
+ items := []MenuItem{
+ {
+ Label: rootModelLabel(selectedModel),
+ Description: rootModelDescription(selectedModel),
+ Action: func() {
+ s.push("model", s.modelMenu())
+ },
+ MainColor: func() *tcell.Color {
+ if modelReady {
+ return nil
+ }
+ color := tcell.ColorGray
+ return &color
+ }(),
+ },
+ {
+ Label: rootChannelLabel(channelReady),
+ Description: rootChannelDescription(channelReady),
+ Action: func() {
+ s.push("channel", s.channelMenu())
+ },
+ MainColor: func() *tcell.Color {
+ if channelReady {
+ return nil
+ }
+ color := tcell.ColorGray
+ return &color
+ }(),
+ },
+ {
+ Label: "Start Talk",
+ Description: "Open picoclaw agent in terminal",
+ Action: func() {
+ s.requestStartTalk()
+ },
+ Disabled: !modelReady,
+ },
+ {
+ Label: gatewayLabel,
+ Description: gatewayDescription,
+ Action: func() {
+ if gatewayRunning {
+ s.stopGateway()
+ } else {
+ s.requestStartGateway()
+ }
+ refreshMainMenu(menu, s)
+ },
+ Disabled: !gatewayRunning && (!modelReady || !channelReady),
+ },
+ {
+ Label: "View Gateway Log",
+ Description: "Open gateway.log",
+ Action: func() {
+ s.viewGatewayLog()
+ },
+ },
+ {
+ Label: "Exit",
+ Description: "Exit the TUI",
+ Action: func() {
+ s.requestExit()
+ },
+ },
+ }
+ menu.applyItems(items)
+}
+
+func (s *appState) applyChangesValidated() bool {
+ if err := s.config.ValidateModelList(); err != nil {
+ s.showMessage("Validation failed", err.Error())
+ return false
+ }
+ if err := s.validateAgentModel(); err != nil {
+ s.showMessage("Validation failed", err.Error())
+ return false
+ }
+ if err := configstore.Save(s.config); err != nil {
+ s.showMessage("Save failed", err.Error())
+ return false
+ }
+ if data, err := os.ReadFile(s.configPath); err == nil {
+ s.original = data
+ s.hasOriginal = true
+ _ = writeBackupConfig(s.backupPath, data)
+ }
+ return true
+}
+
+func (s *appState) requestExit() {
+ if s.dirty {
+ s.confirmApplyOrDiscard(func() {
+ s.app.Stop()
+ }, func() {
+ s.discardChanges()
+ s.app.Stop()
+ })
+ return
+ }
+ s.app.Stop()
+}
+
+func (s *appState) requestStartTalk() {
+ if s.dirty {
+ s.confirmApplyOrDiscard(func() {
+ s.startTalk()
+ }, func() {
+ s.startTalk()
+ })
+ return
+ }
+ s.startTalk()
+}
+
+func (s *appState) requestStartGateway() {
+ if s.dirty {
+ s.confirmApplyOrDiscard(func() {
+ s.startGateway()
+ }, func() {
+ s.startGateway()
+ })
+ return
+ }
+ s.startGateway()
+}
+
+func (s *appState) viewGatewayLog() {
+ data, err := os.ReadFile(s.logPath)
+ if err != nil {
+ s.showMessage("Log not found", "gateway.log not found")
+ return
+ }
+ text := tview.NewTextView()
+ text.SetBorder(true).SetTitle("Gateway Log")
+ text.SetText(string(data))
+ text.SetDoneFunc(func(key tcell.Key) {
+ s.pages.RemovePage("log")
+ })
+ text.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
+ if event.Key() == tcell.KeyEsc {
+ s.pages.RemovePage("log")
+ return nil
+ }
+ return event
+ })
+ s.pages.AddPage("log", text, true, true)
+}
+
+func (s *appState) selectedModelName() string {
+ modelName := strings.TrimSpace(s.config.Agents.Defaults.Model)
+ if modelName == "" {
+ return ""
+ }
+ if !s.isActiveModelValid() {
+ return ""
+ }
+ return modelName
+}
+
+func rootModelLabel(selected string) string {
+ if selected == "" {
+ return "Model (no model selected)"
+ }
+ return "Model (" + selected + ")"
+}
+
+func rootModelDescription(selected string) string {
+ if selected == "" {
+ return "no model selected"
+ }
+ return "selected"
+}
+
+func rootChannelLabel(valid bool) string {
+ if !valid {
+ return "Channel (no channel enabled)"
+ }
+ return "Channel"
+}
+
+func rootChannelDescription(valid bool) string {
+ if !valid {
+ return "no channel enabled"
+ }
+ return "enabled"
+}
+
+func (s *appState) startTalk() {
+ if !s.isActiveModelValid() {
+ s.showMessage("Model required", "Select a valid model before starting talk")
+ return
+ }
+ if !s.applyChangesValidated() {
+ return
+ }
+ s.app.Suspend(func() {
+ cmd := exec.Command("picoclaw", "agent")
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ _ = cmd.Run()
+ })
+}
+
+func (s *appState) startGateway() {
+ if !s.isActiveModelValid() {
+ s.showMessage("Model required", "Select a valid model before starting gateway")
+ return
+ }
+ if !s.hasEnabledChannel() {
+ s.showMessage("Channel required", "Enable at least one channel before starting gateway")
+ return
+ }
+ if !s.applyChangesValidated() {
+ return
+ }
+ _ = stopGatewayProcess()
+ cmd := exec.Command("picoclaw", "gateway")
+ logFile, err := os.OpenFile(s.logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ s.showMessage("Gateway failed", err.Error())
+ return
+ }
+ cmd.Stdout = logFile
+ cmd.Stderr = logFile
+ if err := cmd.Start(); err != nil {
+ s.showMessage("Gateway failed", err.Error())
+ _ = logFile.Close()
+ return
+ }
+ _ = logFile.Close()
+ s.gatewayCmd = cmd
+}
+
+func (s *appState) stopGateway() {
+ _ = stopGatewayProcess()
+ if s.gatewayCmd != nil && s.gatewayCmd.Process != nil {
+ _ = s.gatewayCmd.Process.Kill()
+ }
+ s.gatewayCmd = nil
+}
+
+func (s *appState) isGatewayRunning() bool {
+ return isGatewayProcessRunning()
+}
+
+func (s *appState) validateAgentModel() error {
+ modelName := strings.TrimSpace(s.config.Agents.Defaults.Model)
+ if modelName == "" {
+ return nil
+ }
+ _, err := s.config.GetModelConfig(modelName)
+ return err
+}
+
+func (s *appState) isActiveModelValid() bool {
+ modelName := strings.TrimSpace(s.config.Agents.Defaults.Model)
+ if modelName == "" {
+ return false
+ }
+ cfg, err := s.config.GetModelConfig(modelName)
+ if err != nil {
+ return false
+ }
+ hasKey := strings.TrimSpace(cfg.APIKey) != "" || strings.TrimSpace(cfg.AuthMethod) == "oauth"
+ hasModel := strings.TrimSpace(cfg.Model) != ""
+ return hasKey && hasModel
+}
+
+func (s *appState) hasEnabledChannel() bool {
+ c := s.config.Channels
+ return c.Telegram.Enabled || c.Discord.Enabled || c.QQ.Enabled || c.MaixCam.Enabled ||
+ c.WhatsApp.Enabled || c.Feishu.Enabled || c.DingTalk.Enabled || c.Slack.Enabled ||
+ c.LINE.Enabled || c.OneBot.Enabled || c.WeCom.Enabled || c.WeComApp.Enabled
+}
+
+func (s *appState) confirmApplyOrDiscard(onApply func(), onDiscard func()) {
+ if s.pages.HasPage("apply") {
+ return
+ }
+ modal := tview.NewModal().
+ SetText("Apply changes or discard before continuing?").
+ AddButtons([]string{"Cancel", "Discard", "Apply"}).
+ SetDoneFunc(func(buttonIndex int, buttonLabel string) {
+ s.pages.RemovePage("apply")
+ switch buttonLabel {
+ case "Discard":
+ s.discardChanges()
+ if onDiscard != nil {
+ onDiscard()
+ }
+ case "Apply":
+ if s.applyChangesValidated() {
+ s.dirty = false
+ if onApply != nil {
+ onApply()
+ }
+ }
+ }
+ })
+ modal.SetBorder(true)
+ s.pages.AddPage("apply", modal, true, true)
+}
+
+func (s *appState) discardChanges() {
+ if s.hasOriginal {
+ _ = writeOriginalConfig(s.configPath, s.original)
+ } else {
+ _ = os.Remove(s.configPath)
+ }
+ _ = os.Remove(s.backupPath)
+ if cfg, err := configstore.Load(); err == nil && cfg != nil {
+ s.config = cfg
+ }
+ s.dirty = false
+ refreshMainMenuIfPresent(s)
+}
+
+func (s *appState) showMessage(title, message string) {
+ if s.pages.HasPage("message") {
+ return
+ }
+ modal := tview.NewModal().
+ SetText(strings.TrimSpace(message)).
+ AddButtons([]string{"OK"}).
+ SetDoneFunc(func(_ int, _ string) {
+ s.pages.RemovePage("message")
+ })
+ modal.SetTitle(title).SetBorder(true)
+ modal.SetBackgroundColor(tview.Styles.ContrastBackgroundColor)
+ modal.SetTextColor(tview.Styles.PrimaryTextColor)
+ modal.SetButtonBackgroundColor(tcell.NewRGBColor(112, 102, 255))
+ modal.SetButtonTextColor(tview.Styles.PrimaryTextColor)
+ s.pages.AddPage("message", modal, true, true)
+}
+
+func loadOriginalConfig(path string) ([]byte, bool) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, false
+ }
+ return nil, false
+ }
+ return data, true
+}
+
+func writeOriginalConfig(path string, data []byte) error {
+ return os.WriteFile(path, data, 0o600)
+}
+
+func writeBackupConfig(path string, data []byte) error {
+ return os.WriteFile(path, data, 0o600)
+}
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/channel.go b/cmd/picoclaw-launcher-tui/internal/ui/channel.go
new file mode 100644
index 000000000..49a6ccc5d
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/ui/channel.go
@@ -0,0 +1,410 @@
+package ui
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/gdamore/tcell/v2"
+ "github.com/rivo/tview"
+
+ picoclawconfig "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func (s *appState) buildChannelMenuItems() []MenuItem {
+ return []MenuItem{
+ {Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
+ channelItem(
+ "Telegram",
+ "Telegram bot settings",
+ s.config.Channels.Telegram.Enabled,
+ func() { s.push("channel-telegram", s.telegramForm()) },
+ ),
+ channelItem(
+ "Discord",
+ "Discord bot settings",
+ s.config.Channels.Discord.Enabled,
+ func() { s.push("channel-discord", s.discordForm()) },
+ ),
+ channelItem(
+ "QQ",
+ "QQ bot settings",
+ s.config.Channels.QQ.Enabled,
+ func() { s.push("channel-qq", s.qqForm()) },
+ ),
+ channelItem(
+ "MaixCam",
+ "MaixCam gateway",
+ s.config.Channels.MaixCam.Enabled,
+ func() { s.push("channel-maixcam", s.maixcamForm()) },
+ ),
+ channelItem(
+ "WhatsApp",
+ "WhatsApp bridge",
+ s.config.Channels.WhatsApp.Enabled,
+ func() { s.push("channel-whatsapp", s.whatsappForm()) },
+ ),
+ channelItem(
+ "Feishu",
+ "Feishu bot settings",
+ s.config.Channels.Feishu.Enabled,
+ func() { s.push("channel-feishu", s.feishuForm()) },
+ ),
+ channelItem(
+ "DingTalk",
+ "DingTalk bot settings",
+ s.config.Channels.DingTalk.Enabled,
+ func() { s.push("channel-dingtalk", s.dingtalkForm()) },
+ ),
+ channelItem(
+ "Slack",
+ "Slack bot settings",
+ s.config.Channels.Slack.Enabled,
+ func() { s.push("channel-slack", s.slackForm()) },
+ ),
+ channelItem(
+ "LINE",
+ "LINE bot settings",
+ s.config.Channels.LINE.Enabled,
+ func() { s.push("channel-line", s.lineForm()) },
+ ),
+ channelItem(
+ "OneBot",
+ "OneBot settings",
+ s.config.Channels.OneBot.Enabled,
+ func() { s.push("channel-onebot", s.onebotForm()) },
+ ),
+ channelItem(
+ "WeCom",
+ "WeCom bot settings",
+ s.config.Channels.WeCom.Enabled,
+ func() { s.push("channel-wecom", s.wecomForm()) },
+ ),
+ channelItem(
+ "WeCom App",
+ "WeCom App settings",
+ s.config.Channels.WeComApp.Enabled,
+ func() { s.push("channel-wecomapp", s.wecomAppForm()) },
+ ),
+ }
+}
+
+func (s *appState) channelMenu() tview.Primitive {
+ menu := NewMenu("Channels", s.buildChannelMenuItems())
+ menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
+ if event.Key() == tcell.KeyEsc {
+ s.pop()
+ return nil
+ }
+ if event.Rune() == 'q' {
+ s.pop()
+ return nil
+ }
+ return event
+ })
+ return menu
+}
+
+func refreshChannelMenuFromState(menu *Menu, s *appState) {
+ menu.applyItems(s.buildChannelMenuItems())
+}
+
+func (s *appState) telegramForm() tview.Primitive {
+ cfg := &s.config.Channels.Telegram
+ form := baseChannelForm("Telegram", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
+ cfg.Token = strings.TrimSpace(text)
+ })
+ form.AddInputField("Proxy", cfg.Proxy, 128, nil, func(text string) {
+ cfg.Proxy = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) discordForm() tview.Primitive {
+ cfg := &s.config.Channels.Discord
+ form := baseChannelForm("Discord", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
+ cfg.Token = strings.TrimSpace(text)
+ })
+ form.AddCheckbox("Mention Only", cfg.MentionOnly, func(checked bool) {
+ cfg.MentionOnly = checked
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) qqForm() tview.Primitive {
+ cfg := &s.config.Channels.QQ
+ form := baseChannelForm("QQ", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
+ cfg.AppID = strings.TrimSpace(text)
+ })
+ form.AddInputField("App Secret", cfg.AppSecret, 128, nil, func(text string) {
+ cfg.AppSecret = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) maixcamForm() tview.Primitive {
+ cfg := &s.config.Channels.MaixCam
+ form := baseChannelForm("MaixCam", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Host", cfg.Host, 64, nil, func(text string) {
+ cfg.Host = strings.TrimSpace(text)
+ })
+ addIntField(form, "Port", cfg.Port, func(value int) { cfg.Port = value })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) whatsappForm() tview.Primitive {
+ cfg := &s.config.Channels.WhatsApp
+ form := baseChannelForm("WhatsApp", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Bridge URL", cfg.BridgeURL, 128, nil, func(text string) {
+ cfg.BridgeURL = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) feishuForm() tview.Primitive {
+ cfg := &s.config.Channels.Feishu
+ form := baseChannelForm("Feishu", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) {
+ cfg.AppID = strings.TrimSpace(text)
+ })
+ form.AddInputField("App Secret", cfg.AppSecret, 128, nil, func(text string) {
+ cfg.AppSecret = strings.TrimSpace(text)
+ })
+ form.AddInputField("Encrypt Key", cfg.EncryptKey, 128, nil, func(text string) {
+ cfg.EncryptKey = strings.TrimSpace(text)
+ })
+ form.AddInputField("Verification Token", cfg.VerificationToken, 128, nil, func(text string) {
+ cfg.VerificationToken = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) dingtalkForm() tview.Primitive {
+ cfg := &s.config.Channels.DingTalk
+ form := baseChannelForm("DingTalk", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Client ID", cfg.ClientID, 64, nil, func(text string) {
+ cfg.ClientID = strings.TrimSpace(text)
+ })
+ form.AddInputField("Client Secret", cfg.ClientSecret, 128, nil, func(text string) {
+ cfg.ClientSecret = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) slackForm() tview.Primitive {
+ cfg := &s.config.Channels.Slack
+ form := baseChannelForm("Slack", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Bot Token", cfg.BotToken, 128, nil, func(text string) {
+ cfg.BotToken = strings.TrimSpace(text)
+ })
+ form.AddInputField("App Token", cfg.AppToken, 128, nil, func(text string) {
+ cfg.AppToken = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) lineForm() tview.Primitive {
+ cfg := &s.config.Channels.LINE
+ form := baseChannelForm("LINE", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Channel Secret", cfg.ChannelSecret, 128, nil, func(text string) {
+ cfg.ChannelSecret = strings.TrimSpace(text)
+ })
+ form.AddInputField("Channel Access Token", cfg.ChannelAccessToken, 128, nil, func(text string) {
+ cfg.ChannelAccessToken = strings.TrimSpace(text)
+ })
+ form.AddInputField("Webhook Host", cfg.WebhookHost, 64, nil, func(text string) {
+ cfg.WebhookHost = strings.TrimSpace(text)
+ })
+ addIntField(form, "Webhook Port", cfg.WebhookPort, func(value int) { cfg.WebhookPort = value })
+ form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
+ cfg.WebhookPath = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) onebotForm() tview.Primitive {
+ cfg := &s.config.Channels.OneBot
+ form := baseChannelForm("OneBot", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("WS URL", cfg.WSUrl, 128, nil, func(text string) {
+ cfg.WSUrl = strings.TrimSpace(text)
+ })
+ form.AddInputField("Access Token", cfg.AccessToken, 128, nil, func(text string) {
+ cfg.AccessToken = strings.TrimSpace(text)
+ })
+ addIntField(
+ form,
+ "Reconnect Interval",
+ cfg.ReconnectInterval,
+ func(value int) { cfg.ReconnectInterval = value },
+ )
+ form.AddInputField(
+ "Group Trigger Prefix",
+ strings.Join(cfg.GroupTriggerPrefix, ","),
+ 128,
+ nil,
+ func(text string) {
+ cfg.GroupTriggerPrefix = splitCSV(text)
+ },
+ )
+ addAllowFromField(form, &cfg.AllowFrom)
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) wecomForm() tview.Primitive {
+ cfg := &s.config.Channels.WeCom
+ form := baseChannelForm("WeCom", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
+ cfg.Token = strings.TrimSpace(text)
+ })
+ form.AddInputField("Encoding AES Key", cfg.EncodingAESKey, 128, nil, func(text string) {
+ cfg.EncodingAESKey = strings.TrimSpace(text)
+ })
+ form.AddInputField("Webhook URL", cfg.WebhookURL, 128, nil, func(text string) {
+ cfg.WebhookURL = strings.TrimSpace(text)
+ })
+ form.AddInputField("Webhook Host", cfg.WebhookHost, 64, nil, func(text string) {
+ cfg.WebhookHost = strings.TrimSpace(text)
+ })
+ addIntField(form, "Webhook Port", cfg.WebhookPort, func(value int) { cfg.WebhookPort = value })
+ form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
+ cfg.WebhookPath = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ addIntField(
+ form,
+ "Reply Timeout",
+ cfg.ReplyTimeout,
+ func(value int) { cfg.ReplyTimeout = value },
+ )
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) wecomAppForm() tview.Primitive {
+ cfg := &s.config.Channels.WeComApp
+ form := baseChannelForm("WeCom App", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled))
+ form.AddInputField("Corp ID", cfg.CorpID, 64, nil, func(text string) {
+ cfg.CorpID = strings.TrimSpace(text)
+ })
+ form.AddInputField("Corp Secret", cfg.CorpSecret, 128, nil, func(text string) {
+ cfg.CorpSecret = strings.TrimSpace(text)
+ })
+ addInt64Field(form, "Agent ID", cfg.AgentID, func(value int64) { cfg.AgentID = value })
+ form.AddInputField("Token", cfg.Token, 128, nil, func(text string) {
+ cfg.Token = strings.TrimSpace(text)
+ })
+ form.AddInputField("Encoding AES Key", cfg.EncodingAESKey, 128, nil, func(text string) {
+ cfg.EncodingAESKey = strings.TrimSpace(text)
+ })
+ form.AddInputField("Webhook Host", cfg.WebhookHost, 64, nil, func(text string) {
+ cfg.WebhookHost = strings.TrimSpace(text)
+ })
+ addIntField(form, "Webhook Port", cfg.WebhookPort, func(value int) { cfg.WebhookPort = value })
+ form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) {
+ cfg.WebhookPath = strings.TrimSpace(text)
+ })
+ addAllowFromField(form, &cfg.AllowFrom)
+ addIntField(
+ form,
+ "Reply Timeout",
+ cfg.ReplyTimeout,
+ func(value int) { cfg.ReplyTimeout = value },
+ )
+ return wrapWithBack(form, s)
+}
+
+func (s *appState) makeChannelOnEnabled(enabledPtr *bool) func(bool) {
+ return func(v bool) {
+ *enabledPtr = v
+ s.dirty = true
+ refreshMainMenuIfPresent(s)
+ if menu, ok := s.menus["channel"]; ok {
+ refreshChannelMenuFromState(menu, s)
+ }
+ }
+}
+
+func addAllowFromField(form *tview.Form, allowFrom *picoclawconfig.FlexibleStringSlice) {
+ form.AddInputField("Allow From", strings.Join(*allowFrom, ","), 128, nil, func(text string) {
+ *allowFrom = splitCSV(text)
+ })
+}
+
+func baseChannelForm(title string, enabled bool, onEnabled func(bool)) *tview.Form {
+ form := tview.NewForm()
+ form.SetBorder(true).SetTitle(fmt.Sprintf("Channel: %s", title))
+ form.SetButtonBackgroundColor(tcell.NewRGBColor(80, 250, 123))
+ form.SetButtonTextColor(tcell.NewRGBColor(12, 13, 22))
+ form.AddCheckbox("Enabled", enabled, func(checked bool) {
+ onEnabled(checked)
+ })
+ return form
+}
+
+func wrapWithBack(form *tview.Form, s *appState) tview.Primitive {
+ form.AddButton("Back", func() {
+ s.pop()
+ })
+ form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
+ if event.Key() == tcell.KeyEsc {
+ s.pop()
+ return nil
+ }
+ return event
+ })
+ return form
+}
+
+func splitCSV(input string) picoclawconfig.FlexibleStringSlice {
+ parts := strings.Split(strings.TrimSpace(input), ",")
+ cleaned := make([]string, 0, len(parts))
+ for _, part := range parts {
+ value := strings.TrimSpace(part)
+ if value == "" {
+ continue
+ }
+ cleaned = append(cleaned, value)
+ }
+ return cleaned
+}
+
+func addIntField(form *tview.Form, label string, value int, onChange func(int)) {
+ form.AddInputField(label, fmt.Sprintf("%d", value), 16, nil, func(text string) {
+ var parsed int
+ if _, err := fmt.Sscanf(strings.TrimSpace(text), "%d", &parsed); err == nil {
+ onChange(parsed)
+ }
+ })
+}
+
+func addInt64Field(form *tview.Form, label string, value int64, onChange func(int64)) {
+ form.AddInputField(label, fmt.Sprintf("%d", value), 16, nil, func(text string) {
+ var parsed int64
+ if _, err := fmt.Sscanf(strings.TrimSpace(text), "%d", &parsed); err == nil {
+ onChange(parsed)
+ }
+ })
+}
+
+func channelItem(label, description string, enabled bool, action MenuAction) MenuItem {
+ item := MenuItem{
+ Label: label,
+ Description: description,
+ Action: action,
+ }
+ if !enabled {
+ color := tcell.ColorGray
+ item.MainColor = &color
+ }
+ return item
+}
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/gateway_posix.go b/cmd/picoclaw-launcher-tui/internal/ui/gateway_posix.go
new file mode 100644
index 000000000..bc874f7f2
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/ui/gateway_posix.go
@@ -0,0 +1,16 @@
+//go:build !windows
+// +build !windows
+
+package ui
+
+import "os/exec"
+
+func isGatewayProcessRunning() bool {
+ cmd := exec.Command("sh", "-c", "pgrep -f 'picoclaw\\s+gateway' >/dev/null 2>&1")
+ return cmd.Run() == nil
+}
+
+func stopGatewayProcess() error {
+ cmd := exec.Command("sh", "-c", "pkill -f 'picoclaw\\s+gateway' >/dev/null 2>&1")
+ return cmd.Run()
+}
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/gateway_windows.go b/cmd/picoclaw-launcher-tui/internal/ui/gateway_windows.go
new file mode 100644
index 000000000..7067a5c13
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/ui/gateway_windows.go
@@ -0,0 +1,16 @@
+//go:build windows
+// +build windows
+
+package ui
+
+import "os/exec"
+
+func isGatewayProcessRunning() bool {
+ cmd := exec.Command("tasklist", "/FI", "IMAGENAME eq picoclaw.exe")
+ return cmd.Run() == nil
+}
+
+func stopGatewayProcess() error {
+ cmd := exec.Command("taskkill", "/F", "/IM", "picoclaw.exe")
+ return cmd.Run()
+}
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/menu.go b/cmd/picoclaw-launcher-tui/internal/ui/menu.go
new file mode 100644
index 000000000..9f2132c5a
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/ui/menu.go
@@ -0,0 +1,72 @@
+package ui
+
+import (
+ "github.com/gdamore/tcell/v2"
+ "github.com/rivo/tview"
+)
+
+type MenuAction func()
+
+type MenuItem struct {
+ Label string
+ Description string
+ Action MenuAction
+ Disabled bool
+ MainColor *tcell.Color
+ DescColor *tcell.Color
+}
+
+type Menu struct {
+ *tview.Table
+ items []MenuItem
+}
+
+func NewMenu(title string, items []MenuItem) *Menu {
+ table := tview.NewTable().SetSelectable(true, false)
+ table.SetBorder(true).SetTitle(title)
+ table.SetBorders(false)
+ menu := &Menu{Table: table, items: items}
+ menu.applyItems(items)
+ menu.SetSelectedFunc(func(row, _ int) {
+ if row < 0 || row >= len(menu.items) {
+ return
+ }
+ item := menu.items[row]
+ if item.Disabled || item.Action == nil {
+ return
+ }
+ item.Action()
+ })
+ menu.SetSelectedStyle(
+ tcell.StyleDefault.Foreground(tview.Styles.InverseTextColor).
+ Background(tcell.NewRGBColor(189, 147, 249)),
+ )
+ return menu
+}
+
+func (m *Menu) applyItems(items []MenuItem) {
+ m.items = items
+ m.Clear()
+ for row, item := range items {
+ label := item.Label
+ if item.Disabled && label != "" {
+ label = label + " (disabled)"
+ }
+ left := tview.NewTableCell(label)
+ right := tview.NewTableCell(item.Description).SetAlign(tview.AlignRight)
+ if item.MainColor != nil {
+ left.SetTextColor(*item.MainColor)
+ }
+ if item.DescColor != nil {
+ right.SetTextColor(*item.DescColor)
+ } else {
+ right.SetTextColor(tview.Styles.TertiaryTextColor)
+ }
+ if item.Disabled {
+ left.SetTextColor(tcell.ColorGray)
+ right.SetTextColor(tcell.ColorGray)
+ }
+ m.SetCell(row, 0, left)
+ m.SetCell(row, 1, right)
+ }
+}
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/model.go b/cmd/picoclaw-launcher-tui/internal/ui/model.go
new file mode 100644
index 000000000..ba91f5b09
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/ui/model.go
@@ -0,0 +1,343 @@
+package ui
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gdamore/tcell/v2"
+ "github.com/rivo/tview"
+
+ picoclawconfig "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func (s *appState) modelMenu() tview.Primitive {
+ items := make([]MenuItem, 0, 2+len(s.config.ModelList))
+ items = append(items,
+ MenuItem{Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
+ MenuItem{
+ Label: "Add model",
+ Description: "Append a new model entry",
+ Action: func() {
+ s.addModel(
+ picoclawconfig.ModelConfig{ModelName: "new-model", Model: "openai/gpt-5.2"},
+ )
+ s.push(
+ fmt.Sprintf("model-%d", len(s.config.ModelList)-1),
+ s.modelForm(len(s.config.ModelList)-1),
+ )
+ },
+ },
+ )
+ currentModel := strings.TrimSpace(s.config.Agents.Defaults.Model)
+ for i := range s.config.ModelList {
+ index := i
+ model := s.config.ModelList[i]
+ isValid := isModelValid(model)
+ desc := model.APIBase
+ if desc == "" {
+ desc = model.AuthMethod
+ }
+ if desc == "" {
+ desc = "api_key required"
+ }
+ label := fmt.Sprintf("%s (%s)", model.ModelName, model.Model)
+ if model.ModelName == currentModel && currentModel != "" {
+ label = "* " + label
+ }
+ isSelected := model.ModelName == currentModel && currentModel != ""
+ items = append(items, MenuItem{
+ Label: label,
+ Description: desc,
+ MainColor: modelStatusColor(isValid, isSelected),
+ Action: func() {
+ s.push(fmt.Sprintf("model-%d", index), s.modelForm(index))
+ },
+ })
+ }
+
+ menu := NewMenu("Models", items)
+ menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
+ if event.Key() == tcell.KeyEsc {
+ s.pop()
+ return nil
+ }
+ if event.Rune() == 'q' {
+ s.pop()
+ return nil
+ }
+ if event.Rune() == ' ' {
+ row, _ := menu.GetSelection()
+ if row > 0 && row <= len(s.config.ModelList) {
+ model := s.config.ModelList[row-1]
+ if !isModelValid(model) {
+ s.showMessage(
+ "Invalid model",
+ "Select a model with api_key or oauth auth_method",
+ )
+ return nil
+ }
+ s.config.Agents.Defaults.Model = model.ModelName
+ s.dirty = true
+ refreshModelMenu(menu, s.config.Agents.Defaults.Model, s.config.ModelList)
+ refreshMainMenuIfPresent(s)
+ }
+ return nil
+ }
+ return event
+ })
+ return menu
+}
+
+func (s *appState) modelForm(index int) tview.Primitive {
+ model := &s.config.ModelList[index]
+ form := tview.NewForm()
+ form.SetBorder(true).SetTitle(fmt.Sprintf("Model: %s", model.ModelName))
+ form.SetButtonBackgroundColor(tcell.NewRGBColor(80, 250, 123))
+ form.SetButtonTextColor(tcell.NewRGBColor(12, 13, 22))
+
+ addInput(form, "Model Name", model.ModelName, func(value string) {
+ model.ModelName = value
+ s.dirty = true
+ refreshMainMenuIfPresent(s)
+ if menu, ok := s.menus["model"]; ok {
+ refreshModelMenuFromState(menu, s)
+ }
+ })
+ addInput(form, "Model", model.Model, func(value string) {
+ model.Model = value
+ s.dirty = true
+ refreshMainMenuIfPresent(s)
+ if menu, ok := s.menus["model"]; ok {
+ refreshModelMenuFromState(menu, s)
+ }
+ })
+ addInput(form, "API Base", model.APIBase, func(value string) {
+ model.APIBase = value
+ s.dirty = true
+ refreshMainMenuIfPresent(s)
+ if menu, ok := s.menus["model"]; ok {
+ refreshModelMenuFromState(menu, s)
+ }
+ })
+ addInput(form, "API Key", model.APIKey, func(value string) {
+ model.APIKey = value
+ s.dirty = true
+ refreshMainMenuIfPresent(s)
+ if menu, ok := s.menus["model"]; ok {
+ refreshModelMenuFromState(menu, s)
+ }
+ })
+ addInput(form, "Proxy", model.Proxy, func(value string) {
+ model.Proxy = value
+ })
+ addInput(form, "Auth Method", model.AuthMethod, func(value string) {
+ model.AuthMethod = value
+ s.dirty = true
+ refreshMainMenuIfPresent(s)
+ if menu, ok := s.menus["model"]; ok {
+ refreshModelMenuFromState(menu, s)
+ }
+ })
+ addInput(form, "Connect Mode", model.ConnectMode, func(value string) {
+ model.ConnectMode = value
+ })
+ addInput(form, "Workspace", model.Workspace, func(value string) {
+ model.Workspace = value
+ })
+ addInput(form, "Max Tokens Field", model.MaxTokensField, func(value string) {
+ model.MaxTokensField = value
+ })
+ addIntInput(form, "RPM", model.RPM, func(value int) {
+ model.RPM = value
+ })
+ addIntInput(form, "Request Timeout", model.RequestTimeout, func(value int) {
+ model.RequestTimeout = value
+ })
+
+ form.AddButton("Delete", func() {
+ s.deleteModel(index)
+ })
+ form.AddButton("Test", func() {
+ s.testModel(model)
+ })
+ form.AddButton("Back", func() {
+ s.pop()
+ })
+
+ form.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey {
+ if event.Key() == tcell.KeyEsc {
+ s.pop()
+ return nil
+ }
+ return event
+ })
+ return form
+}
+
+func addInput(form *tview.Form, label, value string, onChange func(string)) {
+ form.AddInputField(label, value, 128, nil, func(text string) {
+ onChange(strings.TrimSpace(text))
+ })
+}
+
+func addIntInput(form *tview.Form, label string, value int, onChange func(int)) {
+ form.AddInputField(label, fmt.Sprintf("%d", value), 16, nil, func(text string) {
+ var parsed int
+ if _, err := fmt.Sscanf(strings.TrimSpace(text), "%d", &parsed); err == nil {
+ onChange(parsed)
+ }
+ })
+}
+
+func (s *appState) addModel(model picoclawconfig.ModelConfig) {
+ s.config.ModelList = append(s.config.ModelList, model)
+}
+
+func (s *appState) deleteModel(index int) {
+ if index < 0 || index >= len(s.config.ModelList) {
+ return
+ }
+ s.config.ModelList = append(s.config.ModelList[:index], s.config.ModelList[index+1:]...)
+ s.pop()
+}
+
+func modelStatusColor(valid bool, selected bool) *tcell.Color {
+ if valid {
+ color := tview.Styles.PrimaryTextColor
+ return &color
+ }
+ color := tcell.ColorGray
+ return &color
+}
+
+func refreshModelMenu(menu *Menu, currentModel string, models []picoclawconfig.ModelConfig) {
+ for i, model := range models {
+ row := i + 1
+ label := fmt.Sprintf("%s (%s)", model.ModelName, model.Model)
+ isValid := isModelValid(model)
+ if model.ModelName == currentModel && currentModel != "" {
+ label = "* " + label
+ }
+ cell := menu.GetCell(row, 0)
+ if cell != nil {
+ cell.SetText(label)
+ isSelected := model.ModelName == currentModel && currentModel != ""
+ color := modelStatusColor(isValid, isSelected)
+ if color != nil {
+ cell.SetTextColor(*color)
+ }
+ }
+ }
+}
+
+func refreshModelMenuFromState(menu *Menu, s *appState) {
+ items := make([]MenuItem, 0, 2+len(s.config.ModelList))
+ items = append(items,
+ MenuItem{Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }},
+ MenuItem{
+ Label: "Add model",
+ Description: "Append a new model entry",
+ Action: func() {
+ s.addModel(
+ picoclawconfig.ModelConfig{ModelName: "new-model", Model: "openai/gpt-5.2"},
+ )
+ s.push(
+ fmt.Sprintf("model-%d", len(s.config.ModelList)-1),
+ s.modelForm(len(s.config.ModelList)-1),
+ )
+ },
+ },
+ )
+ currentModel := strings.TrimSpace(s.config.Agents.Defaults.Model)
+ for i := range s.config.ModelList {
+ index := i
+ model := s.config.ModelList[i]
+ isValid := isModelValid(model)
+ desc := model.APIBase
+ if desc == "" {
+ desc = model.AuthMethod
+ }
+ if desc == "" {
+ desc = "api_key required"
+ }
+ label := fmt.Sprintf("%s (%s)", model.ModelName, model.Model)
+ if model.ModelName == currentModel && currentModel != "" {
+ label = "* " + label
+ }
+ isSelected := model.ModelName == currentModel && currentModel != ""
+ items = append(items, MenuItem{
+ Label: label,
+ Description: desc,
+ MainColor: modelStatusColor(isValid, isSelected),
+ Action: func() {
+ s.push(fmt.Sprintf("model-%d", index), s.modelForm(index))
+ },
+ })
+ }
+ menu.applyItems(items)
+}
+
+func isModelValid(model picoclawconfig.ModelConfig) bool {
+ hasKey := strings.TrimSpace(model.APIKey) != "" ||
+ strings.TrimSpace(model.AuthMethod) == "oauth"
+ hasModel := strings.TrimSpace(model.Model) != ""
+ return hasKey && hasModel
+}
+
+func (s *appState) testModel(model *picoclawconfig.ModelConfig) {
+ if model == nil {
+ return
+ }
+ if strings.TrimSpace(model.APIKey) == "" {
+ s.showMessage("Missing API Key", "Set api_key before testing")
+ return
+ }
+ base := strings.TrimSpace(model.APIBase)
+ if base == "" {
+ s.showMessage("Missing API Base", "Set api_base before testing")
+ return
+ }
+ modelID := strings.TrimSpace(model.Model)
+ if modelID == "" {
+ s.showMessage("Missing Model", "Set model before testing")
+ return
+ }
+ if !strings.HasPrefix(modelID, "openai/") {
+ s.showMessage("Unsupported model", "Only openai/* models are supported for test")
+ return
+ }
+ modelName := strings.TrimPrefix(modelID, "openai/")
+ endpoint := strings.TrimRight(base, "/") + "/chat/completions"
+
+ payload := fmt.Sprintf(
+ `{"model":"%s","messages":[{"role":"user","content":"ping"}],"max_tokens":1}`,
+ modelName,
+ )
+ client := &http.Client{Timeout: 10 * time.Second}
+ request, err := http.NewRequest("POST", endpoint, strings.NewReader(payload))
+ if err != nil {
+ s.showMessage("Test failed", err.Error())
+ return
+ }
+ request.Header.Set("Content-Type", "application/json")
+ request.Header.Set("Authorization", "Bearer "+strings.TrimSpace(model.APIKey))
+
+ resp, err := client.Do(request)
+ if err != nil {
+ s.showMessage("Test failed", err.Error())
+ return
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode >= 200 && resp.StatusCode < 300 {
+ s.showMessage("Test OK", resp.Status)
+ return
+ }
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
+ s.showMessage(
+ "Test failed",
+ fmt.Sprintf("%s: %s", resp.Status, strings.TrimSpace(string(body))),
+ )
+}
diff --git a/cmd/picoclaw-launcher-tui/internal/ui/style.go b/cmd/picoclaw-launcher-tui/internal/ui/style.go
new file mode 100644
index 000000000..68cdd60b9
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/internal/ui/style.go
@@ -0,0 +1,43 @@
+package ui
+
+import (
+ "github.com/gdamore/tcell/v2"
+ "github.com/rivo/tview"
+)
+
+const (
+ colorBlue = "[#3e5db9]"
+ colorRed = "[#d54646]"
+ banner = "\r\n[::b]" +
+ colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" +
+ colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" +
+ colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" +
+ colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" +
+ colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
+ colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " +
+ "[:]"
+)
+
+func applyStyles() {
+ tview.Styles.PrimitiveBackgroundColor = tcell.NewRGBColor(12, 13, 22)
+ tview.Styles.ContrastBackgroundColor = tcell.NewRGBColor(34, 19, 53)
+ tview.Styles.MoreContrastBackgroundColor = tcell.NewRGBColor(18, 18, 32)
+ tview.Styles.BorderColor = tcell.NewRGBColor(112, 102, 255)
+ tview.Styles.TitleColor = tcell.NewRGBColor(255, 121, 198)
+ tview.Styles.GraphicsColor = tcell.NewRGBColor(139, 233, 253)
+ tview.Styles.PrimaryTextColor = tcell.NewRGBColor(241, 250, 255)
+ tview.Styles.SecondaryTextColor = tcell.NewRGBColor(80, 250, 123)
+ tview.Styles.TertiaryTextColor = tcell.NewRGBColor(139, 233, 253)
+ tview.Styles.InverseTextColor = tcell.NewRGBColor(12, 13, 22)
+ tview.Styles.ContrastSecondaryTextColor = tcell.NewRGBColor(189, 147, 249)
+}
+
+func bannerView() *tview.TextView {
+ text := tview.NewTextView()
+ text.SetDynamicColors(true)
+ text.SetTextAlign(tview.AlignCenter)
+ text.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor)
+ text.SetText(banner)
+ text.SetBorder(false)
+ return text
+}
diff --git a/cmd/picoclaw-launcher-tui/main.go b/cmd/picoclaw-launcher-tui/main.go
new file mode 100644
index 000000000..0e8cce415
--- /dev/null
+++ b/cmd/picoclaw-launcher-tui/main.go
@@ -0,0 +1,15 @@
+package main
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw-launcher-tui/internal/ui"
+)
+
+func main() {
+ if err := ui.Run(); err != nil {
+ fmt.Fprintln(os.Stderr, err)
+ os.Exit(1)
+ }
+}
diff --git a/cmd/picoclaw-launcher/README.md b/cmd/picoclaw-launcher/README.md
new file mode 100644
index 000000000..641279bb1
--- /dev/null
+++ b/cmd/picoclaw-launcher/README.md
@@ -0,0 +1,290 @@
+# PicoClaw Launcher
+
+> [!WARNING]
+> This project is a temporary solution and will be refactored in the future to provide a complete web service. Therefore, the APIs in this directory are not stable.
+
+A standalone launcher for PicoClaw, providing visual JSON editing and OAuth provider authentication management.
+
+## Features
+
+- 📝 **Config Editor** — Sidebar-based settings UI with model management, channel configuration forms, and a raw JSON editor
+- 🤖 **Model Management** — Model card grid with availability status (grayed out without API key), primary model selection, add/edit/delete with required/optional field separation
+- 📡 **Channel Configuration** — Form-based settings for 12 channel types (Telegram, Discord, Slack, WeCom, DingTalk, Feishu, LINE, WhatsApp, QQ, OneBot, MaixCAM, etc.) with documentation links
+- 🔐 **Provider Auth** — Login to OpenAI (Device Code), Anthropic (API Token), Google Antigravity (Browser OAuth)
+- 🌐 **Embedded Frontend** — Compiles to a single binary with no external dependencies
+- 🌍 **i18n** — Chinese/English language switching with browser auto-detection
+- 🎨 **Theme** — Light / Dark / System theme toggle with localStorage persistence
+
+## Quick Start
+
+```bash
+# Build
+go build -o picoclaw-launcher ./cmd/picoclaw-launcher/
+
+# Run with default config path (~/.picoclaw/config.json)
+./picoclaw-launcher
+
+# Specify a config file
+./picoclaw-launcher ./config.json
+
+# Allow LAN access
+./picoclaw-launcher -public
+```
+
+Open `http://localhost:18800` in your browser.
+
+## CLI Options
+
+```
+Usage: picoclaw-config [options] [config.json]
+
+Arguments:
+ config.json Path to the configuration file (default: ~/.picoclaw/config.json)
+
+Options:
+ -public Listen on all interfaces (0.0.0.0), allowing access from other devices
+```
+
+## API Reference
+
+Base URL: `http://localhost:18800`
+
+---
+
+### Static Files
+
+#### GET /
+
+Serves the embedded frontend (`index.html`).
+
+---
+
+### Config API
+
+#### GET /api/config
+
+Reads the current configuration file.
+
+**Response** `200 OK`
+
+```json
+{
+ "config": { ... },
+ "path": "/Users/xiao/.picoclaw/config.json"
+}
+```
+
+---
+
+#### PUT /api/config
+
+Saves the configuration. The request body must be a complete Config JSON object.
+
+**Request Body** — `application/json`
+
+```json
+{
+ "agents": { "defaults": { "model_name": "gpt-5.2" } },
+ "model_list": [
+ {
+ "model_name": "gpt-5.2",
+ "model": "openai/gpt-5.2",
+ "auth_method": "oauth"
+ }
+ ]
+}
+```
+
+**Response** `200 OK`
+
+```json
+{ "status": "ok" }
+```
+
+**Error** `400 Bad Request` — Invalid JSON
+
+---
+
+### Auth API
+
+#### GET /api/auth/status
+
+Returns the authentication status of all providers and any in-progress device code login.
+
+**Response** `200 OK`
+
+```json
+{
+ "providers": [
+ {
+ "provider": "openai",
+ "auth_method": "oauth",
+ "status": "active",
+ "account_id": "user-xxx",
+ "expires_at": "2026-03-01T00:00:00Z"
+ }
+ ],
+ "pending_device": {
+ "provider": "openai",
+ "status": "pending",
+ "device_url": "https://auth.openai.com/activate",
+ "user_code": "ABCD-1234"
+ }
+}
+```
+
+`status` values: `active` | `expired` | `needs_refresh`
+
+`pending_device` is only present when a device code login is in progress.
+
+---
+
+#### POST /api/auth/login
+
+Initiates a provider login.
+
+**Request Body** — `application/json`
+
+```json
+{ "provider": "openai" }
+```
+
+Supported `provider` values: `openai` | `anthropic` | `google-antigravity`
+
+##### OpenAI (Device Code Flow)
+
+Returns device code info. The server polls for completion in the background.
+
+```json
+{
+ "status": "pending",
+ "device_url": "https://auth.openai.com/activate",
+ "user_code": "ABCD-1234",
+ "message": "Open the URL and enter the code to authenticate."
+}
+```
+
+The user opens `device_url` in a browser and enters `user_code`. Once authenticated, `GET /api/auth/status` will show `pending_device.status` as `success`.
+
+##### Anthropic (API Token)
+
+Requires a `token` field in the request:
+
+```json
+{ "provider": "anthropic", "token": "sk-ant-xxx" }
+```
+
+**Response:**
+
+```json
+{ "status": "success", "message": "Anthropic token saved" }
+```
+
+##### Google Antigravity (Browser OAuth)
+
+Returns an authorization URL for the frontend to open in a new tab:
+
+```json
+{
+ "status": "redirect",
+ "auth_url": "https://accounts.google.com/o/oauth2/auth?...",
+ "message": "Open the URL to authenticate with Google."
+}
+```
+
+After authentication, Google redirects to `GET /auth/callback`, which saves the credentials and redirects back to the picoclaw-config UI.
+
+---
+
+#### POST /api/auth/logout
+
+Logs out from a provider.
+
+**Request Body** — `application/json`
+
+```json
+{ "provider": "openai" }
+```
+
+Omit or leave `provider` empty to log out from all providers.
+
+**Response** `200 OK`
+
+```json
+{ "status": "ok" }
+```
+
+---
+
+#### GET /auth/callback
+
+OAuth browser callback endpoint (used by Google Antigravity). Called by the OAuth provider's redirect — **not invoked directly by the frontend**.
+
+**Query Parameters:**
+- `state` — OAuth state for CSRF validation
+- `code` — Authorization code
+
+On success, redirects to `/#auth`.
+
+
+### Process API
+
+#### GET /api/process/status
+
+Gets the running status of the `picoclaw gateway` process.
+
+**Response** `200 OK` (Running)
+
+```json
+{
+ "process_status": "running",
+ "status": "ok",
+ "uptime": "1.010814s"
+}
+```
+
+**Response** `200 OK` (Stopped)
+
+```json
+{
+ "process_status": "stopped",
+ "error": "Get \"http://localhost:18790/health\": dial tcp [::1]:18790: connect: connection refused"
+}
+```
+
+---
+
+#### POST /api/process/start
+
+Starts the `picoclaw gateway` process in the background.
+
+**Response** `200 OK`
+
+```json
+{
+ "status": "ok",
+ "pid": 12345
+}
+```
+
+---
+
+#### POST /api/process/stop
+
+Stops the running `picoclaw gateway` process.
+
+**Response** `200 OK`
+
+```json
+{
+ "status": "ok"
+}
+```
+
+---
+
+## Testing
+
+```bash
+go test -v ./cmd/picoclaw-launcher/
+```
diff --git a/cmd/picoclaw-launcher/README.zh.md b/cmd/picoclaw-launcher/README.zh.md
new file mode 100644
index 000000000..320de75a5
--- /dev/null
+++ b/cmd/picoclaw-launcher/README.zh.md
@@ -0,0 +1,287 @@
+# PicoClaw Launcher
+
+> [!WARNING]
+> 该项目属于临时解决方案,后续会重构并提供完整的 Web 服务,因此该目录下的接口并不稳定。
+
+PicoClaw 的独立启动器,提供可视化 JSON 配置编辑和 OAuth Provider 认证管理。
+
+## 功能
+
+- 📝 **配置编辑** — 侧边栏式设置 UI,支持模型管理、通道配置表单和原始 JSON 编辑器
+- 🤖 **模型管理** — 模型卡片网格,可用性状态显示(无 API Key 时灰色),主模型选择,增删改查,必填/选填字段分离
+- 📡 **通道配置** — 12 种通道类型(Telegram、Discord、Slack、企业微信、钉钉、飞书、LINE、WhatsApp、QQ、OneBot、MaixCAM 等)的表单化配置,附带文档链接
+- 🔐 **Provider 认证** — 支持 OpenAI (Device Code)、Anthropic (API Token)、Google Antigravity (Browser OAuth) 登录
+- 🌐 **嵌入式前端** — 编译为单一二进制文件,无需额外依赖
+- 🌍 **国际化** — 中英文切换,首次访问自动检测浏览器语言
+- 🎨 **主题** — 亮色 / 暗色 / 跟随系统,偏好保存在 localStorage
+
+## 快速开始
+
+```bash
+# 编译
+go build -o picoclaw-launcher ./cmd/picoclaw-launcher/
+
+# 运行(使用默认配置路径 ~/.picoclaw/config.json)
+./picoclaw-launcher
+
+# 指定配置文件
+./picoclaw-launcher ./config.json
+
+# 允许局域网访问
+./picoclaw-launcher -public
+```
+
+启动后在浏览器中打开 `http://localhost:18800`。
+
+## 命令行参数
+
+```
+Usage: picoclaw-launcher [options] [config.json]
+
+Arguments:
+ config.json 配置文件路径(默认: ~/.picoclaw/config.json)
+
+Options:
+ -public 监听所有网络接口(0.0.0.0),允许局域网设备访问
+```
+
+## API 文档
+
+Base URL: `http://localhost:18800`
+
+### 静态文件
+
+#### GET /
+
+提供嵌入式前端页面(`index.html`)。
+
+---
+
+### Config API
+
+#### GET /api/config
+
+读取当前配置文件内容。
+
+**Response** `200 OK`
+
+```json
+{
+ "config": { ... },
+ "path": "/Users/xiao/.picoclaw/config.json"
+}
+```
+
+---
+
+#### PUT /api/config
+
+保存配置。请求体为完整的 Config JSON。
+
+**Request Body** — `application/json`
+
+```json
+{
+ "agents": { "defaults": { "model_name": "gpt-5.2" } },
+ "model_list": [
+ {
+ "model_name": "gpt-5.2",
+ "model": "openai/gpt-5.2",
+ "auth_method": "oauth"
+ }
+ ]
+}
+```
+
+**Response** `200 OK`
+
+```json
+{ "status": "ok" }
+```
+
+**Error** `400 Bad Request` — 无效 JSON
+
+---
+
+### Auth API
+
+#### GET /api/auth/status
+
+获取所有 Provider 的认证状态和进行中的 Device Code 登录信息。
+
+**Response** `200 OK`
+
+```json
+{
+ "providers": [
+ {
+ "provider": "openai",
+ "auth_method": "oauth",
+ "status": "active",
+ "account_id": "user-xxx",
+ "expires_at": "2026-03-01T00:00:00Z"
+ }
+ ],
+ "pending_device": {
+ "provider": "openai",
+ "status": "pending",
+ "device_url": "https://auth.openai.com/activate",
+ "user_code": "ABCD-1234"
+ }
+}
+```
+
+`status` 可选值: `active` | `expired` | `needs_refresh`
+
+`pending_device` 仅在有进行中的 Device Code 登录时返回。
+
+---
+
+#### POST /api/auth/login
+
+发起 Provider 登录。
+
+**Request Body** — `application/json`
+
+```json
+{ "provider": "openai" }
+```
+
+支持的 `provider` 值: `openai` | `anthropic` | `google-antigravity`
+
+##### OpenAI (Device Code Flow)
+
+返回 Device Code 信息,后台自动轮询认证结果:
+
+```json
+{
+ "status": "pending",
+ "device_url": "https://auth.openai.com/activate",
+ "user_code": "ABCD-1234",
+ "message": "Open the URL and enter the code to authenticate."
+}
+```
+
+用户在浏览器中打开 `device_url` 并输入 `user_code`。认证完成后通过 `GET /api/auth/status` 的 `pending_device.status` 变为 `success` 通知前端。
+
+##### Anthropic (API Token)
+
+需在请求中附带 token:
+
+```json
+{ "provider": "anthropic", "token": "sk-ant-xxx" }
+```
+
+**Response:**
+
+```json
+{ "status": "success", "message": "Anthropic token saved" }
+```
+
+##### Google Antigravity (Browser OAuth)
+
+返回授权 URL,前端打开新标签页:
+
+```json
+{
+ "status": "redirect",
+ "auth_url": "https://accounts.google.com/o/oauth2/auth?...",
+ "message": "Open the URL to authenticate with Google."
+}
+```
+
+认证完成后 Google 回调至 `GET /auth/callback`,自动保存凭据并重定向回 picoclaw-config 页面。
+
+---
+
+#### POST /api/auth/logout
+
+登出 Provider。
+
+**Request Body** — `application/json`
+
+```json
+{ "provider": "openai" }
+```
+
+传空字符串或省略 `provider` 则登出所有 Provider。
+
+**Response** `200 OK`
+
+```json
+{ "status": "ok" }
+```
+
+---
+
+#### GET /auth/callback
+
+OAuth Browser 回调端点(Google Antigravity 专用),由 OAuth Provider 重定向调用,**非前端直接使用**。
+
+**Query Parameters:**
+- `state` — OAuth state 校验
+- `code` — 授权码
+
+认证成功后重定向到 `/#auth`。
+
+### Process API
+
+#### GET /api/process/status
+
+获取 `picoclaw gateway` 进程的运行状态。
+
+**Response** `200 OK` (运行中)
+
+```json
+{
+ "process_status": "running",
+ "status": "ok",
+ "uptime": "1.010814s"
+}
+```
+
+**Response** `200 OK` (未运行)
+
+```json
+{
+ "process_status": "stopped",
+ "error": "Get \"http://localhost:18790/health\": dial tcp [::1]:18790: connect: connection refused"
+}
+```
+
+---
+
+#### POST /api/process/start
+
+在后台启动 `picoclaw gateway` 进程。
+
+**Response** `200 OK`
+
+```json
+{
+ "status": "ok",
+ "pid": 12345
+}
+```
+
+---
+
+#### POST /api/process/stop
+
+停止正在运行的 `picoclaw gateway` 进程。
+
+**Response** `200 OK`
+
+```json
+{
+ "status": "ok"
+}
+```
+
+---
+
+## 测试
+
+```bash
+go test -v ./cmd/picoclaw-launcher/
+```
diff --git a/cmd/picoclaw-launcher/icon.ico b/cmd/picoclaw-launcher/icon.ico
new file mode 100644
index 000000000..4f6539414
Binary files /dev/null and b/cmd/picoclaw-launcher/icon.ico differ
diff --git a/cmd/picoclaw-launcher/internal/server/auth_config.go b/cmd/picoclaw-launcher/internal/server/auth_config.go
new file mode 100644
index 000000000..f75e8fff0
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/auth_config.go
@@ -0,0 +1,147 @@
+package server
+
+import (
+ "log"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/auth"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// updateConfigAfterLogin updates config.json after a successful provider login.
+func updateConfigAfterLogin(configPath, provider string, cred *auth.AuthCredential) {
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ log.Printf("Warning: could not load config to update auth_method: %v", err)
+ return
+ }
+
+ switch provider {
+ case "openai":
+ cfg.Providers.OpenAI.AuthMethod = "oauth"
+ found := false
+ for i := range cfg.ModelList {
+ if isOpenAIModel(cfg.ModelList[i].Model) {
+ cfg.ModelList[i].AuthMethod = "oauth"
+ found = true
+ break
+ }
+ }
+ if !found {
+ cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
+ ModelName: "gpt-5.2",
+ Model: "openai/gpt-5.2",
+ AuthMethod: "oauth",
+ })
+ }
+ cfg.Agents.Defaults.ModelName = "gpt-5.2"
+
+ case "anthropic":
+ cfg.Providers.Anthropic.AuthMethod = "token"
+ found := false
+ for i := range cfg.ModelList {
+ if isAnthropicModel(cfg.ModelList[i].Model) {
+ cfg.ModelList[i].AuthMethod = "token"
+ found = true
+ break
+ }
+ }
+ if !found {
+ cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
+ ModelName: "claude-sonnet-4.6",
+ Model: "anthropic/claude-sonnet-4.6",
+ AuthMethod: "token",
+ })
+ }
+ cfg.Agents.Defaults.ModelName = "claude-sonnet-4.6"
+
+ case "google-antigravity":
+ cfg.Providers.Antigravity.AuthMethod = "oauth"
+ found := false
+ for i := range cfg.ModelList {
+ if isAntigravityModel(cfg.ModelList[i].Model) {
+ cfg.ModelList[i].AuthMethod = "oauth"
+ found = true
+ break
+ }
+ }
+ if !found {
+ cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
+ ModelName: "gemini-flash",
+ Model: "antigravity/gemini-3-flash",
+ AuthMethod: "oauth",
+ })
+ }
+ cfg.Agents.Defaults.ModelName = "gemini-flash"
+ }
+
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ log.Printf("Warning: could not update config: %v", err)
+ }
+}
+
+// clearAuthMethodInConfig clears auth_method for a specific provider in config.json.
+func clearAuthMethodInConfig(configPath, provider string) {
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ return
+ }
+
+ for i := range cfg.ModelList {
+ switch provider {
+ case "openai":
+ if isOpenAIModel(cfg.ModelList[i].Model) {
+ cfg.ModelList[i].AuthMethod = ""
+ }
+ case "anthropic":
+ if isAnthropicModel(cfg.ModelList[i].Model) {
+ cfg.ModelList[i].AuthMethod = ""
+ }
+ case "google-antigravity", "antigravity":
+ if isAntigravityModel(cfg.ModelList[i].Model) {
+ cfg.ModelList[i].AuthMethod = ""
+ }
+ }
+ }
+
+ switch provider {
+ case "openai":
+ cfg.Providers.OpenAI.AuthMethod = ""
+ case "anthropic":
+ cfg.Providers.Anthropic.AuthMethod = ""
+ case "google-antigravity", "antigravity":
+ cfg.Providers.Antigravity.AuthMethod = ""
+ }
+
+ config.SaveConfig(configPath, cfg)
+}
+
+// clearAllAuthMethodsInConfig clears auth_method for all providers in config.json.
+func clearAllAuthMethodsInConfig(configPath string) {
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ return
+ }
+ for i := range cfg.ModelList {
+ cfg.ModelList[i].AuthMethod = ""
+ }
+ cfg.Providers.OpenAI.AuthMethod = ""
+ cfg.Providers.Anthropic.AuthMethod = ""
+ cfg.Providers.Antigravity.AuthMethod = ""
+ config.SaveConfig(configPath, cfg)
+}
+
+// ── Model identification helpers ─────────────────────────────────
+
+func isOpenAIModel(model string) bool {
+ return model == "openai" || strings.HasPrefix(model, "openai/")
+}
+
+func isAnthropicModel(model string) bool {
+ return model == "anthropic" || strings.HasPrefix(model, "anthropic/")
+}
+
+func isAntigravityModel(model string) bool {
+ return model == "antigravity" || model == "google-antigravity" ||
+ strings.HasPrefix(model, "antigravity/") || strings.HasPrefix(model, "google-antigravity/")
+}
diff --git a/cmd/picoclaw-launcher/internal/server/auth_config_test.go b/cmd/picoclaw-launcher/internal/server/auth_config_test.go
new file mode 100644
index 000000000..92158d011
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/auth_config_test.go
@@ -0,0 +1,222 @@
+package server
+
+import (
+ "path/filepath"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/auth"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// ── Model identification helpers ─────────────────────────────────
+
+func TestIsOpenAIModel(t *testing.T) {
+ tests := []struct {
+ model string
+ want bool
+ }{
+ {"openai", true},
+ {"openai/gpt-4o", true},
+ {"openai/gpt-5.2", true},
+ {"anthropic", false},
+ {"anthropic/claude-sonnet-4.6", false},
+ {"openai-compatible", false},
+ {"", false},
+ }
+ for _, tt := range tests {
+ if got := isOpenAIModel(tt.model); got != tt.want {
+ t.Errorf("isOpenAIModel(%q) = %v, want %v", tt.model, got, tt.want)
+ }
+ }
+}
+
+func TestIsAnthropicModel(t *testing.T) {
+ tests := []struct {
+ model string
+ want bool
+ }{
+ {"anthropic", true},
+ {"anthropic/claude-sonnet-4.6", true},
+ {"openai", false},
+ {"openai/gpt-4o", false},
+ {"", false},
+ }
+ for _, tt := range tests {
+ if got := isAnthropicModel(tt.model); got != tt.want {
+ t.Errorf("isAnthropicModel(%q) = %v, want %v", tt.model, got, tt.want)
+ }
+ }
+}
+
+func TestIsAntigravityModel(t *testing.T) {
+ tests := []struct {
+ model string
+ want bool
+ }{
+ {"antigravity", true},
+ {"google-antigravity", true},
+ {"antigravity/gemini-3-flash", true},
+ {"google-antigravity/gemini-3-flash", true},
+ {"openai", false},
+ {"antigravity-custom", false},
+ {"", false},
+ }
+ for _, tt := range tests {
+ if got := isAntigravityModel(tt.model); got != tt.want {
+ t.Errorf("isAntigravityModel(%q) = %v, want %v", tt.model, got, tt.want)
+ }
+ }
+}
+
+// ── Config update helpers ────────────────────────────────────────
+
+func writeTempConfigViaSave(t *testing.T, cfg *config.Config) string {
+ t.Helper()
+ dir := t.TempDir()
+ path := filepath.Join(dir, "config.json")
+ if err := config.SaveConfig(path, cfg); err != nil {
+ t.Fatalf("save config: %v", err)
+ }
+ return path
+}
+
+func loadTempConfig(t *testing.T, path string) *config.Config {
+ t.Helper()
+ cfg, err := config.LoadConfig(path)
+ if err != nil {
+ t.Fatalf("load config: %v", err)
+ }
+ return cfg
+}
+
+func TestUpdateConfigAfterLogin_OpenAI_ExistingModel(t *testing.T) {
+ cfg := &config.Config{
+ ModelList: []config.ModelConfig{
+ {ModelName: "gpt-4o", Model: "openai/gpt-4o"},
+ },
+ }
+ path := writeTempConfigViaSave(t, cfg)
+
+ cred := &auth.AuthCredential{AuthMethod: "oauth"}
+ updateConfigAfterLogin(path, "openai", cred)
+
+ result := loadTempConfig(t, path)
+
+ // Model-level auth_method persists through serialization
+ if len(result.ModelList) != 1 {
+ t.Fatalf("expected 1 model, got %d", len(result.ModelList))
+ }
+ if result.ModelList[0].AuthMethod != "oauth" {
+ t.Errorf("expected model auth_method=oauth, got %q", result.ModelList[0].AuthMethod)
+ }
+}
+
+func TestUpdateConfigAfterLogin_OpenAI_NoExistingModel(t *testing.T) {
+ cfg := &config.Config{
+ ModelList: []config.ModelConfig{
+ {ModelName: "claude", Model: "anthropic/claude-sonnet-4.6"},
+ },
+ }
+ path := writeTempConfigViaSave(t, cfg)
+
+ cred := &auth.AuthCredential{AuthMethod: "oauth"}
+ updateConfigAfterLogin(path, "openai", cred)
+
+ result := loadTempConfig(t, path)
+
+ if len(result.ModelList) != 2 {
+ t.Fatalf("expected 2 models (original + added), got %d", len(result.ModelList))
+ }
+ if result.ModelList[1].Model != "openai/gpt-5.2" {
+ t.Errorf("expected added model openai/gpt-5.2, got %q", result.ModelList[1].Model)
+ }
+ if result.Agents.Defaults.ModelName != "gpt-5.2" {
+ t.Errorf("expected default model_name=gpt-5.2, got %q", result.Agents.Defaults.ModelName)
+ }
+}
+
+func TestUpdateConfigAfterLogin_Anthropic(t *testing.T) {
+ cfg := &config.Config{}
+ path := writeTempConfigViaSave(t, cfg)
+
+ cred := &auth.AuthCredential{AuthMethod: "token"}
+ updateConfigAfterLogin(path, "anthropic", cred)
+
+ result := loadTempConfig(t, path)
+
+ // Model should be added with correct auth_method
+ if len(result.ModelList) != 1 {
+ t.Fatalf("expected 1 model added, got %d", len(result.ModelList))
+ }
+ if result.ModelList[0].Model != "anthropic/claude-sonnet-4.6" {
+ t.Errorf("expected model anthropic/claude-sonnet-4.6, got %q", result.ModelList[0].Model)
+ }
+ if result.ModelList[0].AuthMethod != "token" {
+ t.Errorf("expected model auth_method=token, got %q", result.ModelList[0].AuthMethod)
+ }
+}
+
+func TestUpdateConfigAfterLogin_GoogleAntigravity(t *testing.T) {
+ cfg := &config.Config{}
+ path := writeTempConfigViaSave(t, cfg)
+
+ cred := &auth.AuthCredential{AuthMethod: "oauth"}
+ updateConfigAfterLogin(path, "google-antigravity", cred)
+
+ result := loadTempConfig(t, path)
+
+ // Model should be added with correct auth_method
+ if len(result.ModelList) != 1 {
+ t.Fatalf("expected 1 model added, got %d", len(result.ModelList))
+ }
+ if result.ModelList[0].Model != "antigravity/gemini-3-flash" {
+ t.Errorf("expected model antigravity/gemini-3-flash, got %q", result.ModelList[0].Model)
+ }
+ if result.ModelList[0].AuthMethod != "oauth" {
+ t.Errorf("expected model auth_method=oauth, got %q", result.ModelList[0].AuthMethod)
+ }
+}
+
+func TestClearAuthMethodInConfig(t *testing.T) {
+ cfg := &config.Config{
+ ModelList: []config.ModelConfig{
+ {ModelName: "gpt-4o", Model: "openai/gpt-4o", AuthMethod: "oauth"},
+ {ModelName: "claude", Model: "anthropic/claude-sonnet-4.6", AuthMethod: "token"},
+ },
+ }
+ path := writeTempConfigViaSave(t, cfg)
+
+ clearAuthMethodInConfig(path, "openai")
+
+ result := loadTempConfig(t, path)
+
+ // Openai model auth_method should be cleared
+ if result.ModelList[0].AuthMethod != "" {
+ t.Errorf("expected openai model auth_method cleared, got %q", result.ModelList[0].AuthMethod)
+ }
+ // Anthropic model should be unchanged
+ if result.ModelList[1].AuthMethod != "token" {
+ t.Errorf("expected anthropic model auth_method unchanged, got %q", result.ModelList[1].AuthMethod)
+ }
+}
+
+func TestClearAllAuthMethodsInConfig(t *testing.T) {
+ cfg := &config.Config{
+ ModelList: []config.ModelConfig{
+ {ModelName: "gpt-4o", Model: "openai/gpt-4o", AuthMethod: "oauth"},
+ {ModelName: "claude", Model: "anthropic/claude-sonnet-4.6", AuthMethod: "token"},
+ {ModelName: "gemini", Model: "antigravity/gemini-3-flash", AuthMethod: "oauth"},
+ },
+ }
+ path := writeTempConfigViaSave(t, cfg)
+
+ clearAllAuthMethodsInConfig(path)
+
+ result := loadTempConfig(t, path)
+
+ for i, m := range result.ModelList {
+ if m.AuthMethod != "" {
+ t.Errorf("model[%d] auth_method not cleared, got %q", i, m.AuthMethod)
+ }
+ }
+}
diff --git a/cmd/picoclaw-launcher/internal/server/auth_handlers.go b/cmd/picoclaw-launcher/internal/server/auth_handlers.go
new file mode 100644
index 000000000..1e9b8be0a
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/auth_handlers.go
@@ -0,0 +1,312 @@
+package server
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/auth"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// oauthSession stores in-flight OAuth state for browser-based flows.
+type oauthSession struct {
+ Provider string
+ PKCE auth.PKCECodes
+ State string
+ RedirectURI string
+ OAuthCfg auth.OAuthProviderConfig
+ ConfigPath string
+}
+
+// deviceCodeSession stores in-flight device code flow state.
+type deviceCodeSession struct {
+ mu sync.Mutex
+ Provider string
+ Info *auth.DeviceCodeInfo
+ OAuthCfg auth.OAuthProviderConfig
+ ConfigPath string
+ Status string // "pending", "success", "error"
+ Error string
+ Done bool
+}
+
+var (
+ oauthSessions = map[string]*oauthSession{} // keyed by state
+ oauthSessionsMu sync.Mutex
+
+ activeDeviceSession *deviceCodeSession
+ activeDeviceSessionMu sync.Mutex
+)
+
+// handleOpenAILogin starts the OpenAI device code flow and returns device code info to the frontend.
+func handleOpenAILogin(w http.ResponseWriter, configPath string) {
+ // Check if there's already a pending device code session
+ activeDeviceSessionMu.Lock()
+ if activeDeviceSession != nil {
+ activeDeviceSession.mu.Lock()
+ if !activeDeviceSession.Done {
+ resp := map[string]any{
+ "status": "pending",
+ "device_url": activeDeviceSession.Info.VerifyURL,
+ "user_code": activeDeviceSession.Info.UserCode,
+ "message": "Device code flow already in progress. Enter the code in your browser.",
+ }
+ activeDeviceSession.mu.Unlock()
+ activeDeviceSessionMu.Unlock()
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ return
+ }
+ activeDeviceSession.mu.Unlock()
+ }
+ activeDeviceSessionMu.Unlock()
+
+ // Request a device code
+ oauthCfg := auth.OpenAIOAuthConfig()
+ info, err := auth.RequestDeviceCode(oauthCfg)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to request device code: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ session := &deviceCodeSession{
+ Provider: "openai",
+ Info: info,
+ OAuthCfg: oauthCfg,
+ ConfigPath: configPath,
+ Status: "pending",
+ }
+
+ activeDeviceSessionMu.Lock()
+ activeDeviceSession = session
+ activeDeviceSessionMu.Unlock()
+
+ // Start background polling
+ go func() {
+ deadline := time.After(15 * time.Minute)
+ ticker := time.NewTicker(time.Duration(info.Interval) * time.Second)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-deadline:
+ session.mu.Lock()
+ session.Status = "error"
+ session.Error = "Authentication timed out after 15 minutes"
+ session.Done = true
+ session.mu.Unlock()
+ return
+ case <-ticker.C:
+ cred, err := auth.PollDeviceCodeOnce(oauthCfg, info.DeviceAuthID, info.UserCode)
+ if err != nil {
+ continue // Still pending
+ }
+ if cred != nil {
+ if saveErr := auth.SetCredential("openai", cred); saveErr != nil {
+ session.mu.Lock()
+ session.Status = "error"
+ session.Error = saveErr.Error()
+ session.Done = true
+ session.mu.Unlock()
+ return
+ }
+ updateConfigAfterLogin(configPath, "openai", cred)
+ session.mu.Lock()
+ session.Status = "success"
+ session.Done = true
+ session.mu.Unlock()
+ log.Printf("OpenAI device code login successful (account: %s)", cred.AccountID)
+ return
+ }
+ }
+ }
+ }()
+
+ // Return device code info to frontend
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]any{
+ "status": "pending",
+ "device_url": info.VerifyURL,
+ "user_code": info.UserCode,
+ "message": "Open the URL and enter the code to authenticate.",
+ })
+}
+
+// handleAnthropicLogin saves a pasted API token for Anthropic.
+func handleAnthropicLogin(w http.ResponseWriter, token, configPath string) {
+ if token == "" {
+ http.Error(w, "Token is required for Anthropic login", http.StatusBadRequest)
+ return
+ }
+
+ cred := &auth.AuthCredential{
+ AccessToken: token,
+ Provider: "anthropic",
+ AuthMethod: "token",
+ }
+
+ if err := auth.SetCredential("anthropic", cred); err != nil {
+ http.Error(w, fmt.Sprintf("Failed to save credentials: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ updateConfigAfterLogin(configPath, "anthropic", cred)
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "success",
+ "message": "Anthropic token saved",
+ })
+}
+
+// handleGoogleAntigravityLogin generates a PKCE + auth URL and returns it to the frontend.
+func handleGoogleAntigravityLogin(w http.ResponseWriter, r *http.Request, configPath string) {
+ oauthCfg := auth.GoogleAntigravityOAuthConfig()
+
+ pkce, err := auth.GeneratePKCE()
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to generate PKCE: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ state, err := auth.GenerateState()
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to generate state: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ // Build redirect URI pointing to picoclaw-launcher's own callback
+ scheme := "http"
+ redirectURI := fmt.Sprintf("%s://%s/auth/callback", scheme, r.Host)
+
+ authURL := auth.BuildAuthorizeURL(oauthCfg, pkce, state, redirectURI)
+
+ // Store session for callback
+ oauthSessionsMu.Lock()
+ oauthSessions[state] = &oauthSession{
+ Provider: "google-antigravity",
+ PKCE: pkce,
+ State: state,
+ RedirectURI: redirectURI,
+ OAuthCfg: oauthCfg,
+ ConfigPath: configPath,
+ }
+ oauthSessionsMu.Unlock()
+
+ // Clean up stale sessions after 10 minutes
+ go func() {
+ time.Sleep(10 * time.Minute)
+ oauthSessionsMu.Lock()
+ delete(oauthSessions, state)
+ oauthSessionsMu.Unlock()
+ }()
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "redirect",
+ "auth_url": authURL,
+ "message": "Open the URL to authenticate with Google.",
+ })
+}
+
+// handleOAuthCallback processes the OAuth callback from Google Antigravity.
+func handleOAuthCallback(w http.ResponseWriter, r *http.Request) {
+ state := r.URL.Query().Get("state")
+ code := r.URL.Query().Get("code")
+
+ oauthSessionsMu.Lock()
+ session, ok := oauthSessions[state]
+ if ok {
+ delete(oauthSessions, state)
+ }
+ oauthSessionsMu.Unlock()
+
+ if !ok {
+ http.Error(w, "Invalid or expired OAuth state", http.StatusBadRequest)
+ return
+ }
+
+ if code == "" {
+ errMsg := r.URL.Query().Get("error")
+ w.Header().Set("Content-Type", "text/html")
+ fmt.Fprintf(
+ w,
+ `Authentication failed
%s
You can close this window.
`,
+ errMsg,
+ )
+ return
+ }
+
+ cred, err := auth.ExchangeCodeForTokens(session.OAuthCfg, code, session.PKCE.CodeVerifier, session.RedirectURI)
+ if err != nil {
+ w.Header().Set("Content-Type", "text/html")
+ fmt.Fprintf(
+ w,
+ `Authentication failed
%s
You can close this window.
`,
+ err.Error(),
+ )
+ return
+ }
+
+ cred.Provider = session.Provider
+
+ // Fetch user info for Google Antigravity
+ if session.Provider == "google-antigravity" {
+ if email, err := fetchGoogleUserEmail(cred.AccessToken); err == nil {
+ cred.Email = email
+ }
+ if projectID, err := providers.FetchAntigravityProjectID(cred.AccessToken); err == nil {
+ cred.ProjectID = projectID
+ }
+ }
+
+ if err := auth.SetCredential(session.Provider, cred); err != nil {
+ w.Header().Set("Content-Type", "text/html")
+ fmt.Fprintf(w, `Failed to save credentials
%s
`, err.Error())
+ return
+ }
+
+ updateConfigAfterLogin(session.ConfigPath, session.Provider, cred)
+
+ // Redirect back to picoclaw-launcher UI
+ w.Header().Set("Content-Type", "text/html")
+ fmt.Fprintf(w, `
+ Authentication successful!
+ Redirecting back to Config Editor...
+
+ `)
+}
+
+// fetchGoogleUserEmail retrieves the user's email from Google's userinfo endpoint.
+func fetchGoogleUserEmail(accessToken string) (string, error) {
+ req, err := http.NewRequest("GET", "https://www.googleapis.com/oauth2/v2/userinfo", nil)
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+
+ client := &http.Client{Timeout: 10 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ body, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("userinfo request failed: %s", string(body))
+ }
+
+ var userInfo struct {
+ Email string `json:"email"`
+ }
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ return "", err
+ }
+ return userInfo.Email, nil
+}
diff --git a/cmd/picoclaw-launcher/internal/server/logbuffer.go b/cmd/picoclaw-launcher/internal/server/logbuffer.go
new file mode 100644
index 000000000..4d70f6466
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/logbuffer.go
@@ -0,0 +1,99 @@
+package server
+
+import "sync"
+
+// LogBuffer is a thread-safe ring buffer that stores the most recent N log lines.
+// It supports incremental reads via LinesSince and tracks a runID that increments
+// on each Reset (used to detect gateway restarts).
+type LogBuffer struct {
+ mu sync.RWMutex
+ lines []string
+ cap int
+ total int // total lines ever appended in current run
+ runID int
+}
+
+// NewLogBuffer creates a LogBuffer with the given capacity.
+func NewLogBuffer(capacity int) *LogBuffer {
+ return &LogBuffer{
+ lines: make([]string, 0, capacity),
+ cap: capacity,
+ }
+}
+
+// Append adds a line to the buffer. If the buffer is full, the oldest line is evicted.
+func (b *LogBuffer) Append(line string) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ if len(b.lines) < b.cap {
+ b.lines = append(b.lines, line)
+ } else {
+ b.lines[b.total%b.cap] = line
+ }
+
+ b.total++
+}
+
+// Reset clears the buffer and increments the runID. Call this when starting a new gateway process.
+func (b *LogBuffer) Reset() {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ b.lines = b.lines[:0]
+ b.total = 0
+ b.runID++
+}
+
+// LinesSince returns lines appended after the given offset, the current total count, and the runID.
+// If offset >= total, no lines are returned. If offset is too old (evicted), all buffered lines are returned.
+func (b *LogBuffer) LinesSince(offset int) (lines []string, total int, runID int) {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ total = b.total
+ runID = b.runID
+
+ if offset >= b.total {
+ return nil, total, runID
+ }
+
+ buffered := len(b.lines)
+
+ // How many new lines since offset
+ newCount := b.total - offset
+ if newCount > buffered {
+ newCount = buffered
+ }
+
+ result := make([]string, newCount)
+
+ if b.total <= b.cap {
+ // Buffer hasn't wrapped yet — simple slice
+ copy(result, b.lines[buffered-newCount:])
+ } else {
+ // Buffer has wrapped — read from ring
+ start := (b.total - newCount) % b.cap
+ for i := range newCount {
+ result[i] = b.lines[(start+i)%b.cap]
+ }
+ }
+
+ return result, total, runID
+}
+
+// RunID returns the current run identifier.
+func (b *LogBuffer) RunID() int {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ return b.runID
+}
+
+// Total returns the total number of lines appended in the current run.
+func (b *LogBuffer) Total() int {
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ return b.total
+}
diff --git a/cmd/picoclaw-launcher/internal/server/logbuffer_test.go b/cmd/picoclaw-launcher/internal/server/logbuffer_test.go
new file mode 100644
index 000000000..dc525be16
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/logbuffer_test.go
@@ -0,0 +1,116 @@
+package server
+
+import (
+ "fmt"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestLogBuffer_Basic(t *testing.T) {
+ buf := NewLogBuffer(5)
+
+ // Empty buffer
+ lines, total, runID := buf.LinesSince(0)
+ assert.Nil(t, lines)
+ assert.Equal(t, 0, total)
+ assert.Equal(t, 0, runID)
+
+ // Append some lines
+ buf.Append("line1")
+ buf.Append("line2")
+ buf.Append("line3")
+
+ lines, total, runID = buf.LinesSince(0)
+ assert.Equal(t, []string{"line1", "line2", "line3"}, lines)
+ assert.Equal(t, 3, total)
+ assert.Equal(t, 0, runID)
+
+ // Incremental read
+ lines, total, _ = buf.LinesSince(2)
+ assert.Equal(t, []string{"line3"}, lines)
+ assert.Equal(t, 3, total)
+
+ // No new lines
+ lines, total, _ = buf.LinesSince(3)
+ assert.Nil(t, lines)
+ assert.Equal(t, 3, total)
+}
+
+func TestLogBuffer_Wrap(t *testing.T) {
+ buf := NewLogBuffer(3)
+
+ buf.Append("a")
+ buf.Append("b")
+ buf.Append("c")
+ buf.Append("d") // evicts "a"
+ buf.Append("e") // evicts "b"
+
+ lines, total, _ := buf.LinesSince(0)
+ assert.Equal(t, []string{"c", "d", "e"}, lines)
+ assert.Equal(t, 5, total)
+
+ // Incremental after wrap
+ lines, total, _ = buf.LinesSince(3)
+ assert.Equal(t, []string{"d", "e"}, lines)
+ assert.Equal(t, 5, total)
+
+ // Offset too old (before buffer start), get all buffered
+ lines, total, _ = buf.LinesSince(1)
+ assert.Equal(t, []string{"c", "d", "e"}, lines)
+ assert.Equal(t, 5, total)
+}
+
+func TestLogBuffer_Reset(t *testing.T) {
+ buf := NewLogBuffer(5)
+
+ buf.Append("before")
+ assert.Equal(t, 0, buf.RunID())
+
+ buf.Reset()
+ assert.Equal(t, 1, buf.RunID())
+ assert.Equal(t, 0, buf.Total())
+
+ lines, total, runID := buf.LinesSince(0)
+ assert.Nil(t, lines)
+ assert.Equal(t, 0, total)
+ assert.Equal(t, 1, runID)
+
+ buf.Append("after")
+ lines, total, runID = buf.LinesSince(0)
+ assert.Equal(t, []string{"after"}, lines)
+ assert.Equal(t, 1, total)
+ assert.Equal(t, 1, runID)
+}
+
+func TestLogBuffer_Concurrent(t *testing.T) {
+ buf := NewLogBuffer(100)
+ var wg sync.WaitGroup
+
+ // 10 writers
+ for i := range 10 {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ for j := range 50 {
+ buf.Append(fmt.Sprintf("writer-%d-line-%d", id, j))
+ }
+ }(i)
+ }
+
+ // 5 readers
+ for range 5 {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for range 100 {
+ buf.LinesSince(0)
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ assert.Equal(t, 500, buf.Total())
+}
diff --git a/cmd/picoclaw-launcher/internal/server/process.go b/cmd/picoclaw-launcher/internal/server/process.go
new file mode 100644
index 000000000..bc2129bf5
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/process.go
@@ -0,0 +1,232 @@
+package server
+
+import (
+ "bufio"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// gatewayLogs stores captured stdout/stderr from the gateway process launched by the launcher.
+var gatewayLogs = NewLogBuffer(200)
+
+// RegisterProcessAPI registers endpoints to start, stop and check status of the picoclaw gateway.
+func RegisterProcessAPI(mux *http.ServeMux, absPath string) {
+ mux.HandleFunc("GET /api/process/status", func(w http.ResponseWriter, r *http.Request) {
+ handleStatusGateway(w, r, absPath)
+ })
+ mux.HandleFunc("POST /api/process/start", handleStartGateway)
+ mux.HandleFunc("POST /api/process/stop", handleStopGateway)
+}
+
+func handleStartGateway(w http.ResponseWriter, r *http.Request) {
+ // Locate picoclaw executable:
+ // 1. Try same directory as current executable
+ // 2. Fallback to just "picoclaw" (relies on $PATH)
+ execPath := "picoclaw"
+
+ if exe, err := os.Executable(); err == nil {
+ dir := filepath.Dir(exe)
+ candidate := filepath.Join(dir, "picoclaw")
+ if runtime.GOOS == "windows" {
+ candidate += ".exe"
+ }
+
+ if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
+ execPath = candidate
+ }
+ }
+
+ cmd := exec.Command(execPath, "gateway")
+
+ stdoutPipe, err := cmd.StdoutPipe()
+ if err != nil {
+ log.Printf("Failed to create stdout pipe: %v\n", err)
+ http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ stderrPipe, err := cmd.StderrPipe()
+ if err != nil {
+ log.Printf("Failed to create stderr pipe: %v\n", err)
+ http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ // Clear old logs and increment runID before starting
+ gatewayLogs.Reset()
+
+ if err := cmd.Start(); err != nil {
+ log.Printf("Failed to start picoclaw gateway: %v\n", err)
+ http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ // Read stdout and stderr into the log buffer
+ go scanPipe(stdoutPipe, gatewayLogs)
+ go scanPipe(stderrPipe, gatewayLogs)
+
+ // Wait for the process to exit in the background to avoid zombies
+ go func() {
+ if err := cmd.Wait(); err != nil {
+ log.Printf("Gateway process exited: %v\n", err)
+ }
+ }()
+
+ log.Printf("Started picoclaw gateway (PID: %d) from %s\n", cmd.Process.Pid, execPath)
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]any{
+ "status": "ok",
+ "pid": cmd.Process.Pid,
+ })
+}
+
+// scanPipe reads lines from r and appends them to buf. It returns when r reaches EOF.
+func scanPipe(r io.Reader, buf *LogBuffer) {
+ scanner := bufio.NewScanner(r)
+ scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) // up to 1MB per line
+
+ for scanner.Scan() {
+ buf.Append(scanner.Text())
+ }
+}
+
+func handleStopGateway(w http.ResponseWriter, r *http.Request) {
+ var err error
+ if runtime.GOOS == "windows" {
+ // Kill via taskkill finding picoclaw.exe (though it might kill this config tool if it's named picoclaw-launcher.exe...? No, /IM does exact match usually, but just to be safe let's stop exactly picoclaw.exe)
+ // Alternatively, we use powershell to kill processes with commandline containing 'gateway'
+ psCmd := `Get-WmiObject Win32_Process | Where-Object { $_.CommandLine -match 'picoclaw.*gateway' } | ForEach-Object { Stop-Process $_.ProcessId -Force }`
+ err = exec.Command("powershell", "-Command", psCmd).Run()
+ } else {
+ // Linux/macOS
+ err = exec.Command("pkill", "-f", "picoclaw gateway").Run()
+ }
+
+ if err != nil {
+ log.Printf("Warning: Failed to stop gateway (perhaps not running?): %v\n", err)
+ // We still return 200 OK because pkill returns an error if no process was found
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]any{
+ "status": "ok", // or "not_found"
+ "msg": "Stop command executed, but returned error (process might not be running).",
+ "error": err.Error(),
+ })
+ return
+ }
+
+ log.Printf("Stopped picoclaw gateway processes.\n")
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "ok",
+ })
+}
+
+func handleStatusGateway(w http.ResponseWriter, r *http.Request, absPath string) {
+ cfg, cfgErr := config.LoadConfig(absPath)
+ host := "127.0.0.1"
+ port := 18790
+ if cfgErr == nil && cfg != nil {
+ if cfg.Gateway.Host != "" && cfg.Gateway.Host != "0.0.0.0" {
+ host = cfg.Gateway.Host
+ }
+ if cfg.Gateway.Port != 0 {
+ port = cfg.Gateway.Port
+ }
+ }
+
+ url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port)))
+ client := http.Client{Timeout: 2 * time.Second}
+ resp, err := client.Get(url)
+
+ // Build the response data map
+ data := map[string]any{}
+
+ if err != nil {
+ data["process_status"] = "stopped"
+ data["error"] = err.Error()
+ } else {
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ data["process_status"] = "error"
+ data["status_code"] = resp.StatusCode
+ } else {
+ var healthData map[string]any
+ if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil {
+ data["process_status"] = "error"
+ data["error"] = "invalid response from gateway"
+ } else {
+ // Gateway is running and responded properly — merge health data
+ for k, v := range healthData {
+ data[k] = v
+ }
+ data["process_status"] = "running"
+ }
+ }
+ }
+
+ // Append log data from the buffer
+ appendLogData(r, data)
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(data)
+}
+
+// appendLogData reads log_offset and log_run_id query params from the request and
+// populates the response data map with incremental log lines.
+func appendLogData(r *http.Request, data map[string]any) {
+ clientOffset := 0
+ clientRunID := -1
+
+ if v := r.URL.Query().Get("log_offset"); v != "" {
+ if n, err := strconv.Atoi(v); err == nil {
+ clientOffset = n
+ }
+ }
+
+ if v := r.URL.Query().Get("log_run_id"); v != "" {
+ if n, err := strconv.Atoi(v); err == nil {
+ clientRunID = n
+ }
+ }
+
+ runID := gatewayLogs.RunID()
+
+ // If runID is 0 (never reset = never launched from this launcher), report no source
+ if runID == 0 {
+ data["logs"] = []string{}
+ data["log_total"] = 0
+ data["log_run_id"] = 0
+ data["log_source"] = "none"
+ return
+ }
+
+ // If the client's runID doesn't match, send all buffered lines (gateway restarted)
+ offset := clientOffset
+ if clientRunID != runID {
+ offset = 0
+ }
+
+ lines, total, runID := gatewayLogs.LinesSince(offset)
+ if lines == nil {
+ lines = []string{}
+ }
+
+ data["logs"] = lines
+ data["log_total"] = total
+ data["log_run_id"] = runID
+ data["log_source"] = "launcher"
+}
diff --git a/cmd/picoclaw-launcher/internal/server/server.go b/cmd/picoclaw-launcher/internal/server/server.go
new file mode 100644
index 000000000..4fc68f04c
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/server.go
@@ -0,0 +1,196 @@
+package server
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/auth"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+const DefaultPort = "18800"
+
+// providerStatus represents the auth status of a single provider in API responses.
+type providerStatus struct {
+ Provider string `json:"provider"`
+ AuthMethod string `json:"auth_method"`
+ Status string `json:"status"`
+ AccountID string `json:"account_id,omitempty"`
+ Email string `json:"email,omitempty"`
+ ProjectID string `json:"project_id,omitempty"`
+ ExpiresAt string `json:"expires_at,omitempty"`
+}
+
+// ── Route registration ───────────────────────────────────────────
+
+func RegisterConfigAPI(mux *http.ServeMux, absPath string) {
+ // GET /api/config — read config
+ mux.HandleFunc("GET /api/config", func(w http.ResponseWriter, r *http.Request) {
+ cfg, err := config.LoadConfig(absPath)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ resp := map[string]any{
+ "config": cfg,
+ "path": absPath,
+ }
+ enc := json.NewEncoder(w)
+ enc.SetIndent("", " ")
+ if err := enc.Encode(resp); err != nil {
+ log.Printf("Failed to encode response: %v", err)
+ }
+ })
+
+ // PUT /api/config — save config
+ mux.HandleFunc("PUT /api/config", func(w http.ResponseWriter, r *http.Request) {
+ body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
+ if err != nil {
+ http.Error(w, "Failed to read request body", http.StatusBadRequest)
+ return
+ }
+ defer r.Body.Close()
+
+ var cfg config.Config
+ if err := json.Unmarshal(body, &cfg); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ if err := config.SaveConfig(absPath, &cfg); err != nil {
+ http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
+ })
+}
+
+func RegisterAuthAPI(mux *http.ServeMux, absPath string) {
+ // GET /api/auth/status — all authenticated providers + pending login state
+ mux.HandleFunc("GET /api/auth/status", func(w http.ResponseWriter, r *http.Request) {
+ store, err := auth.LoadStore()
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to load auth store: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ result := []providerStatus{}
+ for name, cred := range store.Credentials {
+ status := "active"
+ if cred.IsExpired() {
+ status = "expired"
+ } else if cred.NeedsRefresh() {
+ status = "needs_refresh"
+ }
+ ps := providerStatus{
+ Provider: name,
+ AuthMethod: cred.AuthMethod,
+ Status: status,
+ AccountID: cred.AccountID,
+ Email: cred.Email,
+ ProjectID: cred.ProjectID,
+ }
+ if !cred.ExpiresAt.IsZero() {
+ ps.ExpiresAt = cred.ExpiresAt.Format(time.RFC3339)
+ }
+ result = append(result, ps)
+ }
+
+ // Include pending device code state
+ var pendingDevice map[string]any
+ activeDeviceSessionMu.Lock()
+ if activeDeviceSession != nil {
+ activeDeviceSession.mu.Lock()
+ pendingDevice = map[string]any{
+ "provider": activeDeviceSession.Provider,
+ "status": activeDeviceSession.Status,
+ "device_url": activeDeviceSession.Info.VerifyURL,
+ "user_code": activeDeviceSession.Info.UserCode,
+ }
+ if activeDeviceSession.Error != "" {
+ pendingDevice["error"] = activeDeviceSession.Error
+ }
+ if activeDeviceSession.Done {
+ activeDeviceSession.mu.Unlock()
+ activeDeviceSession = nil
+ } else {
+ activeDeviceSession.mu.Unlock()
+ }
+ }
+ activeDeviceSessionMu.Unlock()
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]any{
+ "providers": result,
+ "pending_device": pendingDevice,
+ })
+ })
+
+ // POST /api/auth/login — initiate provider login
+ mux.HandleFunc("POST /api/auth/login", func(w http.ResponseWriter, r *http.Request) {
+ var req struct {
+ Provider string `json:"provider"`
+ Token string `json:"token,omitempty"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+
+ switch req.Provider {
+ case "openai":
+ handleOpenAILogin(w, absPath)
+ case "anthropic":
+ handleAnthropicLogin(w, req.Token, absPath)
+ case "google-antigravity", "antigravity":
+ handleGoogleAntigravityLogin(w, r, absPath)
+ default:
+ http.Error(
+ w,
+ fmt.Sprintf(
+ "Unsupported provider: %s (supported: openai, anthropic, google-antigravity)",
+ req.Provider,
+ ),
+ http.StatusBadRequest,
+ )
+ }
+ })
+
+ // POST /api/auth/logout — logout a provider
+ mux.HandleFunc("POST /api/auth/logout", func(w http.ResponseWriter, r *http.Request) {
+ var req struct {
+ Provider string `json:"provider"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
+ return
+ }
+
+ if req.Provider == "" {
+ if err := auth.DeleteAllCredentials(); err != nil {
+ http.Error(w, fmt.Sprintf("Failed to logout: %v", err), http.StatusInternalServerError)
+ return
+ }
+ clearAllAuthMethodsInConfig(absPath)
+ } else {
+ if err := auth.DeleteCredential(req.Provider); err != nil {
+ http.Error(w, fmt.Sprintf("Failed to logout: %v", err), http.StatusInternalServerError)
+ return
+ }
+ clearAuthMethodInConfig(absPath, req.Provider)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
+ })
+
+ // GET /auth/callback — OAuth browser callback for Google Antigravity
+ mux.HandleFunc("GET /auth/callback", handleOAuthCallback)
+}
diff --git a/cmd/picoclaw-launcher/internal/server/server_test.go b/cmd/picoclaw-launcher/internal/server/server_test.go
new file mode 100644
index 000000000..c87e93d8c
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/server_test.go
@@ -0,0 +1,247 @@
+package server
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// ── Config API tests ─────────────────────────────────────────────
+
+func setupConfigMux(t *testing.T, cfg *config.Config) (*http.ServeMux, string) {
+ t.Helper()
+ dir := t.TempDir()
+ path := filepath.Join(dir, "config.json")
+ data, err := json.MarshalIndent(cfg, "", " ")
+ if err != nil {
+ t.Fatalf("marshal config: %v", err)
+ }
+ if err := os.WriteFile(path, data, 0o600); err != nil {
+ t.Fatalf("write config: %v", err)
+ }
+
+ mux := http.NewServeMux()
+ RegisterConfigAPI(mux, path)
+ RegisterAuthAPI(mux, path)
+ return mux, path
+}
+
+func TestGetConfig(t *testing.T) {
+ cfg := &config.Config{
+ ModelList: []config.ModelConfig{
+ {ModelName: "gpt-4o", Model: "openai/gpt-4o"},
+ },
+ }
+ mux, path := setupConfigMux(t, cfg)
+
+ req := httptest.NewRequest("GET", "/api/config", nil)
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("GET /api/config: expected 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ var resp struct {
+ Config config.Config `json:"config"`
+ Path string `json:"path"`
+ }
+ if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+
+ if resp.Path != path {
+ t.Errorf("expected path %q, got %q", path, resp.Path)
+ }
+ if len(resp.Config.ModelList) != 1 {
+ t.Errorf("expected 1 model, got %d", len(resp.Config.ModelList))
+ }
+}
+
+func TestGetConfig_MissingFile_ReturnsDefault(t *testing.T) {
+ mux := http.NewServeMux()
+ RegisterConfigAPI(mux, "/tmp/nonexistent-picoclaw-launcher-test/config.json")
+
+ req := httptest.NewRequest("GET", "/api/config", nil)
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ // LoadConfig returns a default empty config when file is missing
+ if w.Code != http.StatusOK {
+ t.Errorf("expected 200 for missing file (default config), got %d", w.Code)
+ }
+}
+
+func TestPutConfig(t *testing.T) {
+ cfg := &config.Config{}
+ mux, path := setupConfigMux(t, cfg)
+
+ newCfg := config.Config{
+ ModelList: []config.ModelConfig{
+ {ModelName: "claude", Model: "anthropic/claude-sonnet-4.6", AuthMethod: "token"},
+ },
+ }
+ body, _ := json.Marshal(newCfg)
+
+ req := httptest.NewRequest("PUT", "/api/config", strings.NewReader(string(body)))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("PUT /api/config: expected 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ saved, err := config.LoadConfig(path)
+ if err != nil {
+ t.Fatalf("load saved config: %v", err)
+ }
+ if len(saved.ModelList) != 1 {
+ t.Fatalf("expected 1 model saved, got %d", len(saved.ModelList))
+ }
+ if saved.ModelList[0].Model != "anthropic/claude-sonnet-4.6" {
+ t.Errorf("expected model anthropic/claude-sonnet-4.6, got %q", saved.ModelList[0].Model)
+ }
+}
+
+func TestPutConfig_InvalidJSON(t *testing.T) {
+ cfg := &config.Config{}
+ mux, _ := setupConfigMux(t, cfg)
+
+ req := httptest.NewRequest("PUT", "/api/config", strings.NewReader("{invalid"))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for invalid JSON, got %d", w.Code)
+ }
+}
+
+// ── Auth API tests ───────────────────────────────────────────────
+
+func TestAuthStatus(t *testing.T) {
+ cfg := &config.Config{}
+ mux, _ := setupConfigMux(t, cfg)
+
+ req := httptest.NewRequest("GET", "/api/auth/status", nil)
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("GET /api/auth/status: expected 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ var resp struct {
+ Providers []providerStatus `json:"providers"`
+ PendingDevice map[string]any `json:"pending_device"`
+ }
+ if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("decode response: %v", err)
+ }
+
+ // providers should be a non-nil list (could be empty)
+ if resp.Providers == nil {
+ t.Error("providers should not be nil")
+ }
+}
+
+func TestAuthLogin_UnsupportedProvider(t *testing.T) {
+ cfg := &config.Config{}
+ mux, _ := setupConfigMux(t, cfg)
+
+ body := `{"provider": "unsupported"}`
+ req := httptest.NewRequest("POST", "/api/auth/login", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for unsupported provider, got %d", w.Code)
+ }
+}
+
+func TestAuthLogin_AnthropicNoToken(t *testing.T) {
+ cfg := &config.Config{}
+ mux, _ := setupConfigMux(t, cfg)
+
+ body := `{"provider": "anthropic"}`
+ req := httptest.NewRequest("POST", "/api/auth/login", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for anthropic without token, got %d", w.Code)
+ }
+}
+
+func TestAuthLogin_InvalidBody(t *testing.T) {
+ cfg := &config.Config{}
+ mux, _ := setupConfigMux(t, cfg)
+
+ req := httptest.NewRequest("POST", "/api/auth/login", strings.NewReader("{bad"))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for invalid JSON body, got %d", w.Code)
+ }
+}
+
+func TestAuthLogout_InvalidBody(t *testing.T) {
+ cfg := &config.Config{}
+ mux, _ := setupConfigMux(t, cfg)
+
+ req := httptest.NewRequest("POST", "/api/auth/logout", strings.NewReader("{bad"))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for invalid body, got %d", w.Code)
+ }
+}
+
+func TestOAuthCallback_InvalidState(t *testing.T) {
+ cfg := &config.Config{}
+ mux, _ := setupConfigMux(t, cfg)
+
+ req := httptest.NewRequest("GET", "/auth/callback?state=invalid&code=test", nil)
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for invalid state, got %d", w.Code)
+ }
+}
+
+// ── Utility tests ────────────────────────────────────────────────
+
+func TestDefaultConfigPath(t *testing.T) {
+ path := DefaultConfigPath()
+ if path == "" {
+ t.Error("defaultConfigPath should not return empty")
+ }
+ if !strings.HasSuffix(path, filepath.Join(".picoclaw", "config.json")) {
+ t.Errorf("expected path ending with .picoclaw/config.json, got %q", path)
+ }
+}
+
+func TestGetLocalIP(t *testing.T) {
+ // Just ensure it doesn't panic; IP may or may not be available
+ ip := GetLocalIP()
+ if ip != "" {
+ // If returned, should look like an IP
+ if !strings.Contains(ip, ".") {
+ t.Errorf("getLocalIP returned non-IPv4 looking string: %q", ip)
+ }
+ }
+}
diff --git a/cmd/picoclaw-launcher/internal/server/utils.go b/cmd/picoclaw-launcher/internal/server/utils.go
new file mode 100644
index 000000000..a46adbece
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/server/utils.go
@@ -0,0 +1,28 @@
+package server
+
+import (
+ "net"
+ "os"
+ "path/filepath"
+)
+
+func DefaultConfigPath() string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "config.json"
+ }
+ return filepath.Join(home, ".picoclaw", "config.json")
+}
+
+func GetLocalIP() string {
+ addrs, err := net.InterfaceAddrs()
+ if err != nil {
+ return ""
+ }
+ for _, a := range addrs {
+ if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil {
+ return ipnet.IP.String()
+ }
+ }
+ return ""
+}
diff --git a/cmd/picoclaw-launcher/internal/ui/index.html b/cmd/picoclaw-launcher/internal/ui/index.html
new file mode 100644
index 000000000..93893fd75
--- /dev/null
+++ b/cmd/picoclaw-launcher/internal/ui/index.html
@@ -0,0 +1,1999 @@
+
+
+
+
+
+
+
+ PicoClaw Config
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Models
+
Manage LLM model configurations. Models without an API key are grayed out. Only available models can be set as primary.
+
+
+
+
+
+
+
+
Provider Authentication
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Gateway Logs
+
Real-time output from the gateway process.
+
+
+
No logs available. Start the gateway to see output here.
+
+
+
+
+
+
Raw JSON
+
Directly edit the configuration file.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/cmd/picoclaw-launcher/main.go b/cmd/picoclaw-launcher/main.go
new file mode 100644
index 000000000..3323c31a8
--- /dev/null
+++ b/cmd/picoclaw-launcher/main.go
@@ -0,0 +1,127 @@
+// PicoClaw Launcher - Standalone HTTP service
+//
+// Provides a web-based JSON editor for picoclaw config files,
+// with OAuth provider authentication support.
+//
+// Usage:
+//
+// go build -o picoclaw-launcher ./cmd/picoclaw-launcher/
+// ./picoclaw-launcher [config.json]
+// ./picoclaw-launcher -public config.json
+
+package main
+
+import (
+ "embed"
+ "flag"
+ "fmt"
+ "io/fs"
+ "log"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "time"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw-launcher/internal/server"
+)
+
+//go:embed internal/ui/index.html
+var staticFiles embed.FS
+
+func main() {
+ public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only")
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "PicoClaw Launcher - A web-based configuration editor\n\n")
+ fmt.Fprintf(os.Stderr, "Usage: %s [options] [config.json]\n\n", os.Args[0])
+ fmt.Fprintf(os.Stderr, "Arguments:\n")
+ fmt.Fprintf(os.Stderr, " config.json Path to the configuration file (default: ~/.picoclaw/config.json)\n\n")
+ fmt.Fprintf(os.Stderr, "Options:\n")
+ flag.PrintDefaults()
+ fmt.Fprintf(os.Stderr, "\nExamples:\n")
+ fmt.Fprintf(os.Stderr, " %s Use default config path\n", os.Args[0])
+ fmt.Fprintf(os.Stderr, " %s ./config.json Specify a config file\n", os.Args[0])
+ fmt.Fprintf(
+ os.Stderr,
+ " %s -public ./config.json Allow access from other devices on the network\n",
+ os.Args[0],
+ )
+ }
+ flag.Parse()
+
+ configPath := server.DefaultConfigPath()
+ if flag.NArg() > 0 {
+ configPath = flag.Arg(0)
+ }
+
+ absPath, err := filepath.Abs(configPath)
+ if err != nil {
+ log.Fatalf("Failed to resolve config path: %v", err)
+ }
+
+ var addr string
+ if *public {
+ addr = "0.0.0.0:" + server.DefaultPort
+ } else {
+ addr = "127.0.0.1:" + server.DefaultPort
+ }
+
+ mux := http.NewServeMux()
+ server.RegisterConfigAPI(mux, absPath)
+ server.RegisterAuthAPI(mux, absPath)
+ server.RegisterProcessAPI(mux, absPath)
+
+ staticFS, err := fs.Sub(staticFiles, "internal/ui")
+ if err != nil {
+ log.Fatalf("Failed to create sub filesystem: %v", err)
+ }
+ mux.Handle("/", http.FileServer(http.FS(staticFS)))
+
+ // Print startup banner
+ fmt.Println("=============================================")
+ fmt.Println(" PicoClaw Launcher")
+ fmt.Println("=============================================")
+ fmt.Printf(" Config file : %s\n", absPath)
+ fmt.Printf(" Listen addr : %s\n\n", addr)
+ fmt.Println(" Open the following URL in your browser")
+ fmt.Println(" to view and edit the configuration:")
+ fmt.Println()
+ fmt.Printf(" >> http://localhost:%s <<\n", server.DefaultPort)
+ if *public {
+ if ip := server.GetLocalIP(); ip != "" {
+ fmt.Printf(" >> http://%s:%s <<\n", ip, server.DefaultPort)
+ }
+ }
+ fmt.Println()
+ // fmt.Println("=============================================")
+
+ go func() {
+ // Wait briefly to ensure the server is ready before opening the browser
+ time.Sleep(500 * time.Millisecond)
+ url := "http://localhost:" + server.DefaultPort
+ if err := openBrowser(url); err != nil {
+ log.Printf("Warning: Failed to auto-open browser: %v\n", err)
+ }
+ }()
+
+ if err := http.ListenAndServe(addr, mux); err != nil {
+ log.Fatalf("Server failed: %v", err)
+ }
+}
+
+// openBrowser automatically opens the given URL in the default browser.
+func openBrowser(url string) error {
+ var err error
+ switch runtime.GOOS {
+ case "linux":
+ err = exec.Command("xdg-open", url).Start()
+ case "windows":
+ err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
+ case "darwin":
+ err = exec.Command("open", url).Start()
+ default:
+ err = fmt.Errorf("unsupported platform")
+ }
+ return err
+}
diff --git a/cmd/picoclaw-launcher/winres/winres.json b/cmd/picoclaw-launcher/winres/winres.json
new file mode 100644
index 000000000..01ea7364c
--- /dev/null
+++ b/cmd/picoclaw-launcher/winres/winres.json
@@ -0,0 +1,22 @@
+{
+ "RT_GROUP_ICON": {
+ "APP": {
+ "0000": "../icon.ico"
+ }
+ },
+ "RT_MANIFEST": {
+ "#1": {
+ "0409": {
+ "identity": {
+ "name": "PicoClaw Launcher",
+ "version": "0.0.0.0"
+ },
+ "description": "PicoClaw Launcher - Web-based configuration editor",
+ "minimum-os": "win7",
+ "execution-level": "asInvoker",
+ "dpi-awareness": "system",
+ "use-common-controls-v6": true
+ }
+ }
+ }
+}
diff --git a/cmd/picoclaw/internal/agent/helpers.go b/cmd/picoclaw/internal/agent/helpers.go
index 746e9755e..f754abc65 100644
--- a/cmd/picoclaw/internal/agent/helpers.go
+++ b/cmd/picoclaw/internal/agent/helpers.go
@@ -48,6 +48,7 @@ func agentCmd(message, sessionKey, model string, debug bool) error {
}
msgBus := bus.NewMessageBus()
+ defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
// Print agent startup info (only for interactive mode)
diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go
index a06625dc9..5225340c7 100644
--- a/cmd/picoclaw/internal/gateway/helpers.go
+++ b/cmd/picoclaw/internal/gateway/helpers.go
@@ -2,25 +2,37 @@ package gateway
import (
"context"
- "errors"
"fmt"
- "net/http"
+ "log"
"os"
"os/signal"
"path/filepath"
- "strings"
"time"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
+ _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
+ _ "github.com/sipeed/picoclaw/pkg/channels/discord"
+ _ "github.com/sipeed/picoclaw/pkg/channels/feishu"
+ _ "github.com/sipeed/picoclaw/pkg/channels/line"
+ _ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
+ _ "github.com/sipeed/picoclaw/pkg/channels/onebot"
+ _ "github.com/sipeed/picoclaw/pkg/channels/pico"
+ _ "github.com/sipeed/picoclaw/pkg/channels/qq"
+ _ "github.com/sipeed/picoclaw/pkg/channels/slack"
+ _ "github.com/sipeed/picoclaw/pkg/channels/telegram"
+ _ "github.com/sipeed/picoclaw/pkg/channels/wecom"
+ _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp"
+ _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp_native"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/devices"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -105,48 +117,28 @@ func gatewayCmd(debug bool) error {
return tools.SilentResult(response)
})
- channelManager, err := channels.NewManager(cfg, msgBus)
+ // Create media store for file lifecycle management with TTL cleanup
+ mediaStore := media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
+ Enabled: cfg.Tools.MediaCleanup.Enabled,
+ MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
+ Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
+ })
+ mediaStore.Start()
+
+ channelManager, err := channels.NewManager(cfg, msgBus, mediaStore)
if err != nil {
+ mediaStore.Stop()
return fmt.Errorf("error creating channel manager: %w", err)
}
- // Inject channel manager into agent loop for command handling
+ // Inject channel manager and media store into agent loop
agentLoop.SetChannelManager(channelManager)
+ agentLoop.SetMediaStore(mediaStore)
- var transcriber *voice.GroqTranscriber
- groqAPIKey := cfg.Providers.Groq.APIKey
- if groqAPIKey == "" {
- for _, mc := range cfg.ModelList {
- if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" {
- groqAPIKey = mc.APIKey
- break
- }
- }
- }
- if groqAPIKey != "" {
- transcriber = voice.NewGroqTranscriber(groqAPIKey)
- logger.InfoC("voice", "Groq voice transcription enabled")
- }
-
- if transcriber != nil {
- if telegramChannel, ok := channelManager.GetChannel("telegram"); ok {
- if tc, ok := telegramChannel.(*channels.TelegramChannel); ok {
- tc.SetTranscriber(transcriber)
- logger.InfoC("voice", "Groq transcription attached to Telegram channel")
- }
- }
- if discordChannel, ok := channelManager.GetChannel("discord"); ok {
- if dc, ok := discordChannel.(*channels.DiscordChannel); ok {
- dc.SetTranscriber(transcriber)
- logger.InfoC("voice", "Groq transcription attached to Discord channel")
- }
- }
- if slackChannel, ok := channelManager.GetChannel("slack"); ok {
- if sc, ok := slackChannel.(*channels.SlackChannel); ok {
- sc.SetTranscriber(transcriber)
- logger.InfoC("voice", "Groq transcription attached to Slack channel")
- }
- }
+ // Wire up voice transcription if a supported provider is configured.
+ if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
+ agentLoop.SetTranscriber(transcriber)
+ logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
enabledChannels := channelManager.GetEnabledChannels()
@@ -184,16 +176,16 @@ func gatewayCmd(debug bool) error {
fmt.Println("✓ Device event service started")
}
+ // Setup shared HTTP server with health endpoints and webhook handlers
+ healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
+ addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
+ channelManager.SetupHTTPServer(addr, healthServer)
+
if err := channelManager.StartAll(ctx); err != nil {
fmt.Printf("Error starting channels: %v\n", err)
+ return err
}
- healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
- go func() {
- if err := healthServer.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) {
- logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()})
- }
- }()
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
go agentLoop.Run(ctx)
@@ -207,12 +199,19 @@ func gatewayCmd(debug bool) error {
cp.Close()
}
cancel()
- healthServer.Stop(context.Background())
+ msgBus.Close()
+
+ // Use a fresh context with timeout for graceful shutdown,
+ // since the original ctx is already canceled.
+ shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer shutdownCancel()
+
+ channelManager.StopAll(shutdownCtx)
deviceService.Stop()
heartbeatService.Stop()
cronService.Stop()
+ mediaStore.Stop()
agentLoop.Stop()
- channelManager.StopAll(ctx)
fmt.Println("✓ Gateway stopped")
return nil
@@ -232,7 +231,11 @@ func setupCronTool(
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
- cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
+ cronTool, err := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
+ if err != nil {
+ log.Fatalf("Critical error during CronTool initialization: %v", err)
+ }
+
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go
index 1f52df5dd..9655d3c08 100644
--- a/cmd/picoclaw/internal/helpers.go
+++ b/cmd/picoclaw/internal/helpers.go
@@ -19,6 +19,9 @@ var (
)
func GetConfigPath() string {
+ if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" {
+ return configPath
+ }
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "config.json")
}
diff --git a/cmd/picoclaw/internal/helpers_test.go b/cmd/picoclaw/internal/helpers_test.go
index 9342d141d..47e2f8c07 100644
--- a/cmd/picoclaw/internal/helpers_test.go
+++ b/cmd/picoclaw/internal/helpers_test.go
@@ -95,3 +95,13 @@ func TestGetConfigPath_Windows(t *testing.T) {
func TestGetVersion(t *testing.T) {
assert.Equal(t, "dev", GetVersion())
}
+
+func TestGetConfigPath_WithEnv(t *testing.T) {
+ t.Setenv("PICOCLAW_CONFIG", "/tmp/custom/config.json")
+ t.Setenv("HOME", "/tmp/home") // Also set home to ensure env is preferred
+
+ got := GetConfigPath()
+ want := "/tmp/custom/config.json"
+
+ assert.Equal(t, want, got)
+}
diff --git a/cmd/picoclaw/internal/migrate/command.go b/cmd/picoclaw/internal/migrate/command.go
index fb1cee164..76352c9db 100644
--- a/cmd/picoclaw/internal/migrate/command.go
+++ b/cmd/picoclaw/internal/migrate/command.go
@@ -11,19 +11,21 @@ func NewMigrateCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "migrate",
- Short: "Migrate from OpenClaw to PicoClaw",
+ Short: "Migrate from xxxclaw(openclaw, etc.) to picoclaw",
Args: cobra.NoArgs,
Example: ` picoclaw migrate
+ picoclaw migrate --from openclaw
picoclaw migrate --dry-run
picoclaw migrate --refresh
picoclaw migrate --force`,
RunE: func(cmd *cobra.Command, _ []string) error {
- result, err := migrate.Run(opts)
+ m := migrate.NewMigrateInstance(opts)
+ result, err := m.Run(opts)
if err != nil {
return err
}
if !opts.DryRun {
- migrate.PrintSummary(result)
+ m.PrintSummary(result)
}
return nil
},
@@ -31,6 +33,8 @@ func NewMigrateCommand() *cobra.Command {
cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false,
"Show what would be migrated without making changes")
+ cmd.Flags().StringVar(&opts.Source, "from", "openclaw",
+ "Source to migrate from (e.g., openclaw)")
cmd.Flags().BoolVar(&opts.Refresh, "refresh", false,
"Re-sync workspace files from OpenClaw (repeatable)")
cmd.Flags().BoolVar(&opts.ConfigOnly, "config-only", false,
@@ -39,10 +43,10 @@ func NewMigrateCommand() *cobra.Command {
"Only migrate workspace files, skip config")
cmd.Flags().BoolVar(&opts.Force, "force", false,
"Skip confirmation prompts")
- cmd.Flags().StringVar(&opts.OpenClawHome, "openclaw-home", "",
- "Override OpenClaw home directory (default: ~/.openclaw)")
- cmd.Flags().StringVar(&opts.PicoClawHome, "picoclaw-home", "",
- "Override PicoClaw home directory (default: ~/.picoclaw)")
+ cmd.Flags().StringVar(&opts.SourceHome, "source-home", "",
+ "Override source home directory (default: ~/.openclaw)")
+ cmd.Flags().StringVar(&opts.TargetHome, "target-home", "",
+ "Override target home directory (default: ~/.picoclaw)")
return cmd
}
diff --git a/cmd/picoclaw/internal/migrate/command_test.go b/cmd/picoclaw/internal/migrate/command_test.go
index 1948aa327..5110249a2 100644
--- a/cmd/picoclaw/internal/migrate/command_test.go
+++ b/cmd/picoclaw/internal/migrate/command_test.go
@@ -13,7 +13,7 @@ func TestNewMigrateCommand(t *testing.T) {
require.NotNil(t, cmd)
assert.Equal(t, "migrate", cmd.Use)
- assert.Equal(t, "Migrate from OpenClaw to PicoClaw", cmd.Short)
+ assert.Equal(t, "Migrate from xxxclaw(openclaw, etc.) to picoclaw", cmd.Short)
assert.Len(t, cmd.Aliases, 0)
@@ -33,6 +33,6 @@ func TestNewMigrateCommand(t *testing.T) {
assert.NotNil(t, cmd.Flags().Lookup("config-only"))
assert.NotNil(t, cmd.Flags().Lookup("workspace-only"))
assert.NotNil(t, cmd.Flags().Lookup("force"))
- assert.NotNil(t, cmd.Flags().Lookup("openclaw-home"))
- assert.NotNil(t, cmd.Flags().Lookup("picoclaw-home"))
+ assert.NotNil(t, cmd.Flags().Lookup("source-home"))
+ assert.NotNil(t, cmd.Flags().Lookup("target-home"))
}
diff --git a/cmd/picoclaw/internal/onboard/helpers_test.go b/cmd/picoclaw/internal/onboard/helpers_test.go
new file mode 100644
index 000000000..f3e0c92e0
--- /dev/null
+++ b/cmd/picoclaw/internal/onboard/helpers_test.go
@@ -0,0 +1,25 @@
+package onboard
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestCopyEmbeddedToTargetUsesAgentsMarkdown(t *testing.T) {
+ targetDir := t.TempDir()
+
+ if err := copyEmbeddedToTarget(targetDir); err != nil {
+ t.Fatalf("copyEmbeddedToTarget() error = %v", err)
+ }
+
+ agentsPath := filepath.Join(targetDir, "AGENTS.md")
+ if _, err := os.Stat(agentsPath); err != nil {
+ t.Fatalf("expected %s to exist: %v", agentsPath, err)
+ }
+
+ legacyPath := filepath.Join(targetDir, "AGENT.md")
+ if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
+ t.Fatalf("expected legacy file %s to be absent, got err=%v", legacyPath, err)
+ }
+}
diff --git a/cmd/picoclaw/internal/skills/command.go b/cmd/picoclaw/internal/skills/command.go
index 7f8bd011d..65eb127b9 100644
--- a/cmd/picoclaw/internal/skills/command.go
+++ b/cmd/picoclaw/internal/skills/command.go
@@ -71,7 +71,7 @@ func NewSkillsCommand() *cobra.Command {
newInstallBuiltinCommand(workspaceFn),
newListBuiltinCommand(),
newRemoveCommand(installerFn),
- newSearchCommand(installerFn),
+ newSearchCommand(),
newShowCommand(loaderFn),
)
diff --git a/cmd/picoclaw/internal/skills/helpers.go b/cmd/picoclaw/internal/skills/helpers.go
index 439b81a4f..a59a2013a 100644
--- a/cmd/picoclaw/internal/skills/helpers.go
+++ b/cmd/picoclaw/internal/skills/helpers.go
@@ -15,6 +15,8 @@ import (
"github.com/sipeed/picoclaw/pkg/utils"
)
+const skillsSearchMaxResults = 20
+
func skillsListCmd(loader *skills.SkillsLoader) {
allSkills := loader.ListSkills()
@@ -215,34 +217,43 @@ func skillsListBuiltinCmd() {
}
}
-func skillsSearchCmd(installer *skills.SkillInstaller) {
+func skillsSearchCmd(query string) {
fmt.Println("Searching for available skills...")
+ cfg, err := internal.LoadConfig()
+ if err != nil {
+ fmt.Printf("✗ Failed to load config: %v\n", err)
+ return
+ }
+
+ registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
+ MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
+ ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
+ })
+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
- availableSkills, err := installer.ListAvailableSkills(ctx)
+ results, err := registryMgr.SearchAll(ctx, query, skillsSearchMaxResults)
if err != nil {
fmt.Printf("✗ Failed to fetch skills list: %v\n", err)
return
}
- if len(availableSkills) == 0 {
+ if len(results) == 0 {
fmt.Println("No skills available.")
return
}
- fmt.Printf("\nAvailable Skills (%d):\n", len(availableSkills))
+ fmt.Printf("\nAvailable Skills (%d):\n", len(results))
fmt.Println("--------------------")
- for _, skill := range availableSkills {
- fmt.Printf(" 📦 %s\n", skill.Name)
- fmt.Printf(" %s\n", skill.Description)
- fmt.Printf(" Repo: %s\n", skill.Repository)
- if skill.Author != "" {
- fmt.Printf(" Author: %s\n", skill.Author)
- }
- if len(skill.Tags) > 0 {
- fmt.Printf(" Tags: %v\n", skill.Tags)
+ for _, result := range results {
+ fmt.Printf(" 📦 %s\n", result.DisplayName)
+ fmt.Printf(" %s\n", result.Summary)
+ fmt.Printf(" Slug: %s\n", result.Slug)
+ fmt.Printf(" Registry: %s\n", result.RegistryName)
+ if result.Version != "" {
+ fmt.Printf(" Version: %s\n", result.Version)
}
fmt.Println()
}
diff --git a/cmd/picoclaw/internal/skills/search.go b/cmd/picoclaw/internal/skills/search.go
index 53bc99109..54f72259f 100644
--- a/cmd/picoclaw/internal/skills/search.go
+++ b/cmd/picoclaw/internal/skills/search.go
@@ -2,20 +2,19 @@ package skills
import (
"github.com/spf13/cobra"
-
- "github.com/sipeed/picoclaw/pkg/skills"
)
-func newSearchCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command {
+func newSearchCommand() *cobra.Command {
cmd := &cobra.Command{
- Use: "search",
+ Use: "search [query]",
Short: "Search available skills",
- RunE: func(_ *cobra.Command, _ []string) error {
- installer, err := installerFn()
- if err != nil {
- return err
+ Args: cobra.MaximumNArgs(1),
+ RunE: func(_ *cobra.Command, args []string) error {
+ query := ""
+ if len(args) == 1 {
+ query = args[0]
}
- skillsSearchCmd(installer)
+ skillsSearchCmd(query)
return nil
},
}
diff --git a/cmd/picoclaw/internal/skills/search_test.go b/cmd/picoclaw/internal/skills/search_test.go
index 19f63a9ff..ed92e25cc 100644
--- a/cmd/picoclaw/internal/skills/search_test.go
+++ b/cmd/picoclaw/internal/skills/search_test.go
@@ -8,11 +8,11 @@ import (
)
func TestNewSearchSubcommand(t *testing.T) {
- cmd := newSearchCommand(nil)
+ cmd := newSearchCommand()
require.NotNil(t, cmd)
- assert.Equal(t, "search", cmd.Use)
+ assert.Equal(t, "search [query]", cmd.Use)
assert.Equal(t, "Search available skills", cmd.Short)
assert.Nil(t, cmd.Run)
diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go
index 6db69c990..d9263462e 100644
--- a/cmd/picoclaw/main.go
+++ b/cmd/picoclaw/main.go
@@ -48,7 +48,21 @@ func NewPicoclawCommand() *cobra.Command {
return cmd
}
+const (
+ colorBlue = "\033[1;38;2;62;93;185m"
+ colorRed = "\033[1;38;2;213;70;70m"
+ banner = "\r\n" +
+ colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" +
+ colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" +
+ colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" +
+ colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" +
+ colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
+ colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " +
+ "\033[0m\r\n"
+)
+
func main() {
+ fmt.Printf("%s", banner)
cmd := NewPicoclawCommand()
if err := cmd.Execute(); err != nil {
os.Exit(1)
diff --git a/config/config.example.json b/config/config.example.json
index 605f9dc1d..283ca2bef 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -6,7 +6,9 @@
"model_name": "gpt4",
"max_tokens": 8192,
"temperature": 0.7,
- "max_tool_iterations": 20
+ "max_tool_iterations": 20,
+ "summarize_message_threshold": 20,
+ "summarize_token_percent": 75
}
},
"model_list": [
@@ -49,33 +51,44 @@
"telegram": {
"enabled": false,
"token": "YOUR_TELEGRAM_BOT_TOKEN",
+ "base_url": "",
"proxy": "",
"allow_from": [
"YOUR_USER_ID"
- ]
+ ],
+ "reasoning_channel_id": ""
},
"discord": {
"enabled": false,
"token": "YOUR_DISCORD_BOT_TOKEN",
+ "proxy": "",
"allow_from": [],
- "mention_only": false
+ "group_trigger": {
+ "mention_only": false
+ },
+ "reasoning_channel_id": ""
},
"qq": {
"enabled": false,
"app_id": "YOUR_QQ_APP_ID",
"app_secret": "YOUR_QQ_APP_SECRET",
- "allow_from": []
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"maixcam": {
"enabled": false,
"host": "0.0.0.0",
"port": 18790,
- "allow_from": []
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"whatsapp": {
"enabled": false,
"bridge_url": "ws://localhost:3001",
- "allow_from": []
+ "use_native": false,
+ "session_store_path": "",
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"feishu": {
"enabled": false,
@@ -83,28 +96,30 @@
"app_secret": "",
"encrypt_key": "",
"verification_token": "",
- "allow_from": []
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"dingtalk": {
"enabled": false,
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
- "allow_from": []
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"slack": {
"enabled": false,
"bot_token": "xoxb-YOUR-BOT-TOKEN",
"app_token": "xapp-YOUR-APP-TOKEN",
- "allow_from": []
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"line": {
"enabled": false,
"channel_secret": "YOUR_LINE_CHANNEL_SECRET",
"channel_access_token": "YOUR_LINE_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
- "allow_from": []
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"onebot": {
"enabled": false,
@@ -112,33 +127,42 @@
"access_token": "",
"reconnect_interval": 5,
"group_trigger_prefix": [],
- "allow_from": []
+ "allow_from": [],
+ "reasoning_channel_id": ""
},
"wecom": {
- "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats",
+ "_comment": "WeCom Bot - Easier setup, supports group chats",
"enabled": false,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": [],
- "reply_timeout": 5
+ "reply_timeout": 5,
+ "reasoning_channel_id": ""
},
"wecom_app": {
- "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md",
+ "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only.",
"enabled": false,
"corp_id": "YOUR_CORP_ID",
"corp_secret": "YOUR_CORP_SECRET",
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": [],
- "reply_timeout": 5
+ "reply_timeout": 5,
+ "reasoning_channel_id": ""
+ },
+ "wecom_aibot": {
+ "_comment": "WeCom AI Bot (智能机器人) - Official WeCom AI Bot integration, supports proactive messaging and private chats.",
+ "enabled": false,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "max_steps": 10,
+ "welcome_message": "Hello! I'm your AI assistant. How can I help you today?",
+ "reasoning_channel_id": ""
}
},
"providers": {
@@ -228,6 +252,71 @@
"cron": {
"exec_timeout_minutes": 5
},
+ "mcp": {
+ "enabled": false,
+ "servers": {
+ "context7": {
+ "enabled": false,
+ "type": "http",
+ "url": "https://mcp.context7.com/mcp",
+ "headers": {
+ "CONTEXT7_API_KEY": "ctx7sk-xx"
+ }
+ },
+ "filesystem": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-filesystem",
+ "/tmp"
+ ]
+ },
+ "github": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-github"
+ ],
+ "env": {
+ "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
+ }
+ },
+ "brave-search": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-brave-search"
+ ],
+ "env": {
+ "BRAVE_API_KEY": "YOUR_BRAVE_API_KEY"
+ }
+ },
+ "postgres": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-postgres",
+ "postgresql://user:password@localhost/dbname"
+ ]
+ },
+ "slack": {
+ "enabled": false,
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-slack"
+ ],
+ "env": {
+ "SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN",
+ "SLACK_TEAM_ID": "YOUR_SLACK_TEAM_ID"
+ }
+ }
+ }
+ },
"exec": {
"enable_deny_patterns": false,
"custom_deny_patterns": []
diff --git a/Dockerfile b/docker/Dockerfile
similarity index 100%
rename from Dockerfile
rename to docker/Dockerfile
diff --git a/docker/Dockerfile.full b/docker/Dockerfile.full
new file mode 100644
index 000000000..30e1680d5
--- /dev/null
+++ b/docker/Dockerfile.full
@@ -0,0 +1,44 @@
+# ============================================================
+# Stage 1: Build the picoclaw binary
+# ============================================================
+FROM golang:1.26.0-alpine AS builder
+
+RUN apk add --no-cache git make
+
+WORKDIR /src
+
+# Cache dependencies
+COPY go.mod go.sum ./
+RUN go mod download
+
+# Copy source and build
+COPY . .
+RUN make build
+
+# ============================================================
+# Stage 2: Node.js-based runtime with full MCP support
+# ============================================================
+FROM node:24-alpine3.23
+
+# Install runtime dependencies
+RUN apk add --no-cache \
+ ca-certificates \
+ curl \
+ git \
+ python3 \
+ py3-pip
+
+# Install uv and symlink to system path
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
+ ln -s /root/.local/bin/uv /usr/local/bin/uv && \
+ ln -s /root/.local/bin/uvx /usr/local/bin/uvx && \
+ uv --version
+
+# Copy binary
+COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
+
+# Create picoclaw home directory
+RUN /usr/local/bin/picoclaw onboard
+
+ENTRYPOINT ["picoclaw"]
+CMD ["gateway"]
diff --git a/Dockerfile.goreleaser b/docker/Dockerfile.goreleaser
similarity index 58%
rename from Dockerfile.goreleaser
rename to docker/Dockerfile.goreleaser
index 0cdc8c6bd..68a02aae8 100644
--- a/Dockerfile.goreleaser
+++ b/docker/Dockerfile.goreleaser
@@ -5,6 +5,8 @@ ARG TARGETPLATFORM
RUN apk add --no-cache ca-certificates tzdata
COPY $TARGETPLATFORM/picoclaw /usr/local/bin/picoclaw
+COPY docker/entrypoint.sh /entrypoint.sh
-ENTRYPOINT ["picoclaw"]
-CMD ["gateway"]
+RUN chmod +x /entrypoint.sh
+
+ENTRYPOINT ["/entrypoint.sh"]
diff --git a/docker/docker-compose.full.yml b/docker/docker-compose.full.yml
new file mode 100644
index 000000000..6f34448c4
--- /dev/null
+++ b/docker/docker-compose.full.yml
@@ -0,0 +1,44 @@
+services:
+ # ─────────────────────────────────────────────
+ # PicoClaw Agent (one-shot query) - Full MCP Support
+ # docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent -m "Hello"
+ # ─────────────────────────────────────────────
+ picoclaw-agent:
+ build:
+ context: ..
+ dockerfile: docker/Dockerfile.full
+ container_name: picoclaw-agent-full
+ profiles:
+ - agent
+ volumes:
+ - ../config/config.json:/root/.picoclaw/config.json:ro
+ - picoclaw-workspace:/root/.picoclaw/workspace
+ - picoclaw-npm-cache:/root/.npm # npm cache for faster MCP server installs
+ entrypoint: ["picoclaw", "agent"]
+ stdin_open: true
+ tty: true
+
+ # ─────────────────────────────────────────────
+ # PicoClaw Gateway (Long-running Bot) - Full MCP Support
+ # docker compose -f docker/docker-compose.full.yml --profile gateway up
+ # ─────────────────────────────────────────────
+ picoclaw-gateway:
+ build:
+ context: ..
+ dockerfile: docker/Dockerfile.full
+ container_name: picoclaw-gateway-full
+ restart: unless-stopped
+ profiles:
+ - gateway
+ volumes:
+ # Configuration file
+ - ../config/config.json:/root/.picoclaw/config.json:ro
+ # Persistent workspace (sessions, memory, logs)
+ - picoclaw-workspace:/root/.picoclaw/workspace
+ # NPM cache for faster MCP server installs
+ - picoclaw-npm-cache:/root/.npm
+ command: ["gateway"]
+
+volumes:
+ picoclaw-workspace:
+ picoclaw-npm-cache: # Cache npm packages to speed up MCP server installations
diff --git a/docker-compose.yml b/docker/docker-compose.yml
similarity index 64%
rename from docker-compose.yml
rename to docker/docker-compose.yml
index c268b01cd..9ec71abab 100644
--- a/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -1,12 +1,10 @@
services:
# ─────────────────────────────────────────────
# PicoClaw Agent (one-shot query)
- # docker compose run --rm picoclaw-agent -m "Hello"
+ # docker compose -f docker/docker-compose.yml run --rm picoclaw-agent -m "Hello"
# ─────────────────────────────────────────────
picoclaw-agent:
- build:
- context: .
- dockerfile: Dockerfile
+ image: docker.io/sipeed/picoclaw:latest
container_name: picoclaw-agent
profiles:
- agent
@@ -14,33 +12,23 @@ services:
#extra_hosts:
# - "host.docker.internal:host-gateway"
volumes:
- - ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
- - picoclaw-workspace:/home/picoclaw/.picoclaw/workspace
+ - ./data:/root/.picoclaw
entrypoint: ["picoclaw", "agent"]
stdin_open: true
tty: true
# ─────────────────────────────────────────────
# PicoClaw Gateway (Long-running Bot)
- # docker compose up picoclaw-gateway
+ # docker compose -f docker/docker-compose.yml up picoclaw-gateway
# ─────────────────────────────────────────────
picoclaw-gateway:
- build:
- context: .
- dockerfile: Dockerfile
+ image: docker.io/sipeed/picoclaw:latest
container_name: picoclaw-gateway
- restart: unless-stopped
+ restart: on-failure
profiles:
- gateway
# Uncomment to access host network; leave commented unless needed.
#extra_hosts:
# - "host.docker.internal:host-gateway"
volumes:
- # Configuration file
- - ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
- # Persistent workspace (sessions, memory, logs)
- - picoclaw-workspace:/home/picoclaw/.picoclaw/workspace
- command: ["gateway"]
-
-volumes:
- picoclaw-workspace:
+ - ./data:/root/.picoclaw
diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh
new file mode 100644
index 000000000..b6fc724b5
--- /dev/null
+++ b/docker/entrypoint.sh
@@ -0,0 +1,15 @@
+#!/bin/sh
+set -e
+
+# First-run: neither config nor workspace exists.
+# If config.json is already mounted but workspace is missing we skip onboard to
+# avoid the interactive "Overwrite? (y/n)" prompt hanging in a non-TTY container.
+if [ ! -d "${HOME}/.picoclaw/workspace" ] && [ ! -f "${HOME}/.picoclaw/config.json" ]; then
+ picoclaw onboard
+ echo ""
+ echo "First-run setup complete."
+ echo "Edit ${HOME}/.picoclaw/config.json (add your API key, etc.) then restart the container."
+ exit 0
+fi
+
+exec picoclaw gateway "$@"
diff --git a/docs/channels/discord/README.zh.md b/docs/channels/discord/README.zh.md
index 5b597eced..6d3c502cf 100644
--- a/docs/channels/discord/README.zh.md
+++ b/docs/channels/discord/README.zh.md
@@ -11,7 +11,9 @@ Discord 是一个专为社区设计的免费语音、视频和文本聊天应用
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allow_from": ["YOUR_USER_ID"],
- "mention_only": false
+ "group_trigger": {
+ "mention_only": false
+ }
}
}
}
@@ -22,7 +24,7 @@ Discord 是一个专为社区设计的免费语音、视频和文本聊天应用
| enabled | bool | 是 | 是否启用 Discord 频道 |
| token | string | 是 | Discord 机器人 Token |
| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
-| mention_only | bool | 否 | 是否仅响应提及机器人的消息 |
+| group_trigger | object | 否 | 群组触发设置(示例: { "mention_only": false }) |
## 设置流程
diff --git a/docs/channels/line/README.zh.md b/docs/channels/line/README.zh.md
index fd3aa80da..a36f622c2 100644
--- a/docs/channels/line/README.zh.md
+++ b/docs/channels/line/README.zh.md
@@ -11,8 +11,6 @@ PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
@@ -25,9 +23,7 @@ PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的
| enabled | bool | 是 | 是否启用 LINE Channel |
| channel_secret | string | 是 | LINE Messaging API 的 Channel Secret |
| channel_access_token | string | 是 | LINE Messaging API 的 Channel Access Token |
-| webhook_host | string | 是 | Webhook 监听的主机地址 (通常为 0.0.0.0) |
-| webhook_port | int | 是 | Webhook 监听的端口 (默认为 18791) |
-| webhook_path | string | 是 | Webhook 的路径 (默认为 /webhook/line) |
+| webhook_path | string | 否 | Webhook 的路径 (默认为 /webhook/line) |
| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
## 设置流程
@@ -35,7 +31,8 @@ PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的
1. 前往 [LINE Developers Console](https://developers.line.biz/console/) 创建一个服务提供商和一个 Messaging API Channel
2. 获取 Channel Secret 和 Channel Access Token
3. 配置Webhook:
- - Line要求Webhook必须使用HTTPS协议,因此需要部署一个支持HTTPS的服务器,或者使用反向代理工具如ngrok将本地服务器暴露到公网
- - 将 Webhook URL 设置为 `https://your-domain.com/webhook/line`
+ - LINE 要求 Webhook 必须使用 HTTPS 协议,因此需要部署一个支持 HTTPS 的服务器,或者使用反向代理工具如 ngrok 将本地服务器暴露到公网
+ - PicoClaw 现在使用共享的 Gateway HTTP 服务器来接收所有渠道的 webhook 回调,默认监听地址为 127.0.0.1:18790
+ - 将 Webhook URL 设置为 `https://your-domain.com/webhook/line`,然后将外部域名反向代理到本机的 Gateway(默认端口 18790)
- 启用 Webhook 并验证 URL
4. 将 Channel Secret 和 Channel Access Token 填入配置文件中
diff --git a/docs/channels/wecom/wecom_aibot/README.zh.md b/docs/channels/wecom/wecom_aibot/README.zh.md
new file mode 100644
index 000000000..d210528af
--- /dev/null
+++ b/docs/channels/wecom/wecom_aibot/README.zh.md
@@ -0,0 +1,116 @@
+# 企业微信智能机器人 (AI Bot)
+
+企业微信智能机器人(AI Bot)是企业微信官方提供的 AI 对话接入方式,支持私聊与群聊,内置流式响应协议,并支持超时后通过 `response_url` 主动推送最终回复。
+
+## 与其他 WeCom 通道的对比
+
+| 特性 | WeCom Bot | WeCom App | **WeCom AI Bot** |
+|------|-----------|-----------|-----------------|
+| 私聊 | ✅ | ✅ | ✅ |
+| 群聊 | ✅ | ❌ | ✅ |
+| 流式输出 | ❌ | ❌ | ✅ |
+| 超时主动推送 | ❌ | ✅ | ✅ |
+| 配置复杂度 | 低 | 高 | 中 |
+
+## 配置
+
+```json
+{
+ "channels": {
+ "wecom_aibot": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_path": "/webhook/wecom-aibot",
+ "allow_from": [],
+ "welcome_message": "你好!有什么可以帮助你的吗?",
+ "max_steps": 10
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ---------------- | ------ | ---- | -------------------------------------------------- |
+| token | string | 是 | 回调验证令牌,在 AI Bot 管理页面配置 |
+| encoding_aes_key | string | 是 | 43 字符 AES 密钥,在 AI Bot 管理页面随机生成 |
+| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-aibot) |
+| allow_from | array | 否 | 用户 ID 白名单,空数组表示允许所有用户 |
+| welcome_message | string | 否 | 用户进入聊天时发送的欢迎语,留空则不发送 |
+| reply_timeout | int | 否 | 回复超时时间(秒,默认:5) |
+| max_steps | int | 否 | Agent 最大执行步骤数(默认:10) |
+
+## 设置流程
+
+1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin)
+2. 进入"应用管理" → "智能机器人",创建或选择一个 AI Bot
+3. 在 AI Bot 配置页面,填写"消息接收"信息:
+ - **URL**:`http://:18791/webhook/wecom-aibot`
+ - **Token**:随机生成或自定义
+ - **EncodingAESKey**:点击"随机生成",得到 43 字符密钥
+4. 将 Token 和 EncodingAESKey 填入 PicoClaw 配置文件,启动服务后回到管理后台保存(企业微信会发送验证请求)
+
+> [!TIP]
+> 服务器需要能被企业微信服务器访问。如在内网/本地开发,可使用 [ngrok](https://ngrok.com) 或 frp 做内网穿透。
+
+## 流式响应协议
+
+WeCom AI Bot 使用"流式拉取"协议,区别于普通 Webhook 的一次性回复:
+
+```
+用户发消息
+ │
+ ▼
+PicoClaw 立即返回 {finish: false}(Agent 开始处理)
+ │
+ ▼
+企业微信每隔约 1 秒拉取一次 {msgtype: "stream", stream: {id: "..."}}
+ │
+ ├─ Agent 未完成 → 返回 {finish: false}(继续等待)
+ │
+ └─ Agent 完成 → 返回 {finish: true, content: "回答内容"}
+```
+
+**超时处理**(任务超过 30 秒):
+
+若 Agent 处理时间超过约 30 秒(企业微信最大轮询窗口为 6 分钟),PicoClaw 会:
+
+1. 立即关闭流,向用户显示「⏳ 正在处理中,请稍候,结果将稍后发送。」
+2. Agent 继续在后台运行
+3. Agent 完成后,通过消息中携带的 `response_url` 将最终回复主动推送给用户
+
+> `response_url` 由企业微信颁发,有效期 1 小时,只可使用一次,无需加密,直接 POST markdown 消息体即可。
+
+## 欢迎语
+
+配置 `welcome_message` 后,当用户打开与 AI Bot 的聊天窗口时(`enter_chat` 事件),PicoClaw 会自动回复该欢迎语。留空则静默忽略。
+
+```json
+"welcome_message": "你好!我是 PicoClaw AI 助手,有什么可以帮你?"
+```
+
+## 常见问题
+
+### 回调 URL 验证失败
+
+- 确认服务器防火墙已开放对应端口(默认 18791)
+- 确认 `token` 与 `encoding_aes_key` 填写正确
+- 检查 PicoClaw 日志是否收到了来自企业微信的 GET 请求
+
+### 消息没有回复
+
+- 检查 `allow_from` 是否意外限制了发送者
+- 查看日志中是否出现 `context canceled` 或 Agent 错误
+- 确认 Agent 配置(`model_name` 等)正确
+
+### 超长任务没有收到最终推送
+
+- 确认消息回调中携带了 `response_url`(仅企业微信新版 AI Bot 支持)
+- 确认服务器能主动访问外网(需向 `response_url` POST 请求)
+- 查看日志关键词 `response_url mode` 和 `Sending reply via response_url`
+
+## 参考文档
+
+- [企业微信 AI Bot 接入文档](https://developer.work.weixin.qq.com/document/path/100719)
+- [流式响应协议说明](https://developer.work.weixin.qq.com/document/path/100719)
+- [response_url 主动回复](https://developer.work.weixin.qq.com/document/path/101138)
diff --git a/docs/channels/wecom/wecom_app/README.zh.md b/docs/channels/wecom/wecom_app/README.zh.md
index 1e6a0e2b3..0a9858107 100644
--- a/docs/channels/wecom/wecom_app/README.zh.md
+++ b/docs/channels/wecom/wecom_app/README.zh.md
@@ -14,8 +14,6 @@
"agent_id": 1000002,
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
"webhook_path": "/webhook/wecom-app",
"allow_from": [],
"reply_timeout": 5
@@ -31,8 +29,6 @@
| agent_id | int | 是 | 应用程序代理 ID |
| token | string | 是 | 回调验证令牌 |
| encoding_aes_key | string | 是 | 43 字符 AES 密钥 |
-| webhook_host | string | 否 | HTTP 服务器绑定地址 |
-| webhook_port | int | 否 | HTTP 服务器端口(默认:18792) |
| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-app) |
| allow_from | array | 否 | 用户 ID 白名单 |
| reply_timeout | int | 否 | 回复超时时间(秒) |
@@ -45,3 +41,5 @@
4. 在应用设置中配置“接收消息”,获取 Token 和 EncodingAESKey
5. 设置回调 URL 为 `http://:/webhook/wecom-app`
6. 将 CorpID, Secret, AgentID 等信息填入配置文件
+
+ 注意: PicoClaw 现在使用共享的 Gateway HTTP 服务器来接收所有渠道的 webhook 回调,默认监听地址为 127.0.0.1:18790。如需从公网接收回调,请把外部域名反向代理到 Gateway(默认端口 18790)。
diff --git a/docs/channels/wecom/wecom_bot/README.zh.md b/docs/channels/wecom/wecom_bot/README.zh.md
index c4bb1c87e..63d9b84d6 100644
--- a/docs/channels/wecom/wecom_bot/README.zh.md
+++ b/docs/channels/wecom/wecom_bot/README.zh.md
@@ -12,8 +12,6 @@
"token": "YOUR_TOKEN",
"encoding_aes_key": "YOUR_ENCODING_AES_KEY",
"webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
- "webhook_host": "0.0.0.0",
- "webhook_port": 18793,
"webhook_path": "/webhook/wecom",
"allow_from": [],
"reply_timeout": 5
@@ -27,8 +25,6 @@
| token | string | 是 | 签名验证代币 |
| encoding_aes_key | string | 是 | 用于解密的 43 字符 AES 密钥 |
| webhook_url | string | 是 | 用于发送回复的企业微信群聊机器人 Webhook URL |
-| webhook_host | string | 否 | HTTP 服务器绑定地址(默认:0.0.0.0) |
-| webhook_port | int | 否 | HTTP 服务器端口(默认:18793) |
| webhook_path | string | 否 | Webhook 端点路径(默认:/webhook/wecom) |
| allow_from | array | 否 | 用户 ID 白名单(空值 = 允许所有用户) |
| reply_timeout | int | 否 | 回复超时时间(单位:秒,默认值:5) |
@@ -39,3 +35,5 @@
2. 获取 Webhook URL
3. (如需接收消息) 在机器人配置页面设置接收消息的 API 地址(回调地址)以及 Token 和 EncodingAESKey
4. 将相关信息填入配置文件
+
+ 注意: PicoClaw 现在使用共享的 Gateway HTTP 服务器来接收所有渠道的 webhook 回调,默认监听地址为 127.0.0.1:18790。如需从公网接收回调,请把外部域名反向代理到 Gateway(默认端口 18790)。
diff --git a/docs/design/issue-783-investigation-and-fix-plan.zh.md b/docs/design/issue-783-investigation-and-fix-plan.zh.md
new file mode 100644
index 000000000..1c9fc1e70
--- /dev/null
+++ b/docs/design/issue-783-investigation-and-fix-plan.zh.md
@@ -0,0 +1,61 @@
+# Issue #783 调研与修复执行文档
+
+## 1. 问题澄清(已确认)
+
+- 现象:当 `agents.*.model.primary/fallbacks` 使用 `model_name` 别名(如 `step-3.5-flash`)时,fallback 链路将别名当作真实 `provider/model` 解析,导致 `provider` 可能为空、`model` 可能错误。
+- 根因:`ResolveCandidates` 仅对字符串做 `ParseModelRef`,未先通过 `model_list` 将别名映射到真实 `model` 字段。
+- 影响:
+ - fallback 执行可能把别名直接发给 OpenAI-compatible provider,触发 `Unknown Model`。
+ - `defaults.provider` 为空时,日志出现 `provider=` 空值。
+
+## 2. 本次目标
+
+- 修复 fallback 候选解析:优先通过 `model_list` 解析别名。
+- 兼容旧行为:若未命中 `model_list`,继续走原有 `ParseModelRef` 兜底。
+- 补充测试:覆盖别名、嵌套路径模型(如 `openrouter/stepfun/...`)、空默认 provider。
+- 验证代码风格:与当前仓库风格保持一致(命名、错误处理、测试结构)。
+
+## 3. 联网最佳实践调研结论(已完成)
+
+- [x] 查阅 OpenAI-compatible 网关(如 OpenRouter)对 `model` 字段的推荐处理。
+- [x] 查阅多 provider/fallback 设计最佳实践(候选解析、日志可观测性)。
+- [x] 将外部建议映射为本仓库可执行约束。
+
+外部参考要点(来自 OpenRouter/LiteLLM/Cloudflare AI Gateway 等官方文档):
+
+- 优先显式配置,不依赖字符串切分推断 provider。
+- 对网关模型标识应保留完整路径语义,避免截断导致 Unknown Model。
+- fallback 与 primary 应复用同一解析策略,避免“主路径正确、降级路径错误”。
+
+参考链接:
+
+- OpenRouter Provider Routing: https://openrouter.ai/docs/guides/routing/provider-selection
+- OpenRouter Model Fallbacks: https://openrouter.ai/docs/guides/routing/model-fallbacks
+- OpenRouter Chat Completion API: https://openrouter.ai/docs/api-reference/chat-completion
+- LiteLLM Router Architecture: https://docs.litellm.ai/docs/router_architecture
+- Cloudflare AI Gateway Chat Completion: https://developers.cloudflare.com/ai-gateway/usage/chat-completion/
+
+与本仓库对应的可执行约束:
+
+- 在 fallback candidate 构建阶段先做 `model_name -> model_list.model` 映射。
+- 未命中映射时保留旧解析行为,保证兼容性。
+- 用新增测试锁定“别名 + 嵌套模型路径 + 空默认 provider”场景。
+
+## 4. 实施步骤(顺序执行)
+
+- [x] Step 1: 对齐现有代码模式,定位最小改动点(`pkg/agent` + `pkg/providers`)。
+- [x] Step 2: 实现“基于 model_list 的 fallback 候选解析”。
+- [x] Step 3: 增加/更新单元测试,覆盖 issue 场景。
+- [x] Step 4: 代码风格一致性复核(与现有文件风格对照)。
+- [x] Step 5: 运行质量门禁(LSP + `make check`)。
+
+## 5. 执行记录
+
+- 状态:已完成
+- 已完成改动:
+ - `pkg/providers/fallback.go`:新增 `ResolveCandidatesWithLookup`,并保持 `ResolveCandidates` 向后兼容。
+ - `pkg/agent/instance.go`:在构建 fallback candidates 前,优先通过 `model_list` 解析别名,并对无协议模型补齐默认 `openai/` 前缀后再解析。
+ - `pkg/providers/fallback_test.go`:新增别名解析与去重测试。
+ - `pkg/agent/instance_test.go`:新增 agent 侧别名解析到嵌套模型路径、无协议模型解析测试。
+- 风格对齐检查(完成):与 `pkg/providers/fallback_test.go`、`pkg/providers/model_ref_test.go` 现有模式一致。
+- 质量验证(完成):先 `make generate`,后 `make check` 全量通过。
diff --git a/docs/migration/model-list-migration.md b/docs/migration/model-list-migration.md
index 589dfc043..0d4af719c 100644
--- a/docs/migration/model-list-migration.md
+++ b/docs/migration/model-list-migration.md
@@ -117,6 +117,7 @@ The `model` field uses a protocol prefix format: `[protocol/]model-identifier`
| `connect_mode` | No | Connection mode for CLI providers: `stdio`, `grpc` |
| `rpm` | No | Requests per minute limit |
| `max_tokens_field` | No | Field name for max tokens |
+| `request_timeout` | No | HTTP request timeout in seconds; `<=0` uses default `120s` |
*`api_key` is required for HTTP-based protocols unless `api_base` points to a local server.
diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md
index 8aba1aa91..6204fb0c8 100644
--- a/docs/tools_configuration.md
+++ b/docs/tools_configuration.md
@@ -8,6 +8,7 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`.
{
"tools": {
"web": { ... },
+ "mcp": { ... },
"exec": { ... },
"cron": { ... },
"skills": { ... }
@@ -21,35 +22,35 @@ Web tools are used for web search and fetching.
### Brave
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enabled` | bool | false | Enable Brave search |
-| `api_key` | string | - | Brave Search API key |
-| `max_results` | int | 5 | Maximum number of results |
+| Config | Type | Default | Description |
+| ------------- | ------ | ------- | ------------------------- |
+| `enabled` | bool | false | Enable Brave search |
+| `api_key` | string | - | Brave Search API key |
+| `max_results` | int | 5 | Maximum number of results |
### DuckDuckGo
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enabled` | bool | true | Enable DuckDuckGo search |
-| `max_results` | int | 5 | Maximum number of results |
+| Config | Type | Default | Description |
+| ------------- | ---- | ------- | ------------------------- |
+| `enabled` | bool | true | Enable DuckDuckGo search |
+| `max_results` | int | 5 | Maximum number of results |
### Perplexity
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enabled` | bool | false | Enable Perplexity search |
-| `api_key` | string | - | Perplexity API key |
-| `max_results` | int | 5 | Maximum number of results |
+| Config | Type | Default | Description |
+| ------------- | ------ | ------- | ------------------------- |
+| `enabled` | bool | false | Enable Perplexity search |
+| `api_key` | string | - | Perplexity API key |
+| `max_results` | int | 5 | Maximum number of results |
## Exec Tool
The exec tool is used to execute shell commands.
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
-| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
+| Config | Type | Default | Description |
+| ---------------------- | ----- | ------- | ------------------------------------------ |
+| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
+| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
### Functionality
@@ -80,10 +81,7 @@ By default, PicoClaw blocks the following dangerous commands:
"tools": {
"exec": {
"enable_deny_patterns": true,
- "custom_deny_patterns": [
- "\\brm\\s+-r\\b",
- "\\bkillall\\s+python"
- ]
+ "custom_deny_patterns": ["\\brm\\s+-r\\b", "\\bkillall\\s+python"]
}
}
}
@@ -93,9 +91,84 @@ By default, PicoClaw blocks the following dangerous commands:
The cron tool is used for scheduling periodic tasks.
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
+| Config | Type | Default | Description |
+| ---------------------- | ---- | ------- | ---------------------------------------------- |
+| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
+
+## MCP Tool
+
+The MCP tool enables integration with external Model Context Protocol servers.
+
+### Global Config
+
+| Config | Type | Default | Description |
+| --------- | ------ | ------- | ----------------------------------- |
+| `enabled` | bool | false | Enable MCP integration globally |
+| `servers` | object | `{}` | Map of server name to server config |
+
+### Per-Server Config
+
+| Config | Type | Required | Description |
+| ---------- | ------ | -------- | ------------------------------------------ |
+| `enabled` | bool | yes | Enable this MCP server |
+| `type` | string | no | Transport type: `stdio`, `sse`, `http` |
+| `command` | string | stdio | Executable command for stdio transport |
+| `args` | array | no | Command arguments for stdio transport |
+| `env` | object | no | Environment variables for stdio process |
+| `env_file` | string | no | Path to environment file for stdio process |
+| `url` | string | sse/http | Endpoint URL for `sse`/`http` transport |
+| `headers` | object | no | HTTP headers for `sse`/`http` transport |
+
+### Transport Behavior
+
+- If `type` is omitted, transport is auto-detected:
+ - `url` is set → `sse`
+ - `command` is set → `stdio`
+- `http` and `sse` both use `url` + optional `headers`.
+- `env` and `env_file` are only applied to `stdio` servers.
+
+### Configuration Examples
+
+#### 1) Stdio MCP server
+
+```json
+{
+ "tools": {
+ "mcp": {
+ "enabled": true,
+ "servers": {
+ "filesystem": {
+ "enabled": true,
+ "command": "npx",
+ "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
+ }
+ }
+ }
+ }
+}
+```
+
+#### 2) Remote SSE/HTTP MCP server
+
+```json
+{
+ "tools": {
+ "mcp": {
+ "enabled": true,
+ "servers": {
+ "remote-mcp": {
+ "enabled": true,
+ "type": "sse",
+ "url": "https://example.com/mcp",
+ "headers": {
+ "Authorization": "Bearer YOUR_TOKEN"
+ }
+ }
+ }
+ }
+ }
+}
+```
## Skills Tool
@@ -103,13 +176,13 @@ The skills tool configures skill discovery and installation via registries like
### Registries
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
-| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
-| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
-| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
-| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
+| Config | Type | Default | Description |
+| ---------------------------------- | ------ | -------------------- | ----------------------- |
+| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
+| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
+| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
+| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
+| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
### Configuration Example
@@ -136,8 +209,10 @@ The skills tool configures skill discovery and installation via registries like
All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS__`:
For example:
+
- `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true`
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
+- `PICOCLAW_TOOLS_MCP_ENABLED=true`
-Note: Array-type environment variables are not currently supported and must be set via the config file.
+Note: Nested map-style config (for example `tools.mcp.servers..*`) is configured in `config.json` rather than environment variables.
diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md
new file mode 100644
index 000000000..219d2c6e3
--- /dev/null
+++ b/docs/troubleshooting.md
@@ -0,0 +1,43 @@
+# Troubleshooting
+
+## "model ... not found in model_list" or OpenRouter "free is not a valid model ID"
+
+**Symptom:** You see either:
+
+- `Error creating provider: model "openrouter/free" not found in model_list`
+- OpenRouter returns 400: `"free is not a valid model ID"`
+
+**Cause:** The `model` field in your `model_list` entry is what gets sent to the API. For OpenRouter you must use the **full** model ID, not a shorthand.
+
+- **Wrong:** `"model": "free"` → OpenRouter receives `free` and rejects it.
+- **Right:** `"model": "openrouter/free"` → OpenRouter receives `openrouter/free` (auto free-tier routing).
+
+**Fix:** In `~/.picoclaw/config.json` (or your config path):
+
+1. **agents.defaults.model** must match a `model_name` in `model_list` (e.g. `"openrouter-free"`).
+2. That entry’s **model** must be a valid OpenRouter model ID, for example:
+ - `"openrouter/free"` – auto free-tier
+ - `"google/gemini-2.0-flash-exp:free"`
+ - `"meta-llama/llama-3.1-8b-instruct:free"`
+
+Example snippet:
+
+```json
+{
+ "agents": {
+ "defaults": {
+ "model": "openrouter-free"
+ }
+ },
+ "model_list": [
+ {
+ "model_name": "openrouter-free",
+ "model": "openrouter/free",
+ "api_key": "sk-or-v1-YOUR_OPENROUTER_KEY",
+ "api_base": "https://openrouter.ai/api/v1"
+ }
+ ]
+}
+```
+
+Get your key at [OpenRouter Keys](https://openrouter.ai/keys).
diff --git a/docs/wecom-app-configuration.md b/docs/wecom-app-configuration.md
deleted file mode 100644
index 3b17d37a7..000000000
--- a/docs/wecom-app-configuration.md
+++ /dev/null
@@ -1,117 +0,0 @@
-# 企业微信自建应用 (WeCom App) 配置指南
-
-本文档介绍如何在 PicoClaw 中配置企业微信自建应用 (wecom-app) 通道。
-
-## 功能特性
-
-| 功能 | 支持状态 |
-|------|---------|
-| 被动接收消息 | ✅ |
-| 主动发送消息 | ✅ |
-| 私聊 | ✅ |
-| 群聊 | ❌ |
-
-## 配置步骤
-
-### 1. 企业微信后台配置
-
-1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin)
-2. 进入"应用管理" → 选择自建应用
-3. 记录以下信息:
- - **AgentId**: 应用详情页显示
- - **Secret**: 点击"查看"获取
-4. 进入"我的企业"页面,记录 **企业ID** (CorpID)
-
-### 2. 接收消息配置
-
-1. 在应用详情页,点击"接收消息"的"设置API接收"
-2. 填写以下信息:
- - **URL**: `http://your-server:18792/webhook/wecom-app`
- - **Token**: 随机生成或自定义(用于签名验证)
- - **EncodingAESKey**: 点击"随机生成"生成43字符的密钥
-3. 点击"保存"时,企业微信会发送验证请求
-
-### 3. PicoClaw 配置
-
-在 `config.json` 中添加以下配置:
-
-```json
-{
- "channels": {
- "wecom_app": {
- "enabled": true,
- "corp_id": "wwxxxxxxxxxxxxxxxx", // 企业ID
- "corp_secret": "xxxxxxxxxxxxxxxxxxxxxxxx", // 应用Secret
- "agent_id": 1000002, // 应用AgentId
- "token": "your_token", // 接收消息配置的Token
- "encoding_aes_key": "your_encoding_aes_key", // 接收消息配置的EncodingAESKey
- "webhook_host": "0.0.0.0",
- "webhook_port": 18792,
- "webhook_path": "/webhook/wecom-app",
- "allow_from": [],
- "reply_timeout": 5
- }
- }
-}
-```
-
-## 常见问题
-
-### 1. 回调URL验证失败
-
-**症状**: 企业微信保存API接收消息时提示验证失败
-
-**检查项**:
-- 确认服务器防火墙已开放 18792 端口
-- 确认 `corp_id`、`token`、`encoding_aes_key` 配置正确
-- 查看 PicoClaw 日志是否有请求到达
-
-### 2. 中文消息解密失败
-
-**症状**: 发送中文消息时出现 `invalid padding size` 错误
-
-**原因**: 企业微信使用非标准的 PKCS7 填充(32字节块大小)
-
-**解决**: 确保使用最新版本的 PicoClaw,已修复此问题。
-
-### 3. 端口冲突
-
-**症状**: 启动时提示端口已被占用
-
-**解决**: 修改 `webhook_port` 为其他端口,如 18794
-
-## 技术细节
-
-### 加密算法
-
-- **算法**: AES-256-CBC
-- **密钥**: EncodingAESKey Base64解码后的32字节
-- **IV**: AESKey的前16字节
-- **填充**: PKCS7(块大小为32字节,非标准16字节)
-- **消息格式**: XML
-
-### 消息结构
-
-解密后的消息格式:
-```
-random(16B) + msg_len(4B) + msg + receiveid
-```
-
-其中 `receiveid` 对于自建应用是 `corp_id`。
-
-## 调试
-
-启用调试模式查看详细日志:
-
-```bash
-picoclaw gateway --debug
-```
-
-关键日志标识:
-- `wecom_app`: WeCom App 通道相关日志
-- `wecom_common`: 加密解密相关日志
-
-## 参考文档
-
-- [企业微信官方文档 - 接收消息](https://developer.work.weixin.qq.com/document/path/96211)
-- [企业微信官方加解密库](https://github.com/sbzhu/weworkapi_golang)
diff --git a/go.mod b/go.mod
index 98e20d07d..c1172937c 100644
--- a/go.mod
+++ b/go.mod
@@ -8,25 +8,59 @@ require (
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1
+ github.com/gdamore/tcell/v2 v2.13.8
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
+ github.com/mdp/qrterminal/v3 v3.2.1
+ github.com/modelcontextprotocol/go-sdk v1.3.0
github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
+ github.com/rivo/tview v0.42.0
github.com/slack-go/slack v0.17.3
github.com/spf13/cobra v1.10.2
github.com/stretchr/testify v1.11.1
github.com/tencent-connect/botgo v0.2.1
+ go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4
golang.org/x/oauth2 v0.35.0
+ golang.org/x/time v0.14.0
+ google.golang.org/protobuf v1.36.11
+ modernc.org/sqlite v1.46.1
)
require (
+ filippo.io/edwards25519 v1.1.0 // indirect
+ github.com/beeper/argo-go v1.1.2 // indirect
+ github.com/coder/websocket v1.8.14 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/dustin/go-humanize v1.0.1 // indirect
+ github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
+ github.com/gdamore/encoding v1.0.1 // indirect
+ github.com/gdamore/tcell/v2 v2.13.8 // indirect
+ github.com/h2non/filetype v1.1.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
+ github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
+ github.com/mattn/go-colorable v0.1.14 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/ncruces/go-strftime v1.0.0 // indirect
+ github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/rivo/uniseg v0.4.7 // indirect
+ github.com/rs/zerolog v1.34.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
+ github.com/vektah/gqlparser/v2 v2.5.27 // indirect
+ go.mau.fi/libsignal v0.2.1 // indirect
+ go.mau.fi/util v0.9.6 // indirect
+ golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect
+ golang.org/x/term v0.40.0 // indirect
+ golang.org/x/text v0.34.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
+ modernc.org/libc v1.67.6 // indirect
+ modernc.org/mathutil v1.7.1 // indirect
+ modernc.org/memory v1.11.0 // indirect
+ rsc.io/qr v0.2.0 // indirect
)
require (
@@ -50,6 +84,7 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
+ github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/arch v0.24.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
diff --git a/go.sum b/go.sum
index abbb11cd6..060594d06 100644
--- a/go.sum
+++ b/go.sum
@@ -1,10 +1,20 @@
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
+filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
+filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
+github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
+github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
+github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM=
+github.com/agnivade/levenshtein v1.2.1/go.mod h1:QVVI16kDrtSuwcpd0p1+xMC6Z/VfhtCyDIjcwga4/DU=
+github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
+github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
+github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
+github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
@@ -25,14 +35,25 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
+github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
+github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
+github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
+github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
+github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
+github.com/elliotchance/orderedmap/v3 v3.1.0 h1:j4DJ5ObEmMBt/lcwIecKcoRxIQUEnw0L804lXYDt/pg=
+github.com/elliotchance/orderedmap/v3 v3.1.0/go.mod h1:G+Hc2RwaZvJMcS4JpGCOyViCnGeKf0bTYCGTO4uhjSo=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
+github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uhw=
+github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
+github.com/gdamore/tcell/v2 v2.13.8 h1:Mys/Kl5wfC/GcC5Cx4C2BIQH9dbnhnkPgS9/wF3RlfU=
+github.com/gdamore/tcell/v2 v2.13.8/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo=
github.com/github/copilot-sdk/go v0.1.23 h1:uExtO/inZQndCZMiSAA1hvXINiz9tqo/MZgQzFzurxw=
github.com/github/copilot-sdk/go v0.1.23/go.mod h1:GdwwBfMbm9AABLEM3x5IZKw4ZfwCYxZ1BgyytmZenQ0=
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w=
@@ -42,8 +63,11 @@ github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2m
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
+github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
+github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
+github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
@@ -63,6 +87,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
+github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
+github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -72,6 +98,10 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc=
github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek=
+github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg=
+github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY=
+github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
+github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@@ -91,8 +121,25 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
+github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
+github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
+github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
+github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
+github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
+github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs=
+github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE=
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
+github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
+github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -105,13 +152,27 @@ github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
github.com/openai/openai-go/v3 v3.22.0 h1:6MEoNoV8sbjOVmXdvhmuX3BjVbVdcExbVyGixiyJ8ys=
github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14=
+github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/rivo/tview v0.42.0 h1:b/ftp+RxtDsHSaynXTbJb+/n/BxDEi+W3UfF5jILK6c=
+github.com/rivo/tview v0.42.0/go.mod h1:cSfIYfhpSGCjp3r/ECJb+GKS7cGJnqV8vfjQPwoXyfY=
+github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
+github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
+github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
+github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
+github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
@@ -153,11 +214,21 @@ github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZy
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM=
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
+github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTdwFp0s=
+github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
+github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
+github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
+go.mau.fi/libsignal v0.2.1 h1:vRZG4EzTn70XY6Oh/pVKrQGuMHBkAWlGRC22/85m9L0=
+go.mau.fi/libsignal v0.2.1/go.mod h1:iVvjrHyfQqWajOUaMEsIfo3IqgVMrhWcPiiEzk7NgoU=
+go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts=
+go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI=
+go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4 h1:hsmlwsM+VqfF70cpdZEeIUKer2XWCQmQPK0u0tHy3ZQ=
+go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4/go.mod h1:mXCRFyPEPn4jqWz6Afirn8vY7DpHCPnlKq6I2cWwFHM=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
@@ -171,10 +242,14 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a h1:ovFr6Z0MNmU7nH8VaX5xqw+05ST2uO1exVfZPVqRC5o=
+golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
+golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@@ -217,8 +292,11 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
@@ -227,6 +305,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
+golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
+golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -234,8 +314,10 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
-golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
+golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
+golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
+golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
+golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
@@ -243,6 +325,8 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
+golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
+golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -255,6 +339,8 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
+google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
+google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
@@ -269,3 +355,33 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
+modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
+modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
+modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
+modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
+modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
+modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
+modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
+modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
+modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
+modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
+modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
+modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
+modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
+modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
+modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
+modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
+modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
+modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
+modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
+modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
+modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
+modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
+modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
+modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
+modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
+modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
+modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
+rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
+rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
diff --git a/pkg/agent/context.go b/pkg/agent/context.go
index b7c6e1108..3aa903b3f 100644
--- a/pkg/agent/context.go
+++ b/pkg/agent/context.go
@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "slices"
"strings"
"sync"
"time"
@@ -33,6 +34,11 @@ type ContextBuilder struct {
// created (didn't exist at cache time, now exist) or deleted (existed at
// cache time, now gone) — both of which should trigger a cache rebuild.
existedAtCache map[string]bool
+
+ // skillFilesAtCache snapshots the skill tree file set and mtimes at cache
+ // build time. This catches nested file creations/deletions/mtime changes
+ // that may not update the top-level skill root directory mtime.
+ skillFilesAtCache map[string]time.Time
}
func getGlobalConfigDir() string {
@@ -46,8 +52,11 @@ func getGlobalConfigDir() string {
func NewContextBuilder(workspace string) *ContextBuilder {
// builtin skills: skills directory in current project
// Use the skills/ directory under the current working directory
- wd, _ := os.Getwd()
- builtinSkillsDir := filepath.Join(wd, "skills")
+ builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
+ if builtinSkillsDir == "" {
+ wd, _ := os.Getwd()
+ builtinSkillsDir = filepath.Join(wd, "skills")
+ }
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
return &ContextBuilder{
@@ -147,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
cb.cachedSystemPrompt = prompt
cb.cachedAt = baseline.maxMtime
cb.existedAtCache = baseline.existed
+ cb.skillFilesAtCache = baseline.skillFiles
logger.DebugCF("agent", "System prompt cached",
map[string]any{
@@ -166,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() {
cb.cachedSystemPrompt = ""
cb.cachedAt = time.Time{}
cb.existedAtCache = nil
+ cb.skillFilesAtCache = nil
logger.DebugCF("agent", "System prompt cache invalidated", nil)
}
-// sourcePaths returns the workspace source file paths tracked for cache
-// invalidation (bootstrap files + memory). The skills directory is handled
-// separately in sourceFilesChangedLocked because it requires both directory-
-// level and recursive file-level mtime checks.
+// sourcePaths returns non-skill workspace source files tracked for cache
+// invalidation (bootstrap files + memory). Skill roots are handled separately
+// because they require both directory-level and recursive file-level checks.
func (cb *ContextBuilder) sourcePaths() []string {
return []string{
filepath.Join(cb.workspace, "AGENTS.md"),
@@ -184,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string {
}
}
+// skillRoots returns all skill root directories that can affect
+// BuildSkillsSummary output (workspace/global/builtin).
+func (cb *ContextBuilder) skillRoots() []string {
+ if cb.skillsLoader == nil {
+ return []string{filepath.Join(cb.workspace, "skills")}
+ }
+
+ roots := cb.skillsLoader.SkillRoots()
+ if len(roots) == 0 {
+ return []string{filepath.Join(cb.workspace, "skills")}
+ }
+ return roots
+}
+
// cacheBaseline holds the file existence snapshot and the latest observed
// mtime across all tracked paths. Used as the cache reference point.
type cacheBaseline struct {
- existed map[string]bool
- maxMtime time.Time
+ existed map[string]bool
+ skillFiles map[string]time.Time
+ maxMtime time.Time
}
// buildCacheBaseline records which tracked paths currently exist and computes
// the latest mtime across all tracked files + skills directory contents.
// Called under write lock when the cache is built.
func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
- skillsDir := filepath.Join(cb.workspace, "skills")
+ skillRoots := cb.skillRoots()
- // All paths whose existence we track: source files + skills dir.
- allPaths := append(cb.sourcePaths(), skillsDir)
+ // All paths whose existence we track: source files + all skill roots.
+ allPaths := append(cb.sourcePaths(), skillRoots...)
existed := make(map[string]bool, len(allPaths))
+ skillFiles := make(map[string]time.Time)
var maxMtime time.Time
for _, p := range allPaths {
@@ -211,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
}
}
- // Walk skills files to capture their mtimes too.
- // Use os.Stat (not d.Info) to match the stat method used in
- // fileChangedSince / skillFilesModifiedSince for consistency.
- _ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
- if walkErr == nil && !d.IsDir() {
- if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
- maxMtime = info.ModTime()
+ // Walk all skill roots recursively to snapshot skill files and mtimes.
+ // Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks.
+ for _, root := range skillRoots {
+ _ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr == nil && !d.IsDir() {
+ if info, err := os.Stat(path); err == nil {
+ skillFiles[path] = info.ModTime()
+ if info.ModTime().After(maxMtime) {
+ maxMtime = info.ModTime()
+ }
+ }
}
- }
- return nil
- })
+ return nil
+ })
+ }
// If no tracked files exist yet (empty workspace), maxMtime is zero.
// Use a very old non-zero time so that:
@@ -233,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
maxMtime = time.Unix(1, 0)
}
- return cacheBaseline{existed: existed, maxMtime: maxMtime}
+ return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime}
}
// sourceFilesChangedLocked checks whether any workspace source file has been
@@ -249,27 +279,21 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
}
// Check tracked source files (bootstrap + memory).
- for _, p := range cb.sourcePaths() {
- if cb.fileChangedSince(p) {
- return true
- }
- }
-
- // --- Skills directory (handled separately from sourcePaths) ---
- //
- // 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
- skillsDir := filepath.Join(cb.workspace, "skills")
- if cb.fileChangedSince(skillsDir) {
+ if slices.ContainsFunc(cb.sourcePaths(), cb.fileChangedSince) {
return true
}
- // 2. Structural changes (add/remove entries inside the dir) are reflected
- // in the directory's own mtime, which fileChangedSince already checks.
+ // --- Skill roots (workspace/global/builtin) ---
//
- // 3. Content-only edits to files inside skills/ do NOT update the parent
- // directory mtime on most filesystems, so we recursively walk to check
- // individual file mtimes at any nesting depth.
- if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
+ // For each root:
+ // 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince.
+ // 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot.
+ for _, root := range cb.skillRoots() {
+ if cb.fileChangedSince(root) {
+ return true
+ }
+ }
+ if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) {
return true
}
@@ -310,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool {
// if the callback returned nil when its err parameter is non-nil.
var errWalkStop = errors.New("walk stop")
-// skillFilesModifiedSince recursively walks the skills directory and checks
-// whether any file was modified after t. This catches content-only edits at
-// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
-// parent directory mtimes.
-func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
- changed := false
- err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
- if walkErr == nil && !d.IsDir() {
- if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
- changed = true
- return errWalkStop // stop walking
- }
- }
- return nil
- })
- // errWalkStop is expected (early exit on first changed file).
- // os.IsNotExist means the skills dir doesn't exist yet — not an error.
- // Any other error is unexpected and worth logging.
- if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
- logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
+// skillFilesChangedSince compares the current recursive skill file tree
+// against the cache-time snapshot. Any create/delete/mtime drift invalidates
+// the cache.
+func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool {
+ // Defensive: if the snapshot was never initialized, force rebuild.
+ if filesAtCache == nil {
+ return true
}
- return changed
+
+ // Check cached files still exist and keep the same mtime.
+ for path, cachedMtime := range filesAtCache {
+ info, err := os.Stat(path)
+ if err != nil {
+ // A previously tracked file disappeared (or became inaccessible):
+ // either way, cached skill summary may now be stale.
+ return true
+ }
+ if !info.ModTime().Equal(cachedMtime) {
+ return true
+ }
+ }
+
+ // Check no new files appeared under any skill root.
+ changed := false
+ for _, root := range skillRoots {
+ if strings.TrimSpace(root) == "" {
+ continue
+ }
+
+ err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr != nil {
+ // Treat unexpected walk errors as changed to avoid stale cache.
+ if !os.IsNotExist(walkErr) {
+ changed = true
+ return errWalkStop
+ }
+ return nil
+ }
+ if d.IsDir() {
+ return nil
+ }
+ if _, ok := filesAtCache[path]; !ok {
+ changed = true
+ return errWalkStop
+ }
+ return nil
+ })
+
+ if changed {
+ return true
+ }
+ if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
+ logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
+ return true
+ }
+ }
+
+ return false
}
func (cb *ContextBuilder) LoadBootstrapFiles() string {
@@ -467,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages(
// Add current user message
if strings.TrimSpace(currentMessage) != "" {
- messages = append(messages, providers.Message{
+ msg := providers.Message{
Role: "user",
Content: currentMessage,
- })
+ }
+ if len(media) > 0 {
+ msg.Media = media
+ }
+ messages = append(messages, msg)
}
return messages
diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go
index ba70d4c0d..707510820 100644
--- a/pkg/agent/context_cache_test.go
+++ b/pkg/agent/context_cache_test.go
@@ -383,6 +383,162 @@ Updated content.`
}
}
+// TestGlobalSkillFileContentChange verifies that modifying a global skill
+// (~/.picoclaw/skills) invalidates the cached system prompt.
+func TestGlobalSkillFileContentChange(t *testing.T) {
+ tmpHome := t.TempDir()
+ t.Setenv("HOME", tmpHome)
+
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md")
+ if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ v1 := `---
+name: global-skill
+description: global-v1
+---
+# Global Skill v1`
+ if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ cb := NewContextBuilder(tmpDir)
+ sp1 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp1, "global-v1") {
+ t.Fatal("expected initial prompt to contain global skill description")
+ }
+
+ v2 := `---
+name: global-skill
+description: global-v2
+---
+# Global Skill v2`
+ if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ future := time.Now().Add(2 * time.Second)
+ if err := os.Chtimes(globalSkillPath, future, future); err != nil {
+ t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("sourceFilesChangedLocked() should detect global skill file content change")
+ }
+
+ sp2 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp2, "global-v2") {
+ t.Error("rebuilt prompt should contain updated global skill description")
+ }
+ if sp1 == sp2 {
+ t.Error("cache should be invalidated when global skill file content changes")
+ }
+}
+
+// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill
+// invalidates the cached system prompt.
+func TestBuiltinSkillFileContentChange(t *testing.T) {
+ tmpHome := t.TempDir()
+ t.Setenv("HOME", tmpHome)
+
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ builtinRoot := t.TempDir()
+ t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot)
+
+ builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md")
+ if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ v1 := `---
+name: builtin-skill
+description: builtin-v1
+---
+# Builtin Skill v1`
+ if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ cb := NewContextBuilder(tmpDir)
+ sp1 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp1, "builtin-v1") {
+ t.Fatal("expected initial prompt to contain builtin skill description")
+ }
+
+ v2 := `---
+name: builtin-skill
+description: builtin-v2
+---
+# Builtin Skill v2`
+ if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ future := time.Now().Add(2 * time.Second)
+ if err := os.Chtimes(builtinSkillPath, future, future); err != nil {
+ t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change")
+ }
+
+ sp2 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp2, "builtin-v2") {
+ t.Error("rebuilt prompt should contain updated builtin skill description")
+ }
+ if sp1 == sp2 {
+ t.Error("cache should be invalidated when builtin skill file content changes")
+ }
+}
+
+// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill
+// file invalidates the cached system prompt.
+func TestSkillFileDeletionInvalidatesCache(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "skills/delete-me/SKILL.md": `---
+name: delete-me
+description: delete-me-v1
+---
+# Delete Me`,
+ })
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ sp1 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp1, "delete-me-v1") {
+ t.Fatal("expected initial prompt to contain skill description")
+ }
+
+ skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md")
+ if err := os.Remove(skillPath); err != nil {
+ t.Fatal(err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("sourceFilesChangedLocked() should detect deleted skill file")
+ }
+
+ sp2 := cb.BuildSystemPromptWithCache()
+ if strings.Contains(sp2, "delete-me-v1") {
+ t.Error("rebuilt prompt should not contain deleted skill description")
+ }
+ if sp1 == sp2 {
+ t.Error("cache should be invalidated when skill file is deleted")
+ }
+}
+
// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
// can safely call BuildSystemPromptWithCache concurrently without producing
// empty results, panics, or data races.
@@ -404,11 +560,11 @@ func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
var wg sync.WaitGroup
errs := make(chan string, goroutines*iterations)
- for g := 0; g < goroutines; g++ {
+ for g := range goroutines {
wg.Add(1)
go func(id int) {
defer wg.Done()
- for i := 0; i < iterations; i++ {
+ for i := range iterations {
result := cb.BuildSystemPromptWithCache()
if result == "" {
errs <- "empty prompt returned"
diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go
index a6fd365c7..ed25f537f 100644
--- a/pkg/agent/instance.go
+++ b/pkg/agent/instance.go
@@ -1,8 +1,11 @@
package agent
import (
+ "fmt"
+ "log"
"os"
"path/filepath"
+ "regexp"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
@@ -15,22 +18,24 @@ import (
// AgentInstance represents a fully configured agent with its own workspace,
// session manager, context builder, and tool registry.
type AgentInstance struct {
- ID string
- Name string
- Model string
- Fallbacks []string
- Workspace string
- MaxIterations int
- MaxTokens int
- Temperature float64
- ContextWindow int
- Provider providers.LLMProvider
- Sessions *session.SessionManager
- ContextBuilder *ContextBuilder
- Tools *tools.ToolRegistry
- Subagents *config.SubagentsConfig
- SkillsFilter []string
- Candidates []providers.FallbackCandidate
+ ID string
+ Name string
+ Model string
+ Fallbacks []string
+ Workspace string
+ MaxIterations int
+ MaxTokens int
+ Temperature float64
+ ContextWindow int
+ SummarizeMessageThreshold int
+ SummarizeTokenPercent int
+ Provider providers.LLMProvider
+ Sessions *session.SessionManager
+ ContextBuilder *ContextBuilder
+ Tools *tools.ToolRegistry
+ Subagents *config.SubagentsConfig
+ SkillsFilter []string
+ Candidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
@@ -47,13 +52,24 @@ func NewAgentInstance(
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
restrict := defaults.RestrictToWorkspace
+ readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
+
+ // Compile path whitelist patterns from config.
+ allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
+ allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
+
toolsRegistry := tools.NewToolRegistry()
- toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
- toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
- toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
- toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
- toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
- toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
+ toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths))
+ toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
+ toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
+ execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
+ if err != nil {
+ log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
+ }
+ toolsRegistry.Register(execTool)
+
+ toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths))
+ toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths))
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
@@ -87,30 +103,82 @@ func NewAgentInstance(
temperature = *defaults.Temperature
}
+ summarizeMessageThreshold := defaults.SummarizeMessageThreshold
+ if summarizeMessageThreshold == 0 {
+ summarizeMessageThreshold = 20
+ }
+
+ summarizeTokenPercent := defaults.SummarizeTokenPercent
+ if summarizeTokenPercent == 0 {
+ summarizeTokenPercent = 75
+ }
+
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
Fallbacks: fallbacks,
}
- candidates := providers.ResolveCandidates(modelCfg, defaults.Provider)
+ resolveFromModelList := func(raw string) (string, bool) {
+ ensureProtocol := func(model string) string {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ return ""
+ }
+ if strings.Contains(model, "/") {
+ return model
+ }
+ return "openai/" + model
+ }
+
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return "", false
+ }
+
+ if cfg != nil {
+ if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
+ return ensureProtocol(mc.Model), true
+ }
+
+ for i := range cfg.ModelList {
+ fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
+ if fullModel == "" {
+ continue
+ }
+ if fullModel == raw {
+ return ensureProtocol(fullModel), true
+ }
+ _, modelID := providers.ExtractProtocol(fullModel)
+ if modelID == raw {
+ return ensureProtocol(fullModel), true
+ }
+ }
+ }
+
+ return "", false
+ }
+
+ candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList)
return &AgentInstance{
- ID: agentID,
- Name: agentName,
- Model: model,
- Fallbacks: fallbacks,
- Workspace: workspace,
- MaxIterations: maxIter,
- MaxTokens: maxTokens,
- Temperature: temperature,
- ContextWindow: maxTokens,
- Provider: provider,
- Sessions: sessionsManager,
- ContextBuilder: contextBuilder,
- Tools: toolsRegistry,
- Subagents: subagents,
- SkillsFilter: skillsFilter,
- Candidates: candidates,
+ ID: agentID,
+ Name: agentName,
+ Model: model,
+ Fallbacks: fallbacks,
+ Workspace: workspace,
+ MaxIterations: maxIter,
+ MaxTokens: maxTokens,
+ Temperature: temperature,
+ ContextWindow: maxTokens,
+ SummarizeMessageThreshold: summarizeMessageThreshold,
+ SummarizeTokenPercent: summarizeTokenPercent,
+ Provider: provider,
+ Sessions: sessionsManager,
+ ContextBuilder: contextBuilder,
+ Tools: toolsRegistry,
+ Subagents: subagents,
+ SkillsFilter: skillsFilter,
+ Candidates: candidates,
}
}
@@ -143,6 +211,19 @@ func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentD
return defaults.ModelFallbacks
}
+func compilePatterns(patterns []string) []*regexp.Regexp {
+ compiled := make([]*regexp.Regexp, 0, len(patterns))
+ for _, p := range patterns {
+ re, err := regexp.Compile(p)
+ if err != nil {
+ fmt.Printf("Warning: invalid path pattern %q: %v\n", p, err)
+ continue
+ }
+ compiled = append(compiled, re)
+ }
+ return compiled
+}
+
func expandHome(path string) string {
if path == "" {
return path
diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go
index fcc8e9bea..4f41ecd1c 100644
--- a/pkg/agent/instance_test.go
+++ b/pkg/agent/instance_test.go
@@ -93,3 +93,70 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) {
t.Fatalf("Temperature = %f, want %f", agent.Temperature, 0.7)
}
}
+
+func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
+ tests := []struct {
+ name string
+ aliasName string
+ modelName string
+ apiBase string
+ wantProvider string
+ wantModel string
+ }{
+ {
+ name: "alias with provider prefix",
+ aliasName: "step-3.5-flash",
+ modelName: "openrouter/stepfun/step-3.5-flash:free",
+ apiBase: "https://openrouter.ai/api/v1",
+ wantProvider: "openrouter",
+ wantModel: "stepfun/step-3.5-flash:free",
+ },
+ {
+ name: "alias without provider prefix",
+ aliasName: "glm-5",
+ modelName: "glm-5",
+ apiBase: "https://api.z.ai/api/coding/paas/v4",
+ wantProvider: "openai",
+ wantModel: "glm-5",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-instance-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: tt.aliasName,
+ },
+ },
+ ModelList: []config.ModelConfig{
+ {
+ ModelName: tt.aliasName,
+ Model: tt.modelName,
+ APIBase: tt.apiBase,
+ },
+ },
+ }
+
+ provider := &mockProvider{}
+ agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider)
+
+ if len(agent.Candidates) != 1 {
+ t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates))
+ }
+ if agent.Candidates[0].Provider != tt.wantProvider {
+ t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider)
+ }
+ if agent.Candidates[0].Model != tt.wantModel {
+ t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel)
+ }
+ })
+ }
+}
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 37591fa79..88625bc7b 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -9,7 +9,10 @@ package agent
import (
"context"
"encoding/json"
+ "errors"
"fmt"
+ "path/filepath"
+ "regexp"
"strings"
"sync"
"sync/atomic"
@@ -21,12 +24,15 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/mcp"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
+ "github.com/sipeed/picoclaw/pkg/voice"
)
type AgentLoop struct {
@@ -38,21 +44,30 @@ type AgentLoop struct {
summarizing sync.Map
fallback *providers.FallbackChain
channelManager *channels.Manager
+ mediaStore media.MediaStore
+ transcriber voice.Transcriber
}
// processOptions configures how a message is processed
type processOptions struct {
- SessionKey string // Session identifier for history/context
- Channel string // Target channel for tool execution
- ChatID string // Target chat ID for tool execution
- UserMessage string // User message content (may include prefix)
- DefaultResponse string // Response when LLM returns empty
- EnableSummary bool // Whether to trigger summarization
- SendResponse bool // Whether to send response via bus
- NoHistory bool // If true, don't load session history (for heartbeat)
+ SessionKey string // Session identifier for history/context
+ Channel string // Target channel for tool execution
+ ChatID string // Target chat ID for tool execution
+ UserMessage string // User message content (may include prefix)
+ Media []string // media:// refs from inbound message
+ DefaultResponse string // Response when LLM returns empty
+ EnableSummary bool // Whether to trigger summarization
+ SendResponse bool // Whether to send response via bus
+ NoHistory bool // If true, don't load session history (for heartbeat)
}
-func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
+const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
+
+func NewAgentLoop(
+ cfg *config.Config,
+ msgBus *bus.MessageBus,
+ provider providers.LLMProvider,
+) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
// Register shared tools to all agents
@@ -93,7 +108,7 @@ func registerSharedTools(
}
// Web tools
- if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
+ searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
@@ -109,11 +124,24 @@ func registerSharedTools(
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
+ GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey,
+ GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL,
+ GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine,
+ GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults,
+ GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled,
Proxy: cfg.Tools.Web.Proxy,
- }); searchTool != nil {
+ })
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()})
+ } else if searchTool != nil {
agent.Tools.Register(searchTool)
}
- agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
+ fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes)
+ if err != nil {
+ logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()})
+ } else {
+ agent.Tools.Register(fetchTool)
+ }
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
@@ -122,12 +150,13 @@ func registerSharedTools(
// Message tool
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
- msgBus.PublishOutbound(bus.OutboundMessage{
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer pubCancel()
+ return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
})
- return nil
})
agent.Tools.Register(messageTool)
@@ -158,6 +187,72 @@ func registerSharedTools(
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
+ // Initialize MCP servers for all agents
+ if al.cfg.Tools.MCP.Enabled {
+ mcpManager := mcp.NewManager()
+ // Ensure MCP connections are cleaned up on exit, regardless of initialization success
+ // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
+ defer func() {
+ if err := mcpManager.Close(); err != nil {
+ logger.ErrorCF("agent", "Failed to close MCP manager",
+ map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ var workspacePath string
+ if defaultAgent != nil && defaultAgent.Workspace != "" {
+ workspacePath = defaultAgent.Workspace
+ } else {
+ workspacePath = al.cfg.WorkspacePath()
+ }
+
+ if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
+ logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
+ map[string]any{
+ "error": err.Error(),
+ })
+ } else {
+ // Register MCP tools for all agents
+ servers := mcpManager.GetServers()
+ uniqueTools := 0
+ totalRegistrations := 0
+ agentIDs := al.registry.ListAgentIDs()
+ agentCount := len(agentIDs)
+
+ for serverName, conn := range servers {
+ uniqueTools += len(conn.Tools)
+ for _, tool := range conn.Tools {
+ for _, agentID := range agentIDs {
+ agent, ok := al.registry.GetAgent(agentID)
+ if !ok {
+ continue
+ }
+ mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
+ agent.Tools.Register(mcpTool)
+ totalRegistrations++
+ logger.DebugCF("agent", "Registered MCP tool",
+ map[string]any{
+ "agent_id": agentID,
+ "server": serverName,
+ "tool": tool.Name,
+ "name": mcpTool.Name(),
+ })
+ }
+ }
+ }
+ logger.InfoCF("agent", "MCP tools registered successfully",
+ map[string]any{
+ "server_count": len(servers),
+ "unique_tools": uniqueTools,
+ "total_registrations": totalRegistrations,
+ "agent_count": agentCount,
+ })
+ }
+ }
+
for al.running.Load() {
select {
case <-ctx.Done():
@@ -168,33 +263,61 @@ func (al *AgentLoop) Run(ctx context.Context) error {
continue
}
- response, err := al.processMessage(ctx, msg)
- if err != nil {
- response = fmt.Sprintf("Error processing message: %v", err)
- }
+ // Process message
+ func() {
+ // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
+ // Currently disabled because files are deleted before the LLM can access their content.
+ // defer func() {
+ // if al.mediaStore != nil && msg.MediaScope != "" {
+ // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil {
+ // logger.WarnCF("agent", "Failed to release media", map[string]any{
+ // "scope": msg.MediaScope,
+ // "error": releaseErr.Error(),
+ // })
+ // }
+ // }
+ // }()
- if response != "" {
- // Check if the message tool already sent a response during this round.
- // If so, skip publishing to avoid duplicate messages to the user.
- // Use default agent's tools to check (message tool is shared).
- alreadySent := false
- defaultAgent := al.registry.GetDefaultAgent()
- if defaultAgent != nil {
- if tool, ok := defaultAgent.Tools.Get("message"); ok {
- if mt, ok := tool.(*tools.MessageTool); ok {
- alreadySent = mt.HasSentInRound()
+ response, err := al.processMessage(ctx, msg)
+ if err != nil {
+ response = fmt.Sprintf("Error processing message: %v", err)
+ }
+
+ if response != "" {
+ // Check if the message tool already sent a response during this round.
+ // If so, skip publishing to avoid duplicate messages to the user.
+ // Use default agent's tools to check (message tool is shared).
+ alreadySent := false
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent != nil {
+ if tool, ok := defaultAgent.Tools.Get("message"); ok {
+ if mt, ok := tool.(*tools.MessageTool); ok {
+ alreadySent = mt.HasSentInRound()
+ }
}
}
- }
- if !alreadySent {
- al.bus.PublishOutbound(bus.OutboundMessage{
- Channel: msg.Channel,
- ChatID: msg.ChatID,
- Content: response,
- })
+ if !alreadySent {
+ al.bus.PublishOutbound(ctx, bus.OutboundMessage{
+ Channel: msg.Channel,
+ ChatID: msg.ChatID,
+ Content: response,
+ })
+ logger.InfoCF("agent", "Published outbound response",
+ map[string]any{
+ "channel": msg.Channel,
+ "chat_id": msg.ChatID,
+ "content_len": len(response),
+ })
+ } else {
+ logger.DebugCF(
+ "agent",
+ "Skipped outbound (message tool already sent)",
+ map[string]any{"channel": msg.Channel},
+ )
+ }
}
- }
+ }()
}
}
@@ -217,6 +340,99 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
al.channelManager = cm
}
+// SetMediaStore injects a MediaStore for media lifecycle management.
+func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
+ al.mediaStore = s
+}
+
+// SetTranscriber injects a voice transcriber for agent-level audio transcription.
+func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
+ al.transcriber = t
+}
+
+var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`)
+
+// transcribeAudioInMessage resolves audio media refs, transcribes them, and
+// replaces audio annotations in msg.Content with the transcribed text.
+func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) bus.InboundMessage {
+ if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 {
+ return msg
+ }
+
+ // Transcribe each audio media ref in order.
+ var transcriptions []string
+ for _, ref := range msg.Media {
+ path, meta, err := al.mediaStore.ResolveWithMeta(ref)
+ if err != nil {
+ logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
+ continue
+ }
+ if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
+ continue
+ }
+ result, err := al.transcriber.Transcribe(ctx, path)
+ if err != nil {
+ logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err})
+ transcriptions = append(transcriptions, "")
+ continue
+ }
+ transcriptions = append(transcriptions, result.Text)
+ }
+
+ if len(transcriptions) == 0 {
+ return msg
+ }
+
+ // Replace audio annotations sequentially with transcriptions.
+ idx := 0
+ newContent := audioAnnotationRe.ReplaceAllStringFunc(msg.Content, func(match string) string {
+ if idx >= len(transcriptions) {
+ return match
+ }
+ text := transcriptions[idx]
+ idx++
+ return "[voice: " + text + "]"
+ })
+
+ // Append any remaining transcriptions not matched by an annotation.
+ for ; idx < len(transcriptions); idx++ {
+ newContent += "\n[voice: " + transcriptions[idx] + "]"
+ }
+
+ msg.Content = newContent
+ return msg
+}
+
+// inferMediaType determines the media type ("image", "audio", "video", "file")
+// from a filename and MIME content type.
+func inferMediaType(filename, contentType string) string {
+ ct := strings.ToLower(contentType)
+ fn := strings.ToLower(filename)
+
+ if strings.HasPrefix(ct, "image/") {
+ return "image"
+ }
+ if strings.HasPrefix(ct, "audio/") || ct == "application/ogg" {
+ return "audio"
+ }
+ if strings.HasPrefix(ct, "video/") {
+ return "video"
+ }
+
+ // Fallback: infer from extension
+ ext := filepath.Ext(fn)
+ switch ext {
+ case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg":
+ return "image"
+ case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus":
+ return "audio"
+ case ".mp4", ".avi", ".mov", ".webm", ".mkv":
+ return "video"
+ }
+
+ return "file"
+}
+
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
@@ -235,7 +451,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
-func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
+func (al *AgentLoop) ProcessDirect(
+ ctx context.Context,
+ content, sessionKey string,
+) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
@@ -256,14 +475,20 @@ func (al *AgentLoop) ProcessDirectWithChannel(
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
-func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
+func (al *AgentLoop) ProcessHeartbeat(
+ ctx context.Context,
+ content, channel, chatID string,
+) (string, error) {
agent := al.registry.GetDefaultAgent()
+ if agent == nil {
+ return "", fmt.Errorf("no default agent for heartbeat")
+ }
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: "heartbeat",
Channel: channel,
ChatID: chatID,
UserMessage: content,
- DefaultResponse: "I've completed processing but have no response to give.",
+ DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
NoHistory: true, // Don't load session history for heartbeat
@@ -278,13 +503,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
} else {
logContent = utils.Truncate(msg.Content, 80)
}
- logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
+ logger.InfoCF(
+ "agent",
+ fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"sender_id": msg.SenderID,
"session_key": msg.SessionKey,
- })
+ },
+ )
+
+ msg = al.transcribeAudioInMessage(ctx, msg)
// Route system messages to processSystemMessage
if msg.Channel == "system" {
@@ -310,6 +540,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
if !ok {
agent = al.registry.GetDefaultAgent()
}
+ if agent == nil {
+ return "", fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
+ }
+
+ // Reset message-tool state for this round so we don't skip publishing due to a previous round.
+ if tool, ok := agent.Tools.Get("message"); ok {
+ if mt, ok := tool.(tools.ContextualTool); ok {
+ mt.SetContext(msg.Channel, msg.ChatID)
+ }
+ }
// Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron)
sessionKey := route.SessionKey
@@ -329,15 +569,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
- DefaultResponse: "I've completed processing but have no response to give.",
+ Media: msg.Media,
+ DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
})
}
-func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
+func (al *AgentLoop) processSystemMessage(
+ ctx context.Context,
+ msg bus.InboundMessage,
+) (string, error) {
if msg.Channel != "system" {
- return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
+ return "", fmt.Errorf(
+ "processSystemMessage called with non-system message channel: %s",
+ msg.Channel,
+ )
}
logger.InfoCF("agent", "Processing system message",
@@ -376,6 +623,9 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
// Use default agent for system messages
agent := al.registry.GetDefaultAgent()
+ if agent == nil {
+ return "", fmt.Errorf("no default agent for system message")
+ }
// Use the origin session for context
sessionKey := routing.BuildAgentMainSessionKey(agent.ID)
@@ -392,14 +642,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
}
// runAgentLoop is the core message processing logic.
-func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
+func (al *AgentLoop) runAgentLoop(
+ ctx context.Context,
+ agent *AgentInstance,
+ opts processOptions,
+) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
- logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
+ logger.WarnCF(
+ "agent",
+ "Failed to record last channel",
+ map[string]any{"error": err.Error()},
+ )
}
}
}
@@ -418,11 +676,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
history,
summary,
opts.UserMessage,
- nil,
+ opts.Media,
opts.Channel,
opts.ChatID,
)
+ // Resolve media:// refs to base64 data URLs (streaming)
+ maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
+ messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
+
// 3. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
@@ -451,7 +713,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
// 8. Optional: send response via bus
if opts.SendResponse {
- al.bus.PublishOutbound(bus.OutboundMessage{
+ al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: finalContent,
@@ -471,6 +733,62 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
return finalContent, nil
}
+func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) {
+ if al.channelManager == nil {
+ return ""
+ }
+ if ch, ok := al.channelManager.GetChannel(channelName); ok {
+ return ch.ReasoningChannelID()
+ }
+ return ""
+}
+
+func (al *AgentLoop) handleReasoning(
+ ctx context.Context,
+ reasoningContent, channelName, channelID string,
+) {
+ if reasoningContent == "" || channelName == "" || channelID == "" {
+ return
+ }
+
+ // Check context cancellation before attempting to publish,
+ // since PublishOutbound's select may race between send and ctx.Done().
+ if ctx.Err() != nil {
+ return
+ }
+
+ // Use a short timeout so the goroutine does not block indefinitely when
+ // the outbound bus is full. Reasoning output is best-effort; dropping it
+ // is acceptable to avoid goroutine accumulation.
+ pubCtx, pubCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer pubCancel()
+
+ if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
+ Channel: channelName,
+ ChatID: channelID,
+ Content: reasoningContent,
+ }); err != nil {
+ // Treat context.DeadlineExceeded / context.Canceled as expected
+ // (bus full under load, or parent canceled). Check the error
+ // itself rather than ctx.Err(), because pubCtx may time out
+ // (5 s) while the parent ctx is still active.
+ // Also treat ErrBusClosed as expected — it occurs during normal
+ // shutdown when the bus is closed before all goroutines finish.
+ if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) ||
+ errors.Is(err, bus.ErrBusClosed) {
+ logger.DebugCF("agent", "Reasoning publish skipped (timeout/cancel)", map[string]any{
+ "channel": channelName,
+ "error": err.Error(),
+ })
+ } else {
+ logger.WarnCF("agent", "Failed to publish reasoning (best-effort)", map[string]any{
+ "channel": channelName,
+ "error": err.Error(),
+ })
+ }
+ }
+}
+
// runLLMIteration executes the LLM call loop with tool handling.
func (al *AgentLoop) runLLMIteration(
ctx context.Context,
@@ -521,22 +839,33 @@ func (al *AgentLoop) runLLMIteration(
callLLM := func() (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
- fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
+ fbResult, fbErr := al.fallback.Execute(
+ ctx,
+ agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
- return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
- "max_tokens": agent.MaxTokens,
- "temperature": agent.Temperature,
- "prompt_cache_key": agent.ID,
- })
+ return agent.Provider.Chat(
+ ctx,
+ messages,
+ providerToolDefs,
+ model,
+ map[string]any{
+ "max_tokens": agent.MaxTokens,
+ "temperature": agent.Temperature,
+ "prompt_cache_key": agent.ID,
+ },
+ )
},
)
if fbErr != nil {
return nil, fbErr
}
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
- logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
- fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
- map[string]any{"agent_id": agent.ID, "iteration": iteration})
+ logger.InfoCF(
+ "agent",
+ fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
+ fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
+ map[string]any{"agent_id": agent.ID, "iteration": iteration},
+ )
}
return fbResult.Response, nil
}
@@ -556,19 +885,48 @@ func (al *AgentLoop) runLLMIteration(
}
errMsg := strings.ToLower(err.Error())
- isContextError := strings.Contains(errMsg, "token") ||
- strings.Contains(errMsg, "context") ||
+
+ // Check if this is a network/HTTP timeout — not a context window error.
+ isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
+ strings.Contains(errMsg, "deadline exceeded") ||
+ strings.Contains(errMsg, "client.timeout") ||
+ strings.Contains(errMsg, "timed out") ||
+ strings.Contains(errMsg, "timeout exceeded")
+
+ // Detect real context window / token limit errors, excluding network timeouts.
+ isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
+ strings.Contains(errMsg, "context window") ||
+ strings.Contains(errMsg, "maximum context length") ||
+ strings.Contains(errMsg, "token limit") ||
+ strings.Contains(errMsg, "too many tokens") ||
+ strings.Contains(errMsg, "max_tokens") ||
strings.Contains(errMsg, "invalidparameter") ||
- strings.Contains(errMsg, "length")
+ strings.Contains(errMsg, "prompt is too long") ||
+ strings.Contains(errMsg, "request too large"))
+
+ if isTimeoutError && retry < maxRetries {
+ backoff := time.Duration(retry+1) * 5 * time.Second
+ logger.WarnCF("agent", "Timeout error, retrying after backoff", map[string]any{
+ "error": err.Error(),
+ "retry": retry,
+ "backoff": backoff.String(),
+ })
+ time.Sleep(backoff)
+ continue
+ }
if isContextError && retry < maxRetries {
- logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
- "error": err.Error(),
- "retry": retry,
- })
+ logger.WarnCF(
+ "agent",
+ "Context window error detected, attempting compression",
+ map[string]any{
+ "error": err.Error(),
+ "retry": retry,
+ },
+ )
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
- al.bus.PublishOutbound(bus.OutboundMessage{
+ al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: "Context window exceeded. Compressing history and retrying...",
@@ -597,6 +955,23 @@ func (al *AgentLoop) runLLMIteration(
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
+ go al.handleReasoning(
+ ctx,
+ response.Reasoning,
+ opts.Channel,
+ al.targetReasoningChannelID(opts.Channel),
+ )
+
+ logger.DebugCF("agent", "LLM response",
+ map[string]any{
+ "agent_id": agent.ID,
+ "iteration": iteration,
+ "content_chars": len(response.Content),
+ "tool_calls": len(response.ToolCalls),
+ "reasoning": response.Reasoning,
+ "target_channel": al.targetReasoningChannelID(opts.Channel),
+ "channel": opts.Channel,
+ })
// Check if no tool calls - we're done
if len(response.ToolCalls) == 0 {
finalContent = response.Content
@@ -660,66 +1035,102 @@ func (al *AgentLoop) runLLMIteration(
// Save assistant message with tool calls to session
agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
- // Execute tool calls
- for _, tc := range normalizedToolCalls {
- argsJSON, _ := json.Marshal(tc.Arguments)
- argsPreview := utils.Truncate(string(argsJSON), 200)
- logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
- map[string]any{
- "agent_id": agent.ID,
- "tool": tc.Name,
- "iteration": iteration,
- })
+ // Execute tool calls in parallel
+ type indexedAgentResult struct {
+ result *tools.ToolResult
+ tc providers.ToolCall
+ }
- // Create async callback for tools that implement AsyncTool
- // NOTE: Following openclaw's design, async tools do NOT send results directly to users.
- // Instead, they notify the agent via PublishInbound, and the agent decides
- // whether to forward the result to the user (in processSystemMessage).
- asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
- // Log the async completion but don't send directly to user
- // The agent will handle user notification via processSystemMessage
- if !result.Silent && result.ForUser != "" {
- logger.InfoCF("agent", "Async tool completed, agent will handle notification",
- map[string]any{
- "tool": tc.Name,
- "content_len": len(result.ForUser),
- })
+ agentResults := make([]indexedAgentResult, len(normalizedToolCalls))
+ var wg sync.WaitGroup
+
+ for i, tc := range normalizedToolCalls {
+ agentResults[i].tc = tc
+
+ wg.Add(1)
+ go func(idx int, tc providers.ToolCall) {
+ defer wg.Done()
+
+ argsJSON, _ := json.Marshal(tc.Arguments)
+ argsPreview := utils.Truncate(string(argsJSON), 200)
+ logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
+ map[string]any{
+ "agent_id": agent.ID,
+ "tool": tc.Name,
+ "iteration": iteration,
+ })
+
+ // Create async callback for tools that implement AsyncTool
+ asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
+ if !result.Silent && result.ForUser != "" {
+ logger.InfoCF("agent", "Async tool completed, agent will handle notification",
+ map[string]any{
+ "tool": tc.Name,
+ "content_len": len(result.ForUser),
+ })
+ }
}
- }
- toolResult := agent.Tools.ExecuteWithContext(
- ctx,
- tc.Name,
- tc.Arguments,
- opts.Channel,
- opts.ChatID,
- asyncCallback,
- )
+ toolResult := agent.Tools.ExecuteWithContext(
+ ctx,
+ tc.Name,
+ tc.Arguments,
+ opts.Channel,
+ opts.ChatID,
+ asyncCallback,
+ )
+ agentResults[idx].result = toolResult
+ }(i, tc)
+ }
+ wg.Wait()
+ // Process results in original order (send to user, save to session)
+ for _, r := range agentResults {
// Send ForUser content to user immediately if not Silent
- if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
- al.bus.PublishOutbound(bus.OutboundMessage{
+ if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse {
+ al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
- Content: toolResult.ForUser,
+ Content: r.result.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
- "tool": tc.Name,
- "content_len": len(toolResult.ForUser),
+ "tool": r.tc.Name,
+ "content_len": len(r.result.ForUser),
})
}
+ // If tool returned media refs, publish them as outbound media
+ if len(r.result.Media) > 0 && opts.SendResponse {
+ parts := make([]bus.MediaPart, 0, len(r.result.Media))
+ for _, ref := range r.result.Media {
+ part := bus.MediaPart{Ref: ref}
+ if al.mediaStore != nil {
+ if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
+ part.Filename = meta.Filename
+ part.ContentType = meta.ContentType
+ part.Type = inferMediaType(meta.Filename, meta.ContentType)
+ }
+ }
+ parts = append(parts, part)
+ }
+ al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
+ Channel: opts.Channel,
+ ChatID: opts.ChatID,
+ Parts: parts,
+ })
+ }
+
// Determine content for LLM based on tool result
- contentForLLM := toolResult.ForLLM
- if contentForLLM == "" && toolResult.Err != nil {
- contentForLLM = toolResult.Err.Error()
+ contentForLLM := r.result.ForLLM
+ if contentForLLM == "" && r.result.Err != nil {
+ contentForLLM = r.result.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
- ToolCallID: tc.ID,
+ ToolCallID: r.tc.ID,
}
messages = append(messages, toolResultMsg)
@@ -755,9 +1166,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st
func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
- threshold := agent.ContextWindow * 75 / 100
+ threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
- if len(newHistory) > 20 || tokenEstimate > threshold {
+ if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
summarizeKey := agent.ID + ":" + sessionKey
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
@@ -865,7 +1276,11 @@ func formatMessagesForLog(messages []providers.Message) string {
for _, tc := range msg.ToolCalls {
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
if tc.Function != nil {
- fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
+ fmt.Fprintf(
+ &sb,
+ " Arguments: %s\n",
+ utils.Truncate(tc.Function.Arguments, 200),
+ )
}
}
}
@@ -894,7 +1309,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
if len(tool.Function.Parameters) > 0 {
- fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
+ fmt.Fprintf(
+ &sb,
+ " Parameters: %s\n",
+ utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
+ )
}
}
sb.WriteString("]")
@@ -991,7 +1410,9 @@ func (al *AgentLoop) summarizeBatch(
existingSummary string,
) (string, error) {
var sb strings.Builder
- sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
+ sb.WriteString(
+ "Provide a concise summary of this conversation segment, preserving core context and key points.\n",
+ )
if existingSummary != "" {
sb.WriteString("Existing context: ")
sb.WriteString(existingSummary)
@@ -1122,21 +1543,20 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage)
return "", false
}
-// extractPeer extracts the routing peer from inbound message metadata.
+// extractPeer extracts the routing peer from the inbound message's structured Peer field.
func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
- peerKind := msg.Metadata["peer_kind"]
- if peerKind == "" {
+ if msg.Peer.Kind == "" {
return nil
}
- peerID := msg.Metadata["peer_id"]
+ peerID := msg.Peer.ID
if peerID == "" {
- if peerKind == "direct" {
+ if msg.Peer.Kind == "direct" {
peerID = msg.SenderID
} else {
peerID = msg.ChatID
}
}
- return &routing.RoutePeer{Kind: peerKind, ID: peerID}
+ return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID}
}
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
diff --git a/pkg/agent/loop_media.go b/pkg/agent/loop_media.go
new file mode 100644
index 000000000..82547a008
--- /dev/null
+++ b/pkg/agent/loop_media.go
@@ -0,0 +1,122 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
+// License: MIT
+//
+// Copyright (c) 2026 PicoClaw contributors
+
+package agent
+
+import (
+ "bytes"
+ "encoding/base64"
+ "io"
+ "os"
+ "strings"
+
+ "github.com/h2non/filetype"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
+// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
+// both raw bytes and encoded string in memory simultaneously.
+// Returns a new slice; original messages are not mutated.
+func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
+ if store == nil {
+ return messages
+ }
+
+ result := make([]providers.Message, len(messages))
+ copy(result, messages)
+
+ for i, m := range result {
+ if len(m.Media) == 0 {
+ continue
+ }
+
+ resolved := make([]string, 0, len(m.Media))
+ for _, ref := range m.Media {
+ if !strings.HasPrefix(ref, "media://") {
+ resolved = append(resolved, ref)
+ continue
+ }
+
+ localPath, meta, err := store.ResolveWithMeta(ref)
+ if err != nil {
+ logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
+ "ref": ref,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ info, err := os.Stat(localPath)
+ if err != nil {
+ logger.WarnCF("agent", "Failed to stat media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+ if info.Size() > int64(maxSize) {
+ logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
+ "path": localPath,
+ "size": info.Size(),
+ "max_size": maxSize,
+ })
+ continue
+ }
+
+ // Determine MIME type: prefer metadata, fallback to magic-bytes detection
+ mime := meta.ContentType
+ if mime == "" {
+ kind, ftErr := filetype.MatchFile(localPath)
+ if ftErr != nil || kind == filetype.Unknown {
+ logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
+ "path": localPath,
+ })
+ continue
+ }
+ mime = kind.MIME.Value
+ }
+
+ // Streaming base64: open file → base64 encoder → buffer
+ // Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
+ f, err := os.Open(localPath)
+ if err != nil {
+ logger.WarnCF("agent", "Failed to open media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ prefix := "data:" + mime + ";base64,"
+ encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
+ var buf bytes.Buffer
+ buf.Grow(len(prefix) + encodedLen)
+ buf.WriteString(prefix)
+
+ encoder := base64.NewEncoder(base64.StdEncoding, &buf)
+ if _, err := io.Copy(encoder, f); err != nil {
+ f.Close()
+ logger.WarnCF("agent", "Failed to encode media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+ encoder.Close()
+ f.Close()
+
+ resolved = append(resolved, buf.String())
+ }
+
+ result[i].Media = resolved
+ }
+
+ return result
+}
diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go
index 4414398b1..023286f02 100644
--- a/pkg/agent/loop_test.go
+++ b/pkg/agent/loop_test.go
@@ -5,25 +5,39 @@ import (
"fmt"
"os"
"path/filepath"
+ "slices"
+ "strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
-func TestRecordLastChannel(t *testing.T) {
- // Create temp workspace
+type fakeChannel struct{ id string }
+
+func (f *fakeChannel) Name() string { return "fake" }
+func (f *fakeChannel) Start(ctx context.Context) error { return nil }
+func (f *fakeChannel) Stop(ctx context.Context) error { return nil }
+func (f *fakeChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return nil }
+func (f *fakeChannel) IsRunning() bool { return true }
+func (f *fakeChannel) IsAllowed(string) bool { return true }
+func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true }
+func (f *fakeChannel) ReasoningChannelID() string { return f.id }
+
+func newTestAgentLoop(
+ t *testing.T,
+) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) {
+ t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
- defer os.RemoveAll(tmpDir)
-
- // Create test config
- cfg := &config.Config{
+ cfg = &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
@@ -33,74 +47,43 @@ func TestRecordLastChannel(t *testing.T) {
},
},
}
+ msgBus = bus.NewMessageBus()
+ provider = &mockProvider{}
+ al = NewAgentLoop(cfg, msgBus, provider)
+ return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) }
+}
- // Create agent loop
- msgBus := bus.NewMessageBus()
- provider := &mockProvider{}
- al := NewAgentLoop(cfg, msgBus, provider)
+func TestRecordLastChannel(t *testing.T) {
+ al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
+ defer cleanup()
- // Test RecordLastChannel
testChannel := "test-channel"
- err = al.RecordLastChannel(testChannel)
- if err != nil {
+ if err := al.RecordLastChannel(testChannel); err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
-
- // Verify channel was saved
- lastChannel := al.state.GetLastChannel()
- if lastChannel != testChannel {
- t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
+ if got := al.state.GetLastChannel(); got != testChannel {
+ t.Errorf("Expected channel '%s', got '%s'", testChannel, got)
}
-
- // Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
- if al2.state.GetLastChannel() != testChannel {
- t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
+ if got := al2.state.GetLastChannel(); got != testChannel {
+ t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got)
}
}
func TestRecordLastChatID(t *testing.T) {
- // Create temp workspace
- tmpDir, err := os.MkdirTemp("", "agent-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
+ al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
+ defer cleanup()
- // Create test config
- cfg := &config.Config{
- Agents: config.AgentsConfig{
- Defaults: config.AgentDefaults{
- Workspace: tmpDir,
- Model: "test-model",
- MaxTokens: 4096,
- MaxToolIterations: 10,
- },
- },
- }
-
- // Create agent loop
- msgBus := bus.NewMessageBus()
- provider := &mockProvider{}
- al := NewAgentLoop(cfg, msgBus, provider)
-
- // Test RecordLastChatID
testChatID := "test-chat-id-123"
- err = al.RecordLastChatID(testChatID)
- if err != nil {
+ if err := al.RecordLastChatID(testChatID); err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
-
- // Verify chat ID was saved
- lastChatID := al.state.GetLastChatID()
- if lastChatID != testChatID {
- t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
+ if got := al.state.GetLastChatID(); got != testChatID {
+ t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got)
}
-
- // Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
- if al2.state.GetLastChatID() != testChatID {
- t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
+ if got := al2.state.GetLastChatID(); got != testChatID {
+ t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got)
}
}
@@ -175,13 +158,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
- found := false
- for _, name := range toolsList {
- if name == "mock_custom" {
- found = true
- break
- }
- }
+ found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -250,13 +227,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
- found := false
- for _, name := range toolsList {
- if name == "mock_custom" {
- found = true
- break
- }
- }
+ found := slices.Contains(toolsList, "mock_custom")
if !found {
t.Error("Expected custom tool to be registered")
}
@@ -631,3 +602,350 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
}
}
+
+func TestTargetReasoningChannelID_AllChannels(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{})
+ chManager, err := channels.NewManager(&config.Config{}, bus.NewMessageBus(), nil)
+ if err != nil {
+ t.Fatalf("Failed to create channel manager: %v", err)
+ }
+ for name, id := range map[string]string{
+ "whatsapp": "rid-whatsapp",
+ "telegram": "rid-telegram",
+ "feishu": "rid-feishu",
+ "discord": "rid-discord",
+ "maixcam": "rid-maixcam",
+ "qq": "rid-qq",
+ "dingtalk": "rid-dingtalk",
+ "slack": "rid-slack",
+ "line": "rid-line",
+ "onebot": "rid-onebot",
+ "wecom": "rid-wecom",
+ "wecom_app": "rid-wecom-app",
+ } {
+ chManager.RegisterChannel(name, &fakeChannel{id: id})
+ }
+ al.SetChannelManager(chManager)
+ tests := []struct {
+ channel string
+ wantID string
+ }{
+ {channel: "whatsapp", wantID: "rid-whatsapp"},
+ {channel: "telegram", wantID: "rid-telegram"},
+ {channel: "feishu", wantID: "rid-feishu"},
+ {channel: "discord", wantID: "rid-discord"},
+ {channel: "maixcam", wantID: "rid-maixcam"},
+ {channel: "qq", wantID: "rid-qq"},
+ {channel: "dingtalk", wantID: "rid-dingtalk"},
+ {channel: "slack", wantID: "rid-slack"},
+ {channel: "line", wantID: "rid-line"},
+ {channel: "onebot", wantID: "rid-onebot"},
+ {channel: "wecom", wantID: "rid-wecom"},
+ {channel: "wecom_app", wantID: "rid-wecom-app"},
+ {channel: "unknown", wantID: ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.channel, func(t *testing.T) {
+ got := al.targetReasoningChannelID(tt.channel)
+ if got != tt.wantID {
+ t.Fatalf("targetReasoningChannelID(%q) = %q, want %q", tt.channel, got, tt.wantID)
+ }
+ })
+ }
+}
+
+func TestHandleReasoning(t *testing.T) {
+ newLoop := func(t *testing.T) (*AgentLoop, *bus.MessageBus) {
+ t.Helper()
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+ msgBus := bus.NewMessageBus()
+ return NewAgentLoop(cfg, msgBus, &mockProvider{}), msgBus
+ }
+
+ t.Run("skips when any required field is empty", func(t *testing.T) {
+ al, msgBus := newLoop(t)
+ al.handleReasoning(context.Background(), "reasoning", "telegram", "")
+
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancel()
+ if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
+ t.Fatalf("expected no outbound message, got %+v", msg)
+ }
+ })
+
+ t.Run("publishes one message for non telegram", func(t *testing.T) {
+ al, msgBus := newLoop(t)
+ al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
+
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancel()
+ msg, ok := msgBus.SubscribeOutbound(ctx)
+ if !ok {
+ t.Fatal("expected an outbound message")
+ }
+ if msg.Channel != "slack" || msg.ChatID != "channel-1" || msg.Content != "hello reasoning" {
+ t.Fatalf("unexpected outbound message: %+v", msg)
+ }
+ })
+
+ t.Run("publishes one message for telegram", func(t *testing.T) {
+ al, msgBus := newLoop(t)
+ reasoning := "hello telegram reasoning"
+ al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
+
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancel()
+ msg, ok := msgBus.SubscribeOutbound(ctx)
+ if !ok {
+ t.Fatal("expected outbound message")
+ }
+
+ if msg.Channel != "telegram" {
+ t.Fatalf("expected telegram channel message, got %+v", msg)
+ }
+ if msg.ChatID != "tg-chat" {
+ t.Fatalf("expected chatID tg-chat, got %+v", msg)
+ }
+ if msg.Content != reasoning {
+ t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
+ }
+ })
+ t.Run("expired ctx", func(t *testing.T) {
+ al, msgBus := newLoop(t)
+ reasoning := "hello telegram reasoning"
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
+
+ ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancel()
+ msg, ok := msgBus.SubscribeOutbound(ctx)
+ if ok {
+ t.Fatalf("expected no outbound message, got %+v", msg)
+ }
+ })
+
+ t.Run("returns promptly when bus is full", func(t *testing.T) {
+ al, msgBus := newLoop(t)
+
+ // Fill the outbound bus buffer until a publish would block.
+ // Use a short timeout to detect when the buffer is full,
+ // rather than hardcoding the buffer size.
+ for i := 0; ; i++ {
+ fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{
+ Channel: "filler",
+ ChatID: "filler",
+ Content: fmt.Sprintf("filler-%d", i),
+ })
+ fillCancel()
+ if err != nil {
+ // Buffer is full (timed out trying to send).
+ break
+ }
+ }
+
+ // Use a short-deadline parent context to bound the test.
+ ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
+ defer cancel()
+
+ start := time.Now()
+ al.handleReasoning(ctx, "should timeout", "slack", "channel-full")
+ elapsed := time.Since(start)
+
+ // handleReasoning uses a 5s internal timeout, but the parent ctx
+ // expires in 500ms. It should return within ~500ms, not 5s.
+ if elapsed > 2*time.Second {
+ t.Fatalf("handleReasoning blocked too long (%v); expected prompt return", elapsed)
+ }
+
+ // Drain the bus and verify the reasoning message was NOT published
+ // (it should have been dropped due to timeout).
+ drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer drainCancel()
+ foundReasoning := false
+ for {
+ msg, ok := msgBus.SubscribeOutbound(drainCtx)
+ if !ok {
+ break
+ }
+ if msg.Content == "should timeout" {
+ foundReasoning = true
+ }
+ }
+ if foundReasoning {
+ t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
+ }
+ })
+}
+
+func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ // Create a minimal valid PNG (8-byte header is enough for filetype detection)
+ pngPath := filepath.Join(dir, "test.png")
+ // PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
+ pngHeader := []byte{
+ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
+ 0x00, 0x00, 0x00, 0x0D, // IHDR length
+ 0x49, 0x48, 0x44, 0x52, // "IHDR"
+ 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
+ 0x00, 0x00, 0x00, // no interlace
+ 0x90, 0x77, 0x53, 0xDE, // CRC
+ }
+ if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
+ t.Fatal(err)
+ }
+ ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ messages := []providers.Message{
+ {Role: "user", Content: "describe this", Media: []string{ref}},
+ }
+ result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 1 {
+ t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
+ }
+ if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
+ t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
+ }
+}
+
+func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ bigPath := filepath.Join(dir, "big.png")
+ // Write PNG header + padding to exceed limit
+ data := make([]byte, 1024+1) // 1KB + 1 byte
+ copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
+ if err := os.WriteFile(bigPath, data, 0o644); err != nil {
+ t.Fatal(err)
+ }
+ ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
+
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ // Use a tiny limit (1KB) so the file is oversized
+ result := resolveMediaRefs(messages, store, 1024)
+
+ if len(result[0].Media) != 0 {
+ t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
+ }
+}
+
+func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ txtPath := filepath.Join(dir, "readme.txt")
+ if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
+
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 0 {
+ t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
+ }
+}
+
+func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
+ }
+ result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
+ t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
+ }
+}
+
+func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+ pngPath := filepath.Join(dir, "test.png")
+ pngHeader := []byte{
+ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
+ 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
+ 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
+ 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
+ }
+ os.WriteFile(pngPath, pngHeader, 0o644)
+ ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
+
+ original := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ originalRef := original[0].Media[0]
+
+ resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
+
+ if original[0].Media[0] != originalRef {
+ t.Fatal("resolveMediaRefs mutated original message slice")
+ }
+}
+
+func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
+ store := media.NewFileMediaStore()
+ dir := t.TempDir()
+
+ // File with JPEG content but stored with explicit content type
+ jpegPath := filepath.Join(dir, "photo")
+ jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
+ os.WriteFile(jpegPath, jpegHeader, 0o644)
+ ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
+
+ messages := []providers.Message{
+ {Role: "user", Content: "hi", Media: []string{ref}},
+ }
+ result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
+
+ if len(result[0].Media) != 1 {
+ t.Fatalf("expected 1 media, got %d", len(result[0].Media))
+ }
+ if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
+ t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
+ }
+}
diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go
index dd5f4441c..01e682f3b 100644
--- a/pkg/agent/memory.go
+++ b/pkg/agent/memory.go
@@ -12,6 +12,8 @@ import (
"path/filepath"
"strings"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
)
// MemoryStore manages persistent memory for the agent.
@@ -58,7 +60,9 @@ func (ms *MemoryStore) ReadLongTerm() string {
// WriteLongTerm writes content to the long-term memory file (MEMORY.md).
func (ms *MemoryStore) WriteLongTerm(content string) error {
- return os.WriteFile(ms.memoryFile, []byte(content), 0o644)
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ // Using 0o600 (owner read/write only) for secure default permissions.
+ return fileutil.WriteFileAtomic(ms.memoryFile, []byte(content), 0o600)
}
// ReadToday reads today's daily note.
@@ -78,7 +82,9 @@ func (ms *MemoryStore) AppendToday(content string) error {
// Ensure month directory exists
monthDir := filepath.Dir(todayFile)
- os.MkdirAll(monthDir, 0o755)
+ if err := os.MkdirAll(monthDir, 0o755); err != nil {
+ return err
+ }
var existingContent string
if data, err := os.ReadFile(todayFile); err == nil {
@@ -95,7 +101,8 @@ func (ms *MemoryStore) AppendToday(content string) error {
newContent = existingContent + "\n" + content
}
- return os.WriteFile(todayFile, []byte(newContent), 0o644)
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ return fileutil.WriteFileAtomic(todayFile, []byte(newContent), 0o600)
}
// GetRecentDailyNotes returns daily notes from the last N days.
@@ -104,7 +111,7 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
var sb strings.Builder
first := true
- for i := 0; i < days; i++ {
+ for i := range days {
date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go
index ba757ffd4..91c9e25c5 100644
--- a/pkg/auth/oauth.go
+++ b/pkg/auth/oauth.go
@@ -66,7 +66,8 @@ func decodeBase64(s string) string {
return string(data)
}
-func generateState() (string, error) {
+// GenerateState generates a random state string for OAuth CSRF protection.
+func GenerateState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
@@ -80,7 +81,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
return nil, fmt.Errorf("generating PKCE: %w", err)
}
- state, err := generateState()
+ state, err := GenerateState()
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
@@ -127,7 +128,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)
- if err := openBrowser(authURL); err != nil {
+ if err := OpenBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
@@ -153,7 +154,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
if result.err != nil {
return nil, result.err
}
- return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
+ return ExchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case manualInput := <-manualCh:
if manualInput == "" {
return nil, fmt.Errorf("manual input canceled")
@@ -169,7 +170,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
if code == "" {
return nil, fmt.Errorf("could not find authorization code in input")
}
- return exchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI)
+ return ExchangeCodeForTokens(cfg, code, pkce.CodeVerifier, redirectURI)
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("authentication timed out after 5 minutes")
}
@@ -186,6 +187,59 @@ type deviceCodeResponse struct {
Interval int
}
+// DeviceCodeInfo holds the device code information returned by the OAuth provider.
+type DeviceCodeInfo struct {
+ DeviceAuthID string `json:"device_auth_id"`
+ UserCode string `json:"user_code"`
+ VerifyURL string `json:"verify_url"`
+ Interval int `json:"interval"`
+}
+
+// RequestDeviceCode requests a device code from the OAuth provider.
+// Returns the info needed for the user to authenticate in a browser.
+func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) {
+ reqBody, _ := json.Marshal(map[string]string{
+ "client_id": cfg.ClientID,
+ })
+
+ resp, err := http.Post(
+ cfg.Issuer+"/api/accounts/deviceauth/usercode",
+ "application/json",
+ strings.NewReader(string(reqBody)),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("requesting device code: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("device code request failed: %s", string(body))
+ }
+
+ deviceResp, err := parseDeviceCodeResponse(body)
+ if err != nil {
+ return nil, fmt.Errorf("parsing device code response: %w", err)
+ }
+
+ if deviceResp.Interval < 1 {
+ deviceResp.Interval = 5
+ }
+
+ return &DeviceCodeInfo{
+ DeviceAuthID: deviceResp.DeviceAuthID,
+ UserCode: deviceResp.UserCode,
+ VerifyURL: cfg.Issuer + "/codex/device",
+ Interval: deviceResp.Interval,
+ }, nil
+}
+
+// PollDeviceCodeOnce makes a single poll attempt to check if the user has authenticated.
+// Returns (credential, nil) on success, (nil, nil) if still pending, or (nil, err) on failure.
+func PollDeviceCodeOnce(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
+ return pollDeviceCode(cfg, deviceAuthID, userCode)
+}
+
func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
var raw struct {
DeviceAuthID string `json:"device_auth_id"`
@@ -318,7 +372,7 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au
}
redirectURI := cfg.Issuer + "/deviceauth/callback"
- return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
+ return ExchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
}
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
@@ -410,7 +464,8 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
}
-func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
+// ExchangeCodeForTokens exchanges an authorization code for tokens.
+func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
@@ -552,7 +607,8 @@ func base64URLDecode(s string) ([]byte, error) {
return base64.StdEncoding.DecodeString(s)
}
-func openBrowser(url string) error {
+// OpenBrowser opens the given URL in the user's default browser.
+func OpenBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go
index 0cb589069..230ac7c2a 100644
--- a/pkg/auth/oauth_test.go
+++ b/pkg/auth/oauth_test.go
@@ -219,9 +219,9 @@ func TestExchangeCodeForTokens(t *testing.T) {
Port: 1455,
}
- cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
+ cred, err := ExchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
if err != nil {
- t.Fatalf("exchangeCodeForTokens() error: %v", err)
+ t.Fatalf("ExchangeCodeForTokens() error: %v", err)
}
if cred.AccessToken != "mock-access-token" {
diff --git a/pkg/auth/store.go b/pkg/auth/store.go
index 64708421b..283dc6977 100644
--- a/pkg/auth/store.go
+++ b/pkg/auth/store.go
@@ -5,6 +5,8 @@ import (
"os"
"path/filepath"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
)
type AuthCredential struct {
@@ -63,16 +65,13 @@ func LoadStore() (*AuthStore, error) {
func SaveStore(store *AuthStore) error {
path := authFilePath()
- dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return err
- }
-
data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return err
}
- return os.WriteFile(path, data, 0o600)
+
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ return fileutil.WriteFileAtomic(path, data, 0o600)
}
func GetCredential(provider string) (*AuthCredential, error) {
diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go
index 58c0a25d5..f5ff9587d 100644
--- a/pkg/bus/bus.go
+++ b/pkg/bus/bus.go
@@ -2,81 +2,156 @@ package bus
import (
"context"
- "sync"
+ "errors"
+ "sync/atomic"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
)
+// ErrBusClosed is returned when publishing to a closed MessageBus.
+var ErrBusClosed = errors.New("message bus closed")
+
+const defaultBusBufferSize = 64
+
type MessageBus struct {
- inbound chan InboundMessage
- outbound chan OutboundMessage
- handlers map[string]MessageHandler
- closed bool
- mu sync.RWMutex
+ inbound chan InboundMessage
+ outbound chan OutboundMessage
+ outboundMedia chan OutboundMediaMessage
+ done chan struct{}
+ closed atomic.Bool
}
func NewMessageBus() *MessageBus {
return &MessageBus{
- inbound: make(chan InboundMessage, 100),
- outbound: make(chan OutboundMessage, 100),
- handlers: make(map[string]MessageHandler),
+ inbound: make(chan InboundMessage, defaultBusBufferSize),
+ outbound: make(chan OutboundMessage, defaultBusBufferSize),
+ outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize),
+ done: make(chan struct{}),
}
}
-func (mb *MessageBus) PublishInbound(msg InboundMessage) {
- mb.mu.RLock()
- defer mb.mu.RUnlock()
- if mb.closed {
- return
+func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
+ if mb.closed.Load() {
+ return ErrBusClosed
+ }
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ select {
+ case mb.inbound <- msg:
+ return nil
+ case <-mb.done:
+ return ErrBusClosed
+ case <-ctx.Done():
+ return ctx.Err()
}
- mb.inbound <- msg
}
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) {
select {
- case msg := <-mb.inbound:
- return msg, true
+ case msg, ok := <-mb.inbound:
+ return msg, ok
+ case <-mb.done:
+ return InboundMessage{}, false
case <-ctx.Done():
return InboundMessage{}, false
}
}
-func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
- mb.mu.RLock()
- defer mb.mu.RUnlock()
- if mb.closed {
- return
+func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
+ if mb.closed.Load() {
+ return ErrBusClosed
+ }
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ select {
+ case mb.outbound <- msg:
+ return nil
+ case <-mb.done:
+ return ErrBusClosed
+ case <-ctx.Done():
+ return ctx.Err()
}
- mb.outbound <- msg
}
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) {
select {
- case msg := <-mb.outbound:
- return msg, true
+ case msg, ok := <-mb.outbound:
+ return msg, ok
+ case <-mb.done:
+ return OutboundMessage{}, false
case <-ctx.Done():
return OutboundMessage{}, false
}
}
-func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) {
- mb.mu.Lock()
- defer mb.mu.Unlock()
- mb.handlers[channel] = handler
+func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
+ if mb.closed.Load() {
+ return ErrBusClosed
+ }
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+ select {
+ case mb.outboundMedia <- msg:
+ return nil
+ case <-mb.done:
+ return ErrBusClosed
+ case <-ctx.Done():
+ return ctx.Err()
+ }
}
-func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
- mb.mu.RLock()
- defer mb.mu.RUnlock()
- handler, ok := mb.handlers[channel]
- return handler, ok
+func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) {
+ select {
+ case msg, ok := <-mb.outboundMedia:
+ return msg, ok
+ case <-mb.done:
+ return OutboundMediaMessage{}, false
+ case <-ctx.Done():
+ return OutboundMediaMessage{}, false
+ }
}
func (mb *MessageBus) Close() {
- mb.mu.Lock()
- defer mb.mu.Unlock()
- if mb.closed {
- return
+ if mb.closed.CompareAndSwap(false, true) {
+ close(mb.done)
+
+ // Drain buffered channels so messages aren't silently lost.
+ // Channels are NOT closed to avoid send-on-closed panics from concurrent publishers.
+ drained := 0
+ for {
+ select {
+ case <-mb.inbound:
+ drained++
+ default:
+ goto doneInbound
+ }
+ }
+ doneInbound:
+ for {
+ select {
+ case <-mb.outbound:
+ drained++
+ default:
+ goto doneOutbound
+ }
+ }
+ doneOutbound:
+ for {
+ select {
+ case <-mb.outboundMedia:
+ drained++
+ default:
+ goto doneMedia
+ }
+ }
+ doneMedia:
+ if drained > 0 {
+ logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
+ "count": drained,
+ })
+ }
}
- mb.closed = true
- close(mb.inbound)
- close(mb.outbound)
}
diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go
new file mode 100644
index 000000000..e07b8c7fe
--- /dev/null
+++ b/pkg/bus/bus_test.go
@@ -0,0 +1,229 @@
+package bus
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestPublishConsume(t *testing.T) {
+ mb := NewMessageBus()
+ defer mb.Close()
+
+ ctx := context.Background()
+
+ msg := InboundMessage{
+ Channel: "test",
+ SenderID: "user1",
+ ChatID: "chat1",
+ Content: "hello",
+ }
+
+ if err := mb.PublishInbound(ctx, msg); err != nil {
+ t.Fatalf("PublishInbound failed: %v", err)
+ }
+
+ got, ok := mb.ConsumeInbound(ctx)
+ if !ok {
+ t.Fatal("ConsumeInbound returned ok=false")
+ }
+ if got.Content != "hello" {
+ t.Fatalf("expected content 'hello', got %q", got.Content)
+ }
+ if got.Channel != "test" {
+ t.Fatalf("expected channel 'test', got %q", got.Channel)
+ }
+}
+
+func TestPublishOutboundSubscribe(t *testing.T) {
+ mb := NewMessageBus()
+ defer mb.Close()
+
+ ctx := context.Background()
+
+ msg := OutboundMessage{
+ Channel: "telegram",
+ ChatID: "123",
+ Content: "world",
+ }
+
+ if err := mb.PublishOutbound(ctx, msg); err != nil {
+ t.Fatalf("PublishOutbound failed: %v", err)
+ }
+
+ got, ok := mb.SubscribeOutbound(ctx)
+ if !ok {
+ t.Fatal("SubscribeOutbound returned ok=false")
+ }
+ if got.Content != "world" {
+ t.Fatalf("expected content 'world', got %q", got.Content)
+ }
+}
+
+func TestPublishInbound_ContextCancel(t *testing.T) {
+ mb := NewMessageBus()
+ defer mb.Close()
+
+ // Fill the buffer
+ ctx := context.Background()
+ for i := range defaultBusBufferSize {
+ if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
+ t.Fatalf("fill failed at %d: %v", i, err)
+ }
+ }
+
+ // Now buffer is full; publish with a canceled context
+ cancelCtx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
+ if err == nil {
+ t.Fatal("expected error from canceled context, got nil")
+ }
+ if err != context.Canceled {
+ t.Fatalf("expected context.Canceled, got %v", err)
+ }
+}
+
+func TestPublishInbound_BusClosed(t *testing.T) {
+ mb := NewMessageBus()
+ mb.Close()
+
+ err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
+ if err != ErrBusClosed {
+ t.Fatalf("expected ErrBusClosed, got %v", err)
+ }
+}
+
+func TestPublishOutbound_BusClosed(t *testing.T) {
+ mb := NewMessageBus()
+ mb.Close()
+
+ err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
+ if err != ErrBusClosed {
+ t.Fatalf("expected ErrBusClosed, got %v", err)
+ }
+}
+
+func TestConsumeInbound_ContextCancel(t *testing.T) {
+ mb := NewMessageBus()
+ defer mb.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ _, ok := mb.ConsumeInbound(ctx)
+ if ok {
+ t.Fatal("expected ok=false when context is canceled")
+ }
+}
+
+func TestConsumeInbound_BusClosed(t *testing.T) {
+ mb := NewMessageBus()
+ mb.Close()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+
+ _, ok := mb.ConsumeInbound(ctx)
+ if ok {
+ t.Fatal("expected ok=false when bus is closed")
+ }
+}
+
+func TestSubscribeOutbound_BusClosed(t *testing.T) {
+ mb := NewMessageBus()
+ mb.Close()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+
+ _, ok := mb.SubscribeOutbound(ctx)
+ if ok {
+ t.Fatal("expected ok=false when bus is closed")
+ }
+}
+
+func TestConcurrentPublishClose(t *testing.T) {
+ mb := NewMessageBus()
+ ctx := context.Background()
+
+ const numGoroutines = 100
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines + 1)
+
+ // Spawn many goroutines trying to publish
+ for range numGoroutines {
+ go func() {
+ defer wg.Done()
+ // Use a short timeout context so we don't block forever after close
+ publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
+ defer cancel()
+ // Errors are expected; we just must not panic or deadlock
+ _ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"})
+ }()
+ }
+
+ // Close from another goroutine
+ go func() {
+ defer wg.Done()
+ time.Sleep(5 * time.Millisecond)
+ mb.Close()
+ }()
+
+ // Must complete without deadlock
+ done := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ // success
+ case <-time.After(5 * time.Second):
+ t.Fatal("test timed out - possible deadlock")
+ }
+}
+
+func TestPublishInbound_FullBuffer(t *testing.T) {
+ mb := NewMessageBus()
+ defer mb.Close()
+
+ ctx := context.Background()
+
+ // Fill the buffer
+ for i := range defaultBusBufferSize {
+ if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
+ t.Fatalf("fill failed at %d: %v", i, err)
+ }
+ }
+
+ // Buffer is full; publish with short timeout
+ timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
+ defer cancel()
+
+ err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
+ if err == nil {
+ t.Fatal("expected error when buffer is full and context times out")
+ }
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected context.DeadlineExceeded, got %v", err)
+ }
+}
+
+func TestCloseIdempotent(t *testing.T) {
+ mb := NewMessageBus()
+
+ // Multiple Close calls must not panic
+ mb.Close()
+ mb.Close()
+ mb.Close()
+
+ // After close, publish should return ErrBusClosed
+ err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
+ if err != ErrBusClosed {
+ t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
+ }
+}
diff --git a/pkg/bus/types.go b/pkg/bus/types.go
index 44f9181a5..7ad8f0417 100644
--- a/pkg/bus/types.go
+++ b/pkg/bus/types.go
@@ -1,11 +1,30 @@
package bus
+// Peer identifies the routing peer for a message (direct, group, channel, etc.)
+type Peer struct {
+ Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
+ ID string `json:"id"`
+}
+
+// SenderInfo provides structured sender identity information.
+type SenderInfo struct {
+ Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ...
+ PlatformID string `json:"platform_id,omitempty"` // raw platform ID, e.g. "123456"
+ CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" format
+ Username string `json:"username,omitempty"` // username (e.g. @alice)
+ DisplayName string `json:"display_name,omitempty"` // display name
+}
+
type InboundMessage struct {
Channel string `json:"channel"`
SenderID string `json:"sender_id"`
+ Sender SenderInfo `json:"sender"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
Media []string `json:"media,omitempty"`
+ Peer Peer `json:"peer"` // routing peer
+ MessageID string `json:"message_id,omitempty"` // platform message ID
+ MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
SessionKey string `json:"session_key"`
Metadata map[string]string `json:"metadata,omitempty"`
}
@@ -16,4 +35,18 @@ type OutboundMessage struct {
Content string `json:"content"`
}
-type MessageHandler func(InboundMessage) error
+// MediaPart describes a single media attachment to send.
+type MediaPart struct {
+ Type string `json:"type"` // "image" | "audio" | "video" | "file"
+ Ref string `json:"ref"` // media store ref, e.g. "media://abc123"
+ Caption string `json:"caption,omitempty"` // optional caption text
+ Filename string `json:"filename,omitempty"` // original filename hint
+ ContentType string `json:"content_type,omitempty"` // MIME type hint
+}
+
+// OutboundMediaMessage carries media attachments from Agent to channels via the bus.
+type OutboundMediaMessage struct {
+ Channel string `json:"channel"`
+ ChatID string `json:"chat_id"`
+ Parts []MediaPart `json:"parts"`
+}
diff --git a/pkg/channels/README.md b/pkg/channels/README.md
new file mode 100644
index 000000000..b7c56660b
--- /dev/null
+++ b/pkg/channels/README.md
@@ -0,0 +1,1384 @@
+# PicoClaw Channel System: Complete Development Guide
+
+> **Scope**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
+
+---
+
+## Table of Contents
+
+- [Part 1: Architecture Overview](#part-1-architecture-overview)
+- [Part 2: Migration Guide — From main Branch to Refactored Branch](#part-2-migration-guide--from-main-branch-to-refactored-branch)
+- [Part 3: New Channel Development Guide — Implementing a Channel from Scratch](#part-3-new-channel-development-guide--implementing-a-channel-from-scratch)
+- [Part 4: Core Subsystem Details](#part-4-core-subsystem-details)
+- [Part 5: Key Design Decisions and Conventions](#part-5-key-design-decisions-and-conventions)
+- [Appendix: Complete File Listing and Interface Quick Reference](#appendix-complete-file-listing-and-interface-quick-reference)
+
+---
+
+## Part 1: Architecture Overview
+
+### 1.1 Before and After Comparison
+
+**Before Refactor (main branch)**:
+
+```
+pkg/channels/
+├── telegram.go # Each channel directly in the channels package
+├── discord.go
+├── slack.go
+├── manager.go # Manager directly references each channel type
+├── ...
+```
+
+- All channel implementations lived at the top level of `pkg/channels/`
+- Manager constructed each channel via `switch` or `if-else` chains
+- Routing info like Peer and MessageID was buried in `Metadata map[string]string`
+- No rate limiting or retry on message sending
+- No unified media file lifecycle management
+- Each channel ran its own HTTP server
+- Group chat trigger filtering logic was scattered across channels
+
+**After Refactor (refactor/channel-system branch)**:
+
+```
+pkg/channels/
+├── base.go # BaseChannel shared abstraction layer
+├── interfaces.go # Optional capability interfaces (TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
+├── README.md # English documentation
+├── README.zh.md # Chinese documentation
+├── media.go # MediaSender optional interface
+├── webhook.go # WebhookHandler, HealthChecker optional interfaces
+├── errors.go # Sentinel errors (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
+├── errutil.go # Error classification helpers
+├── registry.go # Factory registry (RegisterFactory / getFactory)
+├── manager.go # Unified orchestration: Worker queues, rate limiting, retries, Typing/Placeholder, shared HTTP
+├── split.go # Smart long-message splitting (preserves code block integrity)
+├── telegram/ # Each channel in its own sub-package
+│ ├── init.go # Factory registration
+│ ├── telegram.go # Implementation
+│ └── telegram_commands.go
+├── discord/
+│ ├── init.go
+│ └── discord.go
+├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
+│ └── ...
+
+pkg/bus/
+├── bus.go # MessageBus (buffer 64, safe close + drain)
+├── types.go # Structured message types (Peer, SenderInfo, MediaPart, InboundMessage, OutboundMessage, OutboundMediaMessage)
+
+pkg/media/
+├── store.go # MediaStore interface + FileMediaStore implementation (two-phase release, TTL cleanup)
+
+pkg/identity/
+├── identity.go # Unified user identity: canonical "platform:id" format + backward-compatible matching
+```
+
+### 1.2 Message Flow Overview
+
+```
+┌────────────┐ InboundMessage ┌───────────┐ LLM + Tools ┌────────────┐
+│ Telegram │──┐ │ │ │ │
+│ Discord │──┤ PublishInbound() │ │ PublishOutbound() │ │
+│ Slack │──┼──────────────────────▶ │ MessageBus │ ◀─────────────────── │ AgentLoop │
+│ LINE │──┤ (buffered chan, 64) │ │ (buffered chan, 64) │ │
+│ ... │──┘ │ │ │ │
+└────────────┘ └─────┬─────┘ └────────────┘
+ │
+ SubscribeOutbound() │ SubscribeOutboundMedia()
+ ▼
+ ┌───────────────────┐
+ │ Manager │
+ │ ├── dispatchOutbound() Route to Worker queues
+ │ ├── dispatchOutboundMedia()
+ │ ├── runWorker() Message split + sendWithRetry()
+ │ ├── runMediaWorker() sendMediaWithRetry()
+ │ ├── preSend() Stop Typing + Undo Reaction + Edit Placeholder
+ │ └── runTTLJanitor() Clean up expired Typing/Placeholder
+ └────────┬──────────┘
+ │
+ channel.Send() / SendMedia()
+ │
+ ▼
+ ┌────────────────┐
+ │ Platform APIs │
+ └────────────────┘
+```
+
+### 1.3 Key Design Principles
+
+| Principle | Description |
+|-----------|-------------|
+| **Sub-package Isolation** | Each channel is a standalone Go sub-package, depending on `BaseChannel` and interfaces from the `channels` parent package |
+| **Factory Registration** | Sub-packages self-register via `init()`, Manager looks up factories by name, eliminating import coupling |
+| **Capability Discovery** | Optional capabilities are declared via interfaces (`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`), discovered by Manager via runtime type assertions |
+| **Structured Messages** | Peer, MessageID, and SenderInfo promoted from Metadata to first-class fields on InboundMessage |
+| **Error Classification** | Channels return sentinel errors (`ErrRateLimit`, `ErrTemporary`, etc.), Manager uses these to determine retry strategy |
+| **Centralized Orchestration** | Rate limiting, message splitting, retries, and Typing/Reaction/Placeholder management are all handled by Manager and BaseChannel; channels only need to implement Send |
+
+---
+
+## Part 2: Migration Guide — From main Branch to Refactored Branch
+
+### 2.1 If You Have Unmerged Channel Changes
+
+#### Step 1: Identify which files you modified
+
+On the main branch, channel files were directly in `pkg/channels/` top level, e.g.:
+- `pkg/channels/telegram.go`
+- `pkg/channels/discord.go`
+
+After refactoring, these files have been removed and code moved to corresponding sub-packages:
+- `pkg/channels/telegram/telegram.go`
+- `pkg/channels/discord/discord.go`
+
+#### Step 2: Understand the structural change mapping
+
+| main branch file | Refactored branch location | Changes |
+|---|---|---|
+| `pkg/channels/telegram.go` | `pkg/channels/telegram/telegram.go` + `init.go` | Package name changed from `channels` to `telegram` |
+| `pkg/channels/discord.go` | `pkg/channels/discord/discord.go` + `init.go` | Same as above |
+| `pkg/channels/manager.go` | `pkg/channels/manager.go` | Extensively rewritten |
+| _(did not exist)_ | `pkg/channels/base.go` | New shared abstraction layer |
+| _(did not exist)_ | `pkg/channels/registry.go` | New factory registry |
+| _(did not exist)_ | `pkg/channels/errors.go` + `errutil.go` | New error classification system |
+| _(did not exist)_ | `pkg/channels/interfaces.go` | New optional capability interfaces |
+| _(did not exist)_ | `pkg/channels/media.go` | New MediaSender interface |
+| _(did not exist)_ | `pkg/channels/webhook.go` | New WebhookHandler/HealthChecker |
+| _(did not exist)_ | `pkg/channels/whatsapp_native/` | New WhatsApp native mode (whatsmeow) |
+| _(did not exist)_ | `pkg/channels/split.go` | New message splitting (migrated from utils) |
+| _(did not exist)_ | `pkg/bus/types.go` | New structured message types |
+| _(did not exist)_ | `pkg/media/store.go` | New media file lifecycle management |
+| _(did not exist)_ | `pkg/identity/identity.go` | New unified user identity |
+
+#### Step 3: Migrate your channel code
+
+Using Telegram as an example, the main changes are:
+
+**3a. Package declaration and imports**
+
+```go
+// Old code (main branch)
+package channels
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// New code (refactored branch)
+package telegram
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels" // Reference parent package
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity" // New
+ "github.com/sipeed/picoclaw/pkg/media" // New (if media support needed)
+)
+```
+
+**3b. Struct embeds BaseChannel**
+
+```go
+// Old code: directly held bus, config, etc. fields
+type TelegramChannel struct {
+ bus *bus.MessageBus
+ config *config.Config
+ running bool
+ allowList []string
+ // ...
+}
+
+// New code: embed BaseChannel, which provides bus, running, allowList, etc.
+type TelegramChannel struct {
+ *channels.BaseChannel // Embed shared abstraction
+ bot *telego.Bot
+ config *config.Config
+ // ... only channel-specific fields
+}
+```
+
+**3c. Constructor**
+
+```go
+// Old code: direct assignment
+func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
+ return &TelegramChannel{
+ bus: bus,
+ config: cfg,
+ allowList: cfg.Channels.Telegram.AllowFrom,
+ // ...
+ }, nil
+}
+
+// New code: use NewBaseChannel + functional options
+func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
+ base := channels.NewBaseChannel(
+ "telegram", // Name
+ cfg.Channels.Telegram, // Raw config (any type)
+ bus, // Message bus
+ cfg.Channels.Telegram.AllowFrom, // Allow list
+ channels.WithMaxMessageLength(4096), // Platform message length limit
+ channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // Group trigger config
+ channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // Reasoning chain routing
+ )
+ return &TelegramChannel{
+ BaseChannel: base,
+ bot: bot,
+ config: cfg,
+ }, nil
+}
+```
+
+**3d. Start/Stop lifecycle**
+
+```go
+// New code: use SetRunning atomic operation
+func (c *TelegramChannel) Start(ctx context.Context) error {
+ // ... initialize bot, webhook, etc.
+ c.SetRunning(true) // Must be called after ready
+ go bh.Start()
+ return nil
+}
+
+func (c *TelegramChannel) Stop(ctx context.Context) error {
+ c.SetRunning(false) // Must be called before cleanup
+ // ... stop bot handler, cancel context
+ return nil
+}
+```
+
+**3e. Send method error returns**
+
+```go
+// Old code: returns plain error
+func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.running { return fmt.Errorf("not running") }
+ // ...
+ if err != nil { return err }
+}
+
+// New code: must return sentinel errors for Manager to determine retry strategy
+func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning // ← Manager will not retry
+ }
+ // ...
+ if err != nil {
+ // Use ClassifySendError to wrap error based on HTTP status code
+ return channels.ClassifySendError(statusCode, err)
+ // Or manually wrap:
+ // return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
+ // return fmt.Errorf("%w: %v", channels.ErrRateLimit, err)
+ // return fmt.Errorf("%w: %v", channels.ErrSendFailed, err)
+ }
+ return nil
+}
+```
+
+**3f. Message reception (Inbound)**
+
+```go
+// Old code: directly construct InboundMessage and publish
+msg := bus.InboundMessage{
+ Channel: "telegram",
+ SenderID: senderID,
+ ChatID: chatID,
+ Content: content,
+ Metadata: map[string]string{
+ "peer_kind": "group", // Routing info buried in metadata
+ "peer_id": chatID,
+ "message_id": msgID,
+ },
+}
+c.bus.PublishInbound(ctx, msg)
+
+// New code: use BaseChannel.HandleMessage with structured fields
+sender := bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: strconv.FormatInt(from.ID, 10),
+ CanonicalID: identity.BuildCanonicalID("telegram", strconv.FormatInt(from.ID, 10)),
+ Username: from.Username,
+ DisplayName: from.FirstName,
+}
+
+peer := bus.Peer{
+ Kind: "group", // or "direct"
+ ID: chatID,
+}
+
+// HandleMessage internally calls IsAllowedSender for permission checks, builds MediaScope, and publishes to bus
+c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, sender)
+```
+
+**3g. Add factory registration (required)**
+
+Create `init.go` for your channel:
+
+```go
+// pkg/channels/telegram/init.go
+package telegram
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewTelegramChannel(cfg, b)
+ })
+}
+```
+
+**3h. Import sub-package in Gateway**
+
+```go
+// cmd/picoclaw/internal/gateway/helpers.go
+import (
+ _ "github.com/sipeed/picoclaw/pkg/channels/telegram" // Triggers init() registration
+ _ "github.com/sipeed/picoclaw/pkg/channels/discord"
+ _ "github.com/sipeed/picoclaw/pkg/channels/your_new_channel" // New addition
+)
+```
+
+#### Step 4: Migrate bus message usage
+
+If your code directly reads routing fields from `InboundMessage.Metadata`:
+
+```go
+// Old code
+peerKind := msg.Metadata["peer_kind"]
+peerID := msg.Metadata["peer_id"]
+msgID := msg.Metadata["message_id"]
+
+// New code
+peerKind := msg.Peer.Kind // First-class field
+peerID := msg.Peer.ID // First-class field
+msgID := msg.MessageID // First-class field
+sender := msg.Sender // bus.SenderInfo struct
+scope := msg.MediaScope // Media lifecycle scope
+```
+
+#### Step 5: Migrate allow-list checks
+
+```go
+// Old code
+if !c.isAllowed(senderID) { return }
+
+// New code: prefer structured check
+if !c.IsAllowedSender(sender) { return }
+// Or fall back to string check:
+if !c.IsAllowed(senderID) { return }
+```
+
+`BaseChannel.HandleMessage` already handles this logic internally — no need to duplicate the check in your channel.
+
+### 2.2 If You Have Manager Modifications
+
+The Manager has been completely rewritten. Your modifications will need to account for the new architecture:
+
+| Old Manager Responsibility | New Manager Responsibility |
+|---|---|
+| Directly construct channels (switch/if-else) | Look up and construct via factory registry |
+| Directly call channel.Send | Per-channel Worker queues + rate limiting + retries |
+| No message splitting | Automatic splitting based on MaxMessageLength |
+| Each channel runs its own HTTP server | Unified shared HTTP server |
+| No Typing/Placeholder management | Unified preSend handles Typing stop + Reaction undo + Placeholder edit; inbound-side BaseChannel.HandleMessage auto-orchestrates Typing/Reaction/Placeholder |
+| No TTL cleanup | runTTLJanitor periodically cleans up expired Typing/Reaction/Placeholder entries |
+
+### 2.3 If You Have Agent Loop Modifications
+
+Main changes to the Agent Loop:
+
+1. **MediaStore injection**: `agentLoop.SetMediaStore(mediaStore)` — Agent resolves media references produced by tools via MediaStore
+2. **ChannelManager injection**: `agentLoop.SetChannelManager(channelManager)` — Agent can query channel state
+3. **OutboundMediaMessage**: Agent now sends media messages via `bus.PublishOutboundMedia()` instead of embedding them in text replies
+4. **extractPeer**: Routing uses `msg.Peer` structured fields instead of Metadata lookups
+
+---
+
+## Part 3: New Channel Development Guide — Implementing a Channel from Scratch
+
+### 3.1 Minimum Implementation Checklist
+
+To add a new chat platform (e.g., `matrix`), you need to:
+
+1. ✅ Create sub-package directory `pkg/channels/matrix/`
+2. ✅ Create `init.go` — factory registration
+3. ✅ Create `matrix.go` — channel implementation
+4. ✅ Add blank import in Gateway helpers
+5. ✅ Add config check in Manager.initChannels()
+6. ✅ Add config struct in `pkg/config/`
+
+### 3.2 Complete Template
+
+#### `pkg/channels/matrix/init.go`
+
+```go
+package matrix
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewMatrixChannel(cfg, b)
+ })
+}
+```
+
+#### `pkg/channels/matrix/matrix.go`
+
+```go
+package matrix
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// MatrixChannel implements channels.Channel for the Matrix protocol.
+type MatrixChannel struct {
+ *channels.BaseChannel // Must embed
+ config *config.Config
+ ctx context.Context
+ cancel context.CancelFunc
+ // ... Matrix SDK client, etc.
+}
+
+func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChannel, error) {
+ matrixCfg := cfg.Channels.Matrix // Assumes this field exists in config
+
+ base := channels.NewBaseChannel(
+ "matrix", // Channel name (globally unique)
+ matrixCfg, // Raw config
+ msgBus, // Message bus
+ matrixCfg.AllowFrom, // Allow list
+ channels.WithMaxMessageLength(65536), // Matrix message length limit
+ channels.WithGroupTrigger(matrixCfg.GroupTrigger),
+ channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // Reasoning chain routing (optional)
+ )
+
+ return &MatrixChannel{
+ BaseChannel: base,
+ config: cfg,
+ }, nil
+}
+
+// ========== Required Channel Interface Methods ==========
+
+func (c *MatrixChannel) Start(ctx context.Context) error {
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ // 1. Initialize Matrix client
+ // 2. Start listening for messages
+ // 3. Mark as running
+ c.SetRunning(true)
+
+ logger.InfoC("matrix", "Matrix channel started")
+ return nil
+}
+
+func (c *MatrixChannel) Stop(ctx context.Context) error {
+ c.SetRunning(false)
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ logger.InfoC("matrix", "Matrix channel stopped")
+ return nil
+}
+
+func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ // 1. Check running state
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ // 2. Send message to Matrix
+ err := c.sendToMatrix(ctx, msg.ChatID, msg.Content)
+ if err != nil {
+ // 3. Must use error classification wrapping
+ // If you have an HTTP status code:
+ // return channels.ClassifySendError(statusCode, err)
+ // If it's a network error:
+ // return channels.ClassifyNetError(err)
+ // If manual classification is needed:
+ return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
+ }
+
+ return nil
+}
+
+// ========== Incoming Message Handling ==========
+
+func (c *MatrixChannel) handleIncoming(roomID, senderID, displayName, content string, msgID string) {
+ // 1. Construct structured sender identity
+ sender := bus.SenderInfo{
+ Platform: "matrix",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("matrix", senderID),
+ Username: senderID,
+ DisplayName: displayName,
+ }
+
+ // 2. Determine Peer type (direct vs group)
+ peer := bus.Peer{
+ Kind: "group", // or "direct"
+ ID: roomID,
+ }
+
+ // 3. Group chat filtering (if applicable)
+ isGroup := peer.Kind == "group"
+ if isGroup {
+ isMentioned := false // Detect @mentions based on platform specifics
+ shouldRespond, cleanContent := c.ShouldRespondInGroup(isMentioned, content)
+ if !shouldRespond {
+ return
+ }
+ content = cleanContent
+ }
+
+ // 4. Handle media attachments (if any)
+ var mediaRefs []string
+ store := c.GetMediaStore()
+ if store != nil {
+ // Download attachment locally → store.Store() → get ref
+ // mediaRefs = append(mediaRefs, ref)
+ }
+
+ // 5. Call HandleMessage to publish to bus
+ // HandleMessage internally will:
+ // - Check IsAllowedSender/IsAllowed
+ // - Build MediaScope
+ // - Publish InboundMessage
+ c.HandleMessage(
+ c.ctx,
+ peer,
+ msgID, // Platform message ID
+ senderID, // Raw sender ID
+ roomID, // Chat/room ID
+ content, // Message content
+ mediaRefs, // Media reference list
+ nil, // Extra metadata (usually nil)
+ sender, // SenderInfo (variadic parameter)
+ )
+}
+
+// ========== Internal Methods ==========
+
+func (c *MatrixChannel) sendToMatrix(ctx context.Context, roomID, content string) error {
+ // Actual Matrix SDK call
+ return nil
+}
+```
+
+### 3.3 Optional Capability Interfaces
+
+Depending on platform capabilities, your channel can optionally implement the following interfaces:
+
+#### MediaSender — Send Media Attachments
+
+```go
+// If the platform supports sending images/files/audio/video
+func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store: %w", channels.ErrSendFailed)
+ }
+
+ for _, part := range msg.Parts {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("matrix", "Failed to resolve media", map[string]any{
+ "ref": part.Ref, "error": err.Error(),
+ })
+ continue
+ }
+
+ // Call the appropriate API based on part.Type ("image"|"audio"|"video"|"file")
+ switch part.Type {
+ case "image":
+ // Upload image to Matrix
+ default:
+ // Upload file to Matrix
+ }
+ }
+ return nil
+}
+```
+
+#### TypingCapable — Typing Indicator
+
+```go
+// If the platform supports "typing..." indicators
+func (c *MatrixChannel) StartTyping(ctx context.Context, chatID string) (stop func(), err error) {
+ // Call Matrix API to send typing indicator
+ // The returned stop function must be idempotent
+ stopped := false
+ return func() {
+ if !stopped {
+ stopped = true
+ // Call Matrix API to stop typing
+ }
+ }, nil
+}
+```
+
+#### ReactionCapable — Message Reaction Indicator
+
+```go
+// If the platform supports adding emoji reactions to inbound messages (e.g., Slack's 👀, OneBot's emoji 289)
+func (c *MatrixChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) {
+ // Call Matrix API to add reaction to message
+ // The returned undo function removes the reaction, must be idempotent
+ err = c.addReaction(chatID, messageID, "eyes")
+ if err != nil {
+ return func() {}, err
+ }
+ return func() {
+ c.removeReaction(chatID, messageID, "eyes")
+ }, nil
+}
+```
+
+#### MessageEditor — Message Editing
+
+```go
+// If the platform supports editing sent messages (used for Placeholder replacement)
+func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
+ // Call Matrix API to edit message
+ return nil
+}
+```
+
+#### PlaceholderCapable — Placeholder Messages
+
+```go
+// If the platform supports sending placeholder messages (e.g. "Thinking... 💭"),
+// and the channel also implements MessageEditor, then Manager's preSend will
+// automatically edit the placeholder into the final response on outbound.
+// SendPlaceholder checks PlaceholderConfig.Enabled internally;
+// returning ("", nil) means skip.
+func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ cfg := c.config.Channels.Matrix.Placeholder
+ if !cfg.Enabled {
+ return "", nil
+ }
+ text := cfg.Text
+ if text == "" {
+ text = "Thinking... 💭"
+ }
+ // Call Matrix API to send placeholder message
+ msg, err := c.sendText(ctx, chatID, text)
+ if err != nil {
+ return "", err
+ }
+ return msg.ID, nil
+}
+```
+
+#### WebhookHandler — HTTP Webhook Reception
+
+```go
+// If the channel receives messages via webhook (rather than long-polling/WebSocket)
+func (c *MatrixChannel) WebhookPath() string {
+ return "/webhook/matrix" // Path will be registered on the shared HTTP server
+}
+
+func (c *MatrixChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Handle webhook request
+}
+```
+
+#### HealthChecker — Health Check Endpoint
+
+```go
+func (c *MatrixChannel) HealthPath() string {
+ return "/health/matrix"
+}
+
+func (c *MatrixChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
+ if c.IsRunning() {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("OK"))
+ } else {
+ w.WriteHeader(http.StatusServiceUnavailable)
+ }
+}
+```
+
+### 3.4 Inbound-side Typing/Reaction/Placeholder Auto-orchestration
+
+`BaseChannel.HandleMessage` automatically detects whether the channel implements `TypingCapable`, `ReactionCapable`, and/or `PlaceholderCapable` **before** publishing the inbound message, and triggers the corresponding indicators. The three pipelines are completely independent and do not interfere with each other:
+
+```go
+// Automatically executed inside BaseChannel.HandleMessage (no manual calls needed):
+if c.owner != nil && c.placeholderRecorder != nil {
+ // Typing — independent pipeline
+ if tc, ok := c.owner.(TypingCapable); ok {
+ if stop, err := tc.StartTyping(ctx, chatID); err == nil {
+ c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
+ }
+ }
+ // Reaction — independent pipeline
+ if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
+ if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
+ c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
+ }
+ }
+ // Placeholder — independent pipeline
+ if pc, ok := c.owner.(PlaceholderCapable); ok {
+ if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" {
+ c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID)
+ }
+ }
+}
+```
+
+**This means**:
+- Channels implementing `TypingCapable` (Telegram, Discord, LINE, Pico) do not need to manually call `StartTyping` + `RecordTypingStop` in `handleMessage`
+- Channels implementing `ReactionCapable` (Slack, OneBot) do not need to manually call `AddReaction` + `RecordTypingStop` in `handleMessage`
+- Channels implementing `PlaceholderCapable` (Telegram, Discord, Pico) do not need to manually send placeholder messages and call `RecordPlaceholder` in `handleMessage`
+- Channels only need to implement the corresponding interface; `HandleMessage` handles orchestration automatically
+- Channels that don't implement these interfaces are unaffected (type assertions will fail and be skipped)
+- `PlaceholderCapable`'s `SendPlaceholder` method internally decides whether to send based on the configured `PlaceholderConfig.Enabled`; returning `("", nil)` skips registration
+
+**Owner Injection**: Manager automatically calls `SetOwner(ch)` in `initChannel` to inject the concrete channel into BaseChannel — no manual setup required from developers.
+
+When the Agent finishes processing a message, Manager's `preSend` automatically:
+1. Calls the recorded `stop()` to stop Typing
+2. Calls the recorded `undo()` to undo Reaction
+3. If there is a Placeholder and the channel implements `MessageEditor`, attempts to edit the Placeholder with the final reply (skipping Send)
+
+### 3.5 Register Configuration and Gateway Integration
+
+#### Add configuration in `pkg/config/config.go`
+
+```go
+type ChannelsConfig struct {
+ // ... existing channels
+ Matrix MatrixChannelConfig `json:"matrix"`
+}
+
+type MatrixChannelConfig struct {
+ Enabled bool `json:"enabled"`
+ HomeServer string `json:"home_server"`
+ Token string `json:"token"`
+ AllowFrom []string `json:"allow_from"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger"`
+ Placeholder PlaceholderConfig `json:"placeholder"`
+ ReasoningChannelID string `json:"reasoning_channel_id"`
+}
+```
+
+#### Add entry in Manager.initChannels()
+
+```go
+// In the initChannels() method of pkg/channels/manager.go
+if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
+ m.initChannel("matrix", "Matrix")
+}
+```
+
+> **Note**: If your channel has multiple modes (like WhatsApp Bridge vs Native), branch in initChannels based on config:
+> ```go
+> if cfg.UseNative {
+> m.initChannel("whatsapp_native", "WhatsApp Native")
+> } else {
+> m.initChannel("whatsapp", "WhatsApp")
+> }
+> ```
+
+#### Add blank import in Gateway
+
+```go
+// cmd/picoclaw/internal/gateway/helpers.go
+import (
+ _ "github.com/sipeed/picoclaw/pkg/channels/matrix"
+)
+```
+
+---
+
+## Part 4: Core Subsystem Details
+
+### 4.1 MessageBus
+
+**Files**: `pkg/bus/bus.go`, `pkg/bus/types.go`
+
+```go
+type MessageBus struct {
+ inbound chan InboundMessage // buffer = 64
+ outbound chan OutboundMessage // buffer = 64
+ outboundMedia chan OutboundMediaMessage // buffer = 64
+ done chan struct{} // Close signal
+ closed atomic.Bool // Prevents double-close
+}
+```
+
+**Key Behaviors**:
+
+| Method | Behavior |
+|--------|----------|
+| `PublishInbound(ctx, msg)` | Check closed → send to inbound channel → block/timeout/close |
+| `ConsumeInbound(ctx)` | Read from inbound → block/close/cancel |
+| `PublishOutbound(ctx, msg)` | Send to outbound channel |
+| `SubscribeOutbound(ctx)` | Read from outbound (called by Manager dispatcher) |
+| `PublishOutboundMedia(ctx, msg)` | Send to outboundMedia channel |
+| `SubscribeOutboundMedia(ctx)` | Read from outboundMedia (called by Manager media dispatcher) |
+| `Close()` | CAS close → close(done) → drain all channels (**does not close the channels themselves** to avoid concurrent send-on-closed panic) |
+
+**Design Notes**:
+- Buffer size increased from 16 to 64 to reduce blocking under burst load
+- `Close()` does not close the underlying channels (only closes the `done` signal channel), because there may be concurrent `Publish` goroutines
+- Drain loop ensures buffered messages are not silently dropped
+
+### 4.2 Structured Message Types
+
+**File**: `pkg/bus/types.go`
+
+```go
+// Routing peer
+type Peer struct {
+ Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
+ ID string `json:"id"`
+}
+
+// Sender identity information
+type SenderInfo struct {
+ Platform string `json:"platform,omitempty"` // "telegram", "discord", ...
+ PlatformID string `json:"platform_id,omitempty"` // Platform-native ID
+ CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" canonical format
+ Username string `json:"username,omitempty"`
+ DisplayName string `json:"display_name,omitempty"`
+}
+
+// Inbound message
+type InboundMessage struct {
+ Channel string // Source channel name
+ SenderID string // Sender ID (prefer CanonicalID)
+ Sender SenderInfo // Structured sender info
+ ChatID string // Chat/room ID
+ Content string // Message text
+ Media []string // Media reference list (media://...)
+ Peer Peer // Routing peer (first-class field)
+ MessageID string // Platform message ID (first-class field)
+ MediaScope string // Media lifecycle scope
+ SessionKey string // Session key
+ Metadata map[string]string // Only for channel-specific extensions
+}
+
+// Outbound text message
+type OutboundMessage struct {
+ Channel string
+ ChatID string
+ Content string
+}
+
+// Outbound media message
+type OutboundMediaMessage struct {
+ Channel string
+ ChatID string
+ Parts []MediaPart
+}
+
+// Media part
+type MediaPart struct {
+ Type string // "image" | "audio" | "video" | "file"
+ Ref string // "media://uuid"
+ Caption string
+ Filename string
+ ContentType string
+}
+```
+
+### 4.3 BaseChannel
+
+**File**: `pkg/channels/base.go`
+
+BaseChannel is the shared abstraction layer for all channels, providing the following capabilities:
+
+| Method/Feature | Description |
+|---|---|
+| `Name() string` | Channel name |
+| `IsRunning() bool` | Atomically read running state |
+| `SetRunning(bool)` | Atomically set running state |
+| `MaxMessageLength() int` | Message length limit (rune count), 0 = unlimited |
+| `ReasoningChannelID() string` | Reasoning chain routing target channel ID (empty = no routing) |
+| `IsAllowed(senderID string) bool` | Legacy allow-list check (supports `"id\|username"` and `"@username"` formats) |
+| `IsAllowedSender(sender SenderInfo) bool` | New allow-list check (delegates to `identity.MatchAllowed`) |
+| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | Unified group chat trigger filtering logic |
+| `HandleMessage(...)` | Unified inbound message handling: permission check → build MediaScope → auto-trigger Typing/Reaction/Placeholder → publish to Bus |
+| `SetMediaStore(s) / GetMediaStore()` | MediaStore injected by Manager |
+| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | PlaceholderRecorder injected by Manager |
+| `SetOwner(ch)` | Concrete channel reference injected by Manager (used for Typing/Reaction/Placeholder type assertions in HandleMessage) |
+
+**Functional Options**:
+
+```go
+channels.WithMaxMessageLength(4096) // Set platform message length limit
+channels.WithGroupTrigger(groupTriggerCfg) // Set group trigger configuration
+channels.WithReasoningChannelID(id) // Set reasoning chain routing target channel
+```
+
+### 4.4 Factory Registry
+
+**File**: `pkg/channels/registry.go`
+
+```go
+type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
+
+func RegisterFactory(name string, f ChannelFactory) // Called in sub-package init()
+func getFactory(name string) (ChannelFactory, bool) // Called internally by Manager
+```
+
+The factory registry is protected by `sync.RWMutex` and registrations occur during `init()` phase (completed at process startup). Manager looks up factories by name in `initChannel()` and calls them.
+
+### 4.5 Error Classification and Retries
+
+**Files**: `pkg/channels/errors.go`, `pkg/channels/errutil.go`
+
+#### Sentinel Errors
+
+```go
+var (
+ ErrNotRunning = errors.New("channel not running") // Permanent: do not retry
+ ErrRateLimit = errors.New("rate limited") // Fixed delay: retry after 1s
+ ErrTemporary = errors.New("temporary failure") // Exponential backoff: 500ms * 2^attempt, max 8s
+ ErrSendFailed = errors.New("send failed") // Permanent: do not retry
+)
+```
+
+#### Error Classification Helpers
+
+```go
+// Automatically classify based on HTTP status code
+func ClassifySendError(statusCode int, rawErr error) error {
+ // 429 → ErrRateLimit
+ // 5xx → ErrTemporary
+ // 4xx → ErrSendFailed
+}
+
+// Wrap network errors as temporary
+func ClassifyNetError(err error) error {
+ // → ErrTemporary
+}
+```
+
+#### Manager Retry Strategy (`sendWithRetry`)
+
+```
+Max retries: 3
+Rate limit delay: 1 second
+Base backoff: 500 milliseconds
+Max backoff: 8 seconds
+
+Retry logic:
+ ErrNotRunning → Fail immediately, no retry
+ ErrSendFailed → Fail immediately, no retry
+ ErrRateLimit → Wait 1s → retry
+ ErrTemporary → Wait 500ms * 2^attempt (max 8s) → retry
+ Other unknown → Wait 500ms * 2^attempt (max 8s) → retry
+```
+
+### 4.6 Manager Orchestration
+
+**File**: `pkg/channels/manager.go`
+
+#### Per-channel Worker Architecture
+
+```go
+type channelWorker struct {
+ ch Channel // Channel instance
+ queue chan bus.OutboundMessage // Outbound text queue (buffered 16)
+ mediaQueue chan bus.OutboundMediaMessage // Outbound media queue (buffered 16)
+ done chan struct{} // Text worker completion signal
+ mediaDone chan struct{} // Media worker completion signal
+ limiter *rate.Limiter // Per-channel rate limiter
+}
+```
+
+#### Per-channel Rate Limit Configuration
+
+```go
+var channelRateConfig = map[string]float64{
+ "telegram": 20, // 20 msg/s
+ "discord": 1, // 1 msg/s
+ "slack": 1, // 1 msg/s
+ "line": 10, // 10 msg/s
+}
+// Default: 10 msg/s
+// burst = max(1, ceil(rate/2))
+```
+
+#### Lifecycle Management
+
+```
+StartAll:
+ 1. Iterate registered channels → channel.Start(ctx)
+ 2. Create channelWorker for each successfully started channel
+ 3. Start goroutines:
+ - runWorker (per-channel outbound text)
+ - runMediaWorker (per-channel outbound media)
+ - dispatchOutbound (route from bus to worker queues)
+ - dispatchOutboundMedia (route from bus to media worker queues)
+ - runTTLJanitor (every 10s clean up expired typing/reaction/placeholder)
+ 4. Start shared HTTP server (if configured)
+
+StopAll:
+ 1. Shut down shared HTTP server (5s timeout)
+ 2. Cancel dispatcher context
+ 3. Close text worker queues → wait for drain to complete
+ 4. Close media worker queues → wait for drain to complete
+ 5. Stop each channel (channel.Stop)
+```
+
+#### Typing/Reaction/Placeholder Management
+
+```go
+// Manager implements PlaceholderRecorder interface
+func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string)
+func (m *Manager) RecordTypingStop(channel, chatID string, stop func())
+func (m *Manager) RecordReactionUndo(channel, chatID string, undo func())
+
+// Inbound side: BaseChannel.HandleMessage auto-orchestrates
+// BaseChannel.HandleMessage, before PublishInbound, auto-triggers via owner type assertions:
+// - TypingCapable.StartTyping → RecordTypingStop
+// - ReactionCapable.ReactToMessage → RecordReactionUndo
+// - PlaceholderCapable.SendPlaceholder → RecordPlaceholder
+// All three are independent and do not interfere with each other. Channels don't need to call these manually.
+
+// Outbound side: pre-send processing
+func (m *Manager) preSend(ctx, name, msg, ch) bool {
+ key := name + ":" + msg.ChatID
+ // 1. Stop Typing (call stored stop function)
+ // 2. Undo Reaction (call stored undo function)
+ // 3. Attempt to edit Placeholder (if channel implements MessageEditor)
+ // Success → return true (skip Send)
+ // Failure → return false (proceed with Send)
+}
+```
+
+Manager storage is fully separated; three pipelines do not interfere:
+
+```go
+Manager {
+ typingStops sync.Map // "channel:chatID" → typingEntry ← manages TypingCapable
+ reactionUndos sync.Map // "channel:chatID" → reactionEntry ← manages ReactionCapable
+ placeholders sync.Map // "channel:chatID" → placeholderEntry
+}
+```
+
+TTL Cleanup:
+- Typing stop functions: 5-minute TTL (auto-calls stop and deletes on expiry)
+- Reaction undo functions: 5-minute TTL (auto-calls undo and deletes on expiry)
+- Placeholder IDs: 10-minute TTL (deletes on expiry)
+- Cleanup interval: 10 seconds
+
+### 4.7 Message Splitting
+
+**File**: `pkg/channels/split.go`
+
+`SplitMessage(content string, maxLen int) []string`
+
+Smart splitting strategy:
+1. Calculate effective split point = maxLen - 10% buffer (to reserve space for code block closure)
+2. Prefer splitting at newlines
+3. Otherwise split at spaces/tabs
+4. Detect unclosed code blocks (` ``` `)
+5. If a code block is unclosed:
+ - Attempt to extend to maxLen to include the closing fence
+ - If the code block is too long, inject close/reopen fences (`\n```\n` + header)
+ - Last resort: split before the code block starts
+
+### 4.8 MediaStore
+
+**File**: `pkg/media/store.go`
+
+```go
+type MediaStore interface {
+ Store(localPath string, meta MediaMeta, scope string) (ref string, err error)
+ Resolve(ref string) (localPath string, err error)
+ ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error)
+ ReleaseAll(scope string) error
+}
+```
+
+**FileMediaStore Implementation**:
+- Pure in-memory mapping, no file copy/move
+- Reference format: `media://`
+- Scope format: `channel:chatID:messageID` (generated by `BuildMediaScope`)
+- **Two-phase operation**:
+ - Phase 1 (holding lock): collect and delete entries from map
+ - Phase 2 (no lock): delete files from disk
+ - Purpose: minimize lock contention
+- **TTL Cleanup**: `NewFileMediaStoreWithCleanup` → `Start()` launches background cleanup goroutine
+- Cleanup interval and max TTL are controlled by configuration
+
+### 4.9 Identity
+
+**File**: `pkg/identity/identity.go`
+
+```go
+// Build canonical ID
+func BuildCanonicalID(platform, platformID string) string
+// → "telegram:123456"
+
+// Parse canonical ID
+func ParseCanonicalID(canonical string) (platform, id string, ok bool)
+
+// Match against allow list (backward-compatible)
+func MatchAllowed(sender bus.SenderInfo, allowed string) bool
+```
+
+`MatchAllowed` supported allow-list formats:
+| Format | Matching |
+|--------|----------|
+| `"123456"` | Matches `sender.PlatformID` |
+| `"@alice"` | Matches `sender.Username` |
+| `"123456\|alice"` | Matches PlatformID or Username (legacy format compatibility) |
+| `"telegram:123456"` | Exact match on `sender.CanonicalID` (new format) |
+
+### 4.10 Shared HTTP Server
+
+**File**: `pkg/channels/manager.go`'s `SetupHTTPServer`
+
+Manager creates a single `http.Server` and auto-discovers and registers:
+- Channels implementing `WebhookHandler` → mounted at `wh.WebhookPath()`
+- Channels implementing `HealthChecker` → mounted at `hc.HealthPath()`
+- Global health endpoint registered by `health.Server.RegisterOnMux`
+
+Timeout configuration: ReadTimeout = 30s, WriteTimeout = 30s
+
+---
+
+## Part 5: Key Design Decisions and Conventions
+
+### 5.1 Mandatory Conventions
+
+1. **Error classification is a contract**: A channel's `Send` method **must** return sentinel errors (or wrap them). Manager's retry strategy relies entirely on `errors.Is` checks. Returning unclassified errors will cause Manager to treat them as "unknown errors" (exponential backoff retry).
+
+2. **SetRunning is a lifecycle signal**: **Must** call `c.SetRunning(true)` after successful `Start`, and **must** call `c.SetRunning(false)` at the beginning of `Stop`. **Must** check `c.IsRunning()` in `Send` and return `ErrNotRunning`.
+
+3. **HandleMessage includes permission checks**: Do not perform your own permission checks before calling `HandleMessage` (unless you need platform-specific preprocessing before the check). `HandleMessage` already calls `IsAllowedSender`/`IsAllowed` internally.
+
+4. **Message splitting is handled by Manager**: A channel's `Send` method does not need to handle long message splitting. Manager automatically splits based on `MaxMessageLength()` before calling `Send`. Channels only need to declare the limit via `WithMaxMessageLength`.
+
+5. **Typing/Reaction/Placeholder is handled by BaseChannel + Manager automatically**: A channel's `Send` method does not need to manage Typing stop, Reaction undo, or Placeholder editing. `BaseChannel.HandleMessage` auto-triggers `TypingCapable`, `ReactionCapable`, and `PlaceholderCapable` on the inbound side (via `owner` type assertions); Manager's `preSend` auto-stops Typing, undoes Reaction, and edits Placeholder on the outbound side. Channels only need to implement the corresponding interfaces.
+
+6. **Factory registration belongs in init()**: Each sub-package must have an `init.go` file calling `channels.RegisterFactory`. Gateway must trigger registration via blank imports (`_ "pkg/channels/xxx"`).
+
+### 5.2 Metadata Field Usage Conventions
+
+**Do NOT put the following information in Metadata anymore**:
+- `peer_kind` / `peer_id` → Use `InboundMessage.Peer`
+- `message_id` → Use `InboundMessage.MessageID`
+- `sender_platform` / `sender_username` → Use `InboundMessage.Sender`
+
+**Metadata should only be used for**:
+- Channel-specific extension information (e.g., Telegram's `reply_to_message_id`)
+- Temporary information that doesn't fit into structured fields
+
+### 5.3 Concurrency Safety Conventions
+
+- `BaseChannel.running`: Uses `atomic.Bool`, thread-safe
+- `Manager.channels` / `Manager.workers`: Protected by `sync.RWMutex`
+- `Manager.placeholders` / `Manager.typingStops` / `Manager.reactionUndos`: Uses `sync.Map`
+- `MessageBus.closed`: Uses `atomic.Bool`
+- `FileMediaStore`: Uses `sync.RWMutex`, two-phase operation to minimize lock-hold time
+- Channel Worker queue: Go channel, inherently concurrent-safe
+
+### 5.4 Testing Conventions
+
+Existing test files:
+- `pkg/channels/base_test.go` — BaseChannel unit tests
+- `pkg/channels/manager_test.go` — Manager unit tests
+- `pkg/channels/split_test.go` — Message splitting tests
+- `pkg/channels/errors_test.go` — Error type tests
+- `pkg/channels/errutil_test.go` — Error classification tests
+
+To add tests for a new channel:
+```bash
+go test ./pkg/channels/matrix/ -v # Sub-package tests
+go test ./pkg/channels/ -run TestSpecific -v # Framework tests
+make test # Full test suite
+```
+
+---
+
+## Appendix: Complete File Listing and Interface Quick Reference
+
+### A.1 Framework Layer Files
+
+| File | Responsibility |
+|------|---------------|
+| `pkg/channels/base.go` | BaseChannel struct, Channel interface, MessageLengthProvider, BaseChannelOption, HandleMessage |
+| `pkg/channels/interfaces.go` | TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder interfaces |
+| `pkg/channels/media.go` | MediaSender interface |
+| `pkg/channels/webhook.go` | WebhookHandler, HealthChecker interfaces |
+| `pkg/channels/errors.go` | ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed sentinels |
+| `pkg/channels/errutil.go` | ClassifySendError, ClassifyNetError helpers |
+| `pkg/channels/registry.go` | RegisterFactory, getFactory factory registry |
+| `pkg/channels/manager.go` | Manager: Worker queues, rate limiting, retries, preSend, shared HTTP, TTL janitor |
+| `pkg/channels/split.go` | SplitMessage long-message splitting |
+| `pkg/bus/bus.go` | MessageBus implementation |
+| `pkg/bus/types.go` | Peer, SenderInfo, InboundMessage, OutboundMessage, OutboundMediaMessage, MediaPart |
+| `pkg/media/store.go` | MediaStore interface, FileMediaStore implementation |
+| `pkg/identity/identity.go` | BuildCanonicalID, ParseCanonicalID, MatchAllowed |
+
+### A.2 Channel Sub-packages
+
+| Sub-package | Registered Name | Optional Interfaces |
+|-------------|----------------|-------------------|
+| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
+| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
+| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
+| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
+| `pkg/channels/feishu/` | `"feishu"` | — (architecture-specific build tags: `feishu_32.go` / `feishu_64.go`) |
+| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
+| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
+| `pkg/channels/qq/` | `"qq"` | — |
+| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge mode) |
+| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (Native whatsmeow mode) |
+| `pkg/channels/maixcam/` | `"maixcam"` | — |
+| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
+
+### A.3 Interface Quick Reference
+
+```go
+// ===== Required =====
+type Channel interface {
+ Name() string
+ Start(ctx context.Context) error
+ Stop(ctx context.Context) error
+ Send(ctx context.Context, msg bus.OutboundMessage) error
+ IsRunning() bool
+ IsAllowed(senderID string) bool
+ IsAllowedSender(sender bus.SenderInfo) bool
+ ReasoningChannelID() string
+}
+
+// ===== Optional =====
+type MediaSender interface {
+ SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error
+}
+
+type TypingCapable interface {
+ StartTyping(ctx context.Context, chatID string) (stop func(), err error)
+}
+
+type ReactionCapable interface {
+ ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error)
+}
+
+type PlaceholderCapable interface {
+ SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error)
+}
+
+type MessageEditor interface {
+ EditMessage(ctx context.Context, chatID, messageID, content string) error
+}
+
+type WebhookHandler interface {
+ WebhookPath() string
+ http.Handler
+}
+
+type HealthChecker interface {
+ HealthPath() string
+ HealthHandler(w http.ResponseWriter, r *http.Request)
+}
+
+type MessageLengthProvider interface {
+ MaxMessageLength() int
+}
+
+// ===== Injected by Manager =====
+type PlaceholderRecorder interface {
+ RecordPlaceholder(channel, chatID, placeholderID string)
+ RecordTypingStop(channel, chatID string, stop func())
+ RecordReactionUndo(channel, chatID string, undo func())
+}
+```
+
+### A.4 Gateway Startup Sequence (Complete Bootstrap Flow)
+
+```go
+// 1. Create core components
+msgBus := bus.NewMessageBus()
+provider := providers.CreateProvider(cfg)
+agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
+
+// 2. Create media store (with TTL cleanup)
+mediaStore := media.NewFileMediaStoreWithCleanup(cleanerConfig)
+mediaStore.Start()
+
+// 3. Create Channel Manager (triggers initChannels → factory lookup → construct → inject MediaStore/PlaceholderRecorder/Owner)
+channelManager := channels.NewManager(cfg, msgBus, mediaStore)
+
+// 4. Inject references
+agentLoop.SetChannelManager(channelManager)
+agentLoop.SetMediaStore(mediaStore)
+
+// 5. Configure shared HTTP server
+channelManager.SetupHTTPServer(addr, healthServer)
+
+// 6. Start
+channelManager.StartAll(ctx) // Start channels + workers + dispatchers + HTTP server
+go agentLoop.Run(ctx) // Start Agent message loop
+
+// 7. Shutdown (signal-triggered)
+cancel() // Cancel context
+msgBus.Close() // Signal close + drain
+channelManager.StopAll(shutdownCtx) // Stop HTTP + workers + channels
+mediaStore.Stop() // Stop TTL cleanup
+agentLoop.Stop() // Stop Agent
+```
+
+### A.5 Per-channel Rate Limit Reference
+
+| Channel | Rate (msg/s) | Burst |
+|---------|-------------|-------|
+| telegram | 20 | 10 |
+| discord | 1 | 1 |
+| slack | 1 | 1 |
+| line | 10 | 5 |
+| _others_ | 10 (default) | 5 |
+
+### A.6 Known Limitations and Caveats
+
+1. **Media cleanup temporarily disabled**: The `ReleaseAll` call in the Agent loop is commented out (`refactor(loop): disable media cleanup to prevent premature file deletion`) because session boundaries are not yet clearly defined. TTL cleanup remains active.
+
+2. **Feishu architecture-specific compilation**: The Feishu channel uses build tags to distinguish 32-bit and 64-bit architectures (`feishu_32.go` / `feishu_64.go`). Feishu uses the SDK's WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
+
+3. **WeCom has two factories**: `"wecom"` (Bot mode, webhook only) and `"wecom_app"` (App mode, supports MediaSender) are registered separately. Both implement `WebhookHandler` and `HealthChecker`.
+
+4. **Pico Protocol**: `pkg/channels/pico/` implements a custom PicoClaw native protocol channel that receives messages via WebSocket webhook (`/pico/ws`).
+
+5. **WhatsApp has two modes**: `"whatsapp"` (Bridge mode, communicates via external bridge URL) and `"whatsapp_native"` (native whatsmeow mode, connects directly to WhatsApp). Manager selects which to initialize based on `WhatsAppConfig.UseNative`.
+
+6. **DingTalk uses Stream mode**: DingTalk uses the SDK's Stream/WebSocket mode (not HTTP webhook), so it does not implement `WebhookHandler`.
+
+7. **PlaceholderConfig vs implementation**: `PlaceholderConfig` appears in 6 channel configs (Telegram, Discord, Slack, LINE, OneBot, Pico), but only channels that implement both `PlaceholderCapable` + `MessageEditor` (Telegram, Discord, Pico) can actually use placeholder message editing. The rest are reserved fields.
+
+8. **ReasoningChannelID**: Most channel configs include a `reasoning_channel_id` field to route LLM reasoning/thinking output to a designated channel (WhatsApp, Telegram, Feishu, Discord, MaixCam, QQ, DingTalk, Slack, LINE, OneBot, WeCom, WeComApp). Note: `PicoConfig` does not currently expose this field. `BaseChannel` exposes this via the `WithReasoningChannelID` option and `ReasoningChannelID()` method.
\ No newline at end of file
diff --git a/pkg/channels/README.zh.md b/pkg/channels/README.zh.md
new file mode 100644
index 000000000..2c5e7356e
--- /dev/null
+++ b/pkg/channels/README.zh.md
@@ -0,0 +1,1383 @@
+# PicoClaw Channel System:完整开发指南
+
+> **影响范围**: `pkg/channels/`, `pkg/bus/`, `pkg/media/`, `pkg/identity/`, `cmd/picoclaw/internal/gateway/`
+
+---
+
+## 目录
+
+- [第一部分:架构总览](#第一部分架构总览)
+- [第二部分:迁移指南——从 main 分支迁移到重构分支](#第二部分迁移指南从-main-分支迁移到重构分支)
+- [第三部分:新 Channel 开发指南——从零实现一个新 Channel](#第三部分新-channel-开发指南从零实现一个新-channel)
+- [第四部分:核心子系统详解](#第四部分核心子系统详解)
+- [第五部分:关键设计决策与约定](#第五部分关键设计决策与约定)
+- [附录:完整文件清单与接口速查表](#附录完整文件清单与接口速查表)
+
+---
+
+## 第一部分:架构总览
+
+### 1.1 重构前后对比
+
+**重构前(main 分支)**:
+
+```
+pkg/channels/
+├── telegram.go # 每个 channel 直接放在 channels 包内
+├── discord.go
+├── slack.go
+├── manager.go # Manager 直接引用各 channel 类型
+├── ...
+```
+
+- Channel 实现全部在 `pkg/channels/` 包的顶层
+- Manager 通过 `switch` 或 `if-else` 链条直接构造各 channel
+- Peer、MessageID 等路由信息埋在 `Metadata map[string]string` 中
+- 消息发送没有速率限制和重试
+- 没有统一的媒体文件生命周期管理
+- 各 channel 各自启动 HTTP 服务器
+- 群聊触发过滤逻辑分散在各 channel 中
+
+**重构后(refactor/channel-system 分支)**:
+
+```
+pkg/channels/
+├── base.go # BaseChannel 共享抽象层
+├── interfaces.go # 可选能力接口(TypingCapable, MessageEditor, ReactionCapable, PlaceholderCapable, PlaceholderRecorder)
+├── README.md # 英文文档
+├── README.zh.md # 中文文档
+├── media.go # MediaSender 可选接口
+├── webhook.go # WebhookHandler, HealthChecker 可选接口
+├── errors.go # 错误哨兵值(ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed)
+├── errutil.go # 错误分类帮助函数
+├── registry.go # 工厂注册表(RegisterFactory / getFactory)
+├── manager.go # 统一编排:Worker 队列、速率限制、重试、Typing/Placeholder、共享 HTTP
+├── split.go # 长消息智能分割(保留代码块完整性)
+├── telegram/ # 每个 channel 独立子包
+│ ├── init.go # 工厂注册
+│ ├── telegram.go # 实现
+│ └── telegram_commands.go
+├── discord/
+│ ├── init.go
+│ └── discord.go
+├── slack/ line/ onebot/ dingtalk/ feishu/ wecom/ qq/ whatsapp/ whatsapp_native/ maixcam/ pico/
+│ └── ...
+
+pkg/bus/
+├── bus.go # MessageBus(缓冲区 64,安全关闭+排水)
+├── types.go # 结构化消息类型(Peer, SenderInfo, MediaPart, InboundMessage, OutboundMessage, OutboundMediaMessage)
+
+pkg/media/
+├── store.go # MediaStore 接口 + FileMediaStore 实现(两阶段释放,TTL 清理)
+
+pkg/identity/
+├── identity.go # 统一用户身份:规范 "platform:id" 格式 + 向后兼容匹配
+```
+
+### 1.2 消息流转全景图
+
+```
+┌────────────┐ InboundMessage ┌───────────┐ LLM + Tools ┌────────────┐
+│ Telegram │──┐ │ │ │ │
+│ Discord │──┤ PublishInbound() │ │ PublishOutbound() │ │
+│ Slack │──┼──────────────────────▶ │ MessageBus │ ◀─────────────────── │ AgentLoop │
+│ LINE │──┤ (buffered chan, 64) │ │ (buffered chan, 64) │ │
+│ ... │──┘ │ │ │ │
+└────────────┘ └─────┬─────┘ └────────────┘
+ │
+ SubscribeOutbound() │ SubscribeOutboundMedia()
+ ▼
+ ┌───────────────────┐
+ │ Manager │
+ │ ├── dispatchOutbound() 路由到 Worker 队列
+ │ ├── dispatchOutboundMedia()
+ │ ├── runWorker() 消息分割 + sendWithRetry()
+ │ ├── runMediaWorker() sendMediaWithRetry()
+ │ ├── preSend() 停止 Typing + 撤销 Reaction + 编辑 Placeholder
+ │ └── runTTLJanitor() 清理过期 Typing/Placeholder
+ └────────┬──────────┘
+ │
+ channel.Send() / SendMedia()
+ │
+ ▼
+ ┌────────────────┐
+ │ 各平台 API/SDK │
+ └────────────────┘
+```
+
+### 1.3 关键设计原则
+
+| 原则 | 说明 |
+|------|------|
+| **子包隔离** | 每个 channel 一个独立 Go 子包,依赖 `channels` 父包提供的 `BaseChannel` 和接口 |
+| **工厂注册** | 各子包通过 `init()` 自注册,Manager 通过名字查找工厂,消除 import 耦合 |
+| **能力发现** | 可选能力通过接口(`MediaSender`, `TypingCapable`, `ReactionCapable`, `PlaceholderCapable`, `MessageEditor`, `WebhookHandler`, `HealthChecker`)声明,Manager 运行时类型断言发现 |
+| **结构化消息** | Peer、MessageID、SenderInfo 从 Metadata 提升为 InboundMessage 的一等字段 |
+| **错误分类** | Channel 返回哨兵错误(`ErrRateLimit`, `ErrTemporary` 等),Manager 据此决定重试策略 |
+| **集中编排** | 速率限制、消息分割、重试、Typing/Reaction/Placeholder 全部由 Manager 和 BaseChannel 统一处理,Channel 只负责 Send |
+
+---
+
+## 第二部分:迁移指南——从 main 分支迁移到重构分支
+
+### 2.1 如果你有未合并的 Channel 修改
+
+#### 步骤 1:确认你修改了哪些文件
+
+在 main 分支上,Channel 文件直接位于 `pkg/channels/` 顶层,例如:
+- `pkg/channels/telegram.go`
+- `pkg/channels/discord.go`
+
+重构后,这些文件已被删除,代码移动到了对应子包:
+- `pkg/channels/telegram/telegram.go`
+- `pkg/channels/discord/discord.go`
+
+#### 步骤 2:理解结构变化映射
+
+| main 分支文件 | 重构分支位置 | 变化 |
+|---|---|---|
+| `pkg/channels/telegram.go` | `pkg/channels/telegram/telegram.go` + `init.go` | 包名从 `channels` 变为 `telegram` |
+| `pkg/channels/discord.go` | `pkg/channels/discord/discord.go` + `init.go` | 同上 |
+| `pkg/channels/manager.go` | `pkg/channels/manager.go` | 大幅重写 |
+| _(不存在)_ | `pkg/channels/base.go` | 新增共享抽象层 |
+| _(不存在)_ | `pkg/channels/registry.go` | 新增工厂注册表 |
+| _(不存在)_ | `pkg/channels/errors.go` + `errutil.go` | 新增错误分类体系 |
+| _(不存在)_ | `pkg/channels/interfaces.go` | 新增可选能力接口 |
+| _(不存在)_ | `pkg/channels/media.go` | 新增 MediaSender 接口 |
+| _(不存在)_ | `pkg/channels/webhook.go` | 新增 WebhookHandler/HealthChecker |
+| _(不存在)_ | `pkg/channels/whatsapp_native/` | 新增 WhatsApp 原生模式(whatsmeow) |
+| _(不存在)_ | `pkg/channels/split.go` | 新增消息分割(从 utils 迁入) |
+| _(不存在)_ | `pkg/bus/types.go` | 新增结构化消息类型 |
+| _(不存在)_ | `pkg/media/store.go` | 新增媒体文件生命周期管理 |
+| _(不存在)_ | `pkg/identity/identity.go` | 新增统一用户身份 |
+
+#### 步骤 3:迁移你的 Channel 代码
+
+以 Telegram 为例,主要改动项:
+
+**3a. 包声明和导入**
+
+```go
+// 旧代码(main 分支)
+package channels
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// 新代码(重构分支)
+package telegram
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels" // 引用父包
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity" // 新增
+ "github.com/sipeed/picoclaw/pkg/media" // 新增(如需媒体)
+)
+```
+
+**3b. 结构体嵌入 BaseChannel**
+
+```go
+// 旧代码:直接持有 bus、config 等字段
+type TelegramChannel struct {
+ bus *bus.MessageBus
+ config *config.Config
+ running bool
+ allowList []string
+ // ...
+}
+
+// 新代码:嵌入 BaseChannel,它提供 bus、running、allowList 等
+type TelegramChannel struct {
+ *channels.BaseChannel // 嵌入共享抽象
+ bot *telego.Bot
+ config *config.Config
+ // ... 只保留 channel 特有字段
+}
+```
+
+**3c. 构造函数**
+
+```go
+// 旧代码:直接赋值
+func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
+ return &TelegramChannel{
+ bus: bus,
+ config: cfg,
+ allowList: cfg.Channels.Telegram.AllowFrom,
+ // ...
+ }, nil
+}
+
+// 新代码:使用 NewBaseChannel + 功能选项
+func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
+ base := channels.NewBaseChannel(
+ "telegram", // 名称
+ cfg.Channels.Telegram, // 原始配置(any 类型)
+ bus, // 消息总线
+ cfg.Channels.Telegram.AllowFrom, // 允许列表
+ channels.WithMaxMessageLength(4096), // 平台消息长度上限
+ channels.WithGroupTrigger(cfg.Channels.Telegram.GroupTrigger), // 群聊触发配置
+ channels.WithReasoningChannelID(cfg.Channels.Telegram.ReasoningChannelID), // 思维链路由
+ )
+ return &TelegramChannel{
+ BaseChannel: base,
+ bot: bot,
+ config: cfg,
+ }, nil
+}
+```
+
+**3d. Start/Stop 生命周期**
+
+```go
+// 新代码:使用 SetRunning 原子操作
+func (c *TelegramChannel) Start(ctx context.Context) error {
+ // ... 初始化 bot、webhook 等
+ c.SetRunning(true) // 必须在就绪后调用
+ go bh.Start()
+ return nil
+}
+
+func (c *TelegramChannel) Stop(ctx context.Context) error {
+ c.SetRunning(false) // 必须在清理前调用
+ // ... 停止 bot handler、取消 context
+ return nil
+}
+```
+
+**3e. Send 方法的错误返回**
+
+```go
+// 旧代码:返回普通 error
+func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.running { return fmt.Errorf("not running") }
+ // ...
+ if err != nil { return err }
+}
+
+// 新代码:必须返回哨兵错误,供 Manager 判断重试策略
+func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning // ← Manager 不会重试
+ }
+ // ...
+ if err != nil {
+ // 使用 ClassifySendError 根据 HTTP 状态码包装错误
+ return channels.ClassifySendError(statusCode, err)
+ // 或手动包装:
+ // return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
+ // return fmt.Errorf("%w: %v", channels.ErrRateLimit, err)
+ // return fmt.Errorf("%w: %v", channels.ErrSendFailed, err)
+ }
+ return nil
+}
+```
+
+**3f. 消息接收(Inbound)**
+
+```go
+// 旧代码:直接构造 InboundMessage 并发布
+msg := bus.InboundMessage{
+ Channel: "telegram",
+ SenderID: senderID,
+ ChatID: chatID,
+ Content: content,
+ Metadata: map[string]string{
+ "peer_kind": "group", // 路由信息埋在 metadata
+ "peer_id": chatID,
+ "message_id": msgID,
+ },
+}
+c.bus.PublishInbound(ctx, msg)
+
+// 新代码:使用 BaseChannel.HandleMessage,传入结构化字段
+sender := bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: strconv.FormatInt(from.ID, 10),
+ CanonicalID: identity.BuildCanonicalID("telegram", strconv.FormatInt(from.ID, 10)),
+ Username: from.Username,
+ DisplayName: from.FirstName,
+}
+
+peer := bus.Peer{
+ Kind: "group", // 或 "direct"
+ ID: chatID,
+}
+
+// HandleMessage 内部调用 IsAllowedSender 检查权限,构建 MediaScope,发布到 bus
+c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, sender)
+```
+
+**3g. 添加工厂注册(必需)**
+
+为你的 channel 创建 `init.go`:
+
+```go
+// pkg/channels/telegram/init.go
+package telegram
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewTelegramChannel(cfg, b)
+ })
+}
+```
+
+**3h. 在 Gateway 中导入子包**
+
+```go
+// cmd/picoclaw/internal/gateway/helpers.go
+import (
+ _ "github.com/sipeed/picoclaw/pkg/channels/telegram" // 触发 init() 注册
+ _ "github.com/sipeed/picoclaw/pkg/channels/discord"
+ _ "github.com/sipeed/picoclaw/pkg/channels/your_new_channel" // 新增
+)
+```
+
+#### 步骤 4:迁移 Bus 消息使用方式
+
+如果你的代码直接读取 `InboundMessage.Metadata` 中的路由字段:
+
+```go
+// 旧代码
+peerKind := msg.Metadata["peer_kind"]
+peerID := msg.Metadata["peer_id"]
+msgID := msg.Metadata["message_id"]
+
+// 新代码
+peerKind := msg.Peer.Kind // 一等字段
+peerID := msg.Peer.ID // 一等字段
+msgID := msg.MessageID // 一等字段
+sender := msg.Sender // bus.SenderInfo 结构体
+scope := msg.MediaScope // 媒体生命周期作用域
+```
+
+#### 步骤 5:迁移允许列表检查
+
+```go
+// 旧代码
+if !c.isAllowed(senderID) { return }
+
+// 新代码:优先使用结构化检查
+if !c.IsAllowedSender(sender) { return }
+// 或回退到字符串检查:
+if !c.IsAllowed(senderID) { return }
+```
+
+`BaseChannel.HandleMessage` 方法内部已经处理了这个逻辑,无需在 channel 中重复检查。
+
+### 2.2 如果你有 Manager 的修改
+
+Manager 已被完全重写。你的修改需要理解新架构:
+
+| 旧 Manager 职责 | 新 Manager 职责 |
+|---|---|
+| 直接构造 channel(switch/if-else) | 通过工厂注册表查找并构造 |
+| 直接调用 channel.Send | 通过 per-channel Worker 队列 + 速率限制 + 重试 |
+| 无消息分割 | 自动根据 MaxMessageLength 分割长消息 |
+| 各 channel 自建 HTTP 服务器 | 统一共享 HTTP 服务器 |
+| 无 Typing/Placeholder 管理 | 统一 preSend 处理 Typing 停止 + Reaction 撤销 + Placeholder 编辑;入站侧 BaseChannel.HandleMessage 自动编排 Typing/Reaction/Placeholder |
+| 无 TTL 清理 | runTTLJanitor 定期清理过期 Typing/Reaction/Placeholder 条目 |
+
+### 2.3 如果你有 Agent Loop 的修改
+
+Agent Loop 的主要变化:
+
+1. **MediaStore 注入**:`agentLoop.SetMediaStore(mediaStore)` — Agent 通过 MediaStore 解析工具产生的媒体引用
+2. **ChannelManager 注入**:`agentLoop.SetChannelManager(channelManager)` — Agent 可查询 channel 状态
+3. **OutboundMediaMessage**:Agent 现在通过 `bus.PublishOutboundMedia()` 发送媒体消息,而非嵌入文本回复
+4. **extractPeer**:路由使用 `msg.Peer` 结构化字段而非 Metadata 查找
+
+---
+
+## 第三部分:新 Channel 开发指南——从零实现一个新 Channel
+
+### 3.1 最小实现清单
+
+要添加一个新的聊天平台(例如 `matrix`),你需要:
+
+1. ✅ 创建子包目录 `pkg/channels/matrix/`
+2. ✅ 创建 `init.go` — 工厂注册
+3. ✅ 创建 `matrix.go` — Channel 实现
+4. ✅ 在 Gateway helpers 中添加 blank import
+5. ✅ 在 Manager.initChannels() 中添加配置检查
+6. ✅ 在 `pkg/config/` 中添加配置结构体
+
+### 3.2 完整模板
+
+#### `pkg/channels/matrix/init.go`
+
+```go
+package matrix
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("matrix", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewMatrixChannel(cfg, b)
+ })
+}
+```
+
+#### `pkg/channels/matrix/matrix.go`
+
+```go
+package matrix
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// MatrixChannel implements channels.Channel for the Matrix protocol.
+type MatrixChannel struct {
+ *channels.BaseChannel // 必须嵌入
+ config *config.Config
+ ctx context.Context
+ cancel context.CancelFunc
+ // ... Matrix SDK 客户端等
+}
+
+func NewMatrixChannel(cfg *config.Config, msgBus *bus.MessageBus) (*MatrixChannel, error) {
+ matrixCfg := cfg.Channels.Matrix // 假设配置中有此字段
+
+ base := channels.NewBaseChannel(
+ "matrix", // channel 名称(全局唯一)
+ matrixCfg, // 原始配置
+ msgBus, // 消息总线
+ matrixCfg.AllowFrom, // 允许列表
+ channels.WithMaxMessageLength(65536), // Matrix 消息长度限制
+ channels.WithGroupTrigger(matrixCfg.GroupTrigger),
+ channels.WithReasoningChannelID(matrixCfg.ReasoningChannelID), // 思维链路由(可选)
+ )
+
+ return &MatrixChannel{
+ BaseChannel: base,
+ config: cfg,
+ }, nil
+}
+
+// ========== 必须实现的 Channel 接口方法 ==========
+
+func (c *MatrixChannel) Start(ctx context.Context) error {
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ // 1. 初始化 Matrix 客户端
+ // 2. 开始监听消息
+ // 3. 标记为运行中
+ c.SetRunning(true)
+
+ logger.InfoC("matrix", "Matrix channel started")
+ return nil
+}
+
+func (c *MatrixChannel) Stop(ctx context.Context) error {
+ c.SetRunning(false)
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ logger.InfoC("matrix", "Matrix channel stopped")
+ return nil
+}
+
+func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ // 1. 检查运行状态
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ // 2. 发送消息到 Matrix
+ err := c.sendToMatrix(ctx, msg.ChatID, msg.Content)
+ if err != nil {
+ // 3. 必须使用错误分类包装
+ // 如果你有 HTTP 状态码:
+ // return channels.ClassifySendError(statusCode, err)
+ // 如果是网络错误:
+ // return channels.ClassifyNetError(err)
+ // 如果需要手动分类:
+ return fmt.Errorf("%w: %v", channels.ErrTemporary, err)
+ }
+
+ return nil
+}
+
+// ========== 消息接收处理 ==========
+
+func (c *MatrixChannel) handleIncoming(roomID, senderID, displayName, content string, msgID string) {
+ // 1. 构造结构化发送者身份
+ sender := bus.SenderInfo{
+ Platform: "matrix",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("matrix", senderID),
+ Username: senderID,
+ DisplayName: displayName,
+ }
+
+ // 2. 确定 Peer 类型(直聊 vs 群聊)
+ peer := bus.Peer{
+ Kind: "group", // 或 "direct"
+ ID: roomID,
+ }
+
+ // 3. 群聊过滤(如适用)
+ isGroup := peer.Kind == "group"
+ if isGroup {
+ isMentioned := false // 根据平台特性检测 @提及
+ shouldRespond, cleanContent := c.ShouldRespondInGroup(isMentioned, content)
+ if !shouldRespond {
+ return
+ }
+ content = cleanContent
+ }
+
+ // 4. 处理媒体附件(如有)
+ var mediaRefs []string
+ store := c.GetMediaStore()
+ if store != nil {
+ // 下载附件到本地 → store.Store() → 获取 ref
+ // mediaRefs = append(mediaRefs, ref)
+ }
+
+ // 5. 调用 HandleMessage 发布到 bus
+ // HandleMessage 内部会:
+ // - 检查 IsAllowedSender/IsAllowed
+ // - 构建 MediaScope
+ // - 发布 InboundMessage
+ c.HandleMessage(
+ c.ctx,
+ peer,
+ msgID, // 平台消息 ID
+ senderID, // 原始发送者 ID
+ roomID, // 聊天/房间 ID
+ content, // 消息内容
+ mediaRefs, // 媒体引用列表
+ nil, // 额外 metadata(通常 nil)
+ sender, // SenderInfo(variadic 参数)
+ )
+}
+
+// ========== 内部方法 ==========
+
+func (c *MatrixChannel) sendToMatrix(ctx context.Context, roomID, content string) error {
+ // 实际的 Matrix SDK 调用
+ return nil
+}
+```
+
+### 3.3 可选能力接口
+
+根据平台能力,你的 Channel 可以选择性实现以下接口:
+
+#### MediaSender — 发送媒体附件
+
+```go
+// 如果平台支持发送图片/文件/音频/视频
+func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store: %w", channels.ErrSendFailed)
+ }
+
+ for _, part := range msg.Parts {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("matrix", "Failed to resolve media", map[string]any{
+ "ref": part.Ref, "error": err.Error(),
+ })
+ continue
+ }
+
+ // 根据 part.Type ("image"|"audio"|"video"|"file") 调用对应 API
+ switch part.Type {
+ case "image":
+ // 上传图片到 Matrix
+ default:
+ // 上传文件到 Matrix
+ }
+ }
+ return nil
+}
+```
+
+#### TypingCapable — Typing 指示器
+
+```go
+// 如果平台支持 "正在输入..." 提示
+func (c *MatrixChannel) StartTyping(ctx context.Context, chatID string) (stop func(), err error) {
+ // 调用 Matrix API 发送 typing 指示器
+ // 返回的 stop 函数必须是幂等的
+ stopped := false
+ return func() {
+ if !stopped {
+ stopped = true
+ // 调用 Matrix API 停止 typing
+ }
+ }, nil
+}
+```
+
+#### ReactionCapable — 消息反应指示器
+
+```go
+// 如果平台支持对入站消息添加 emoji 反应(如 Slack 的 👀、OneBot 的表情 289)
+func (c *MatrixChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error) {
+ // 调用 Matrix API 添加反应到消息
+ // 返回的 undo 函数移除反应,必须是幂等的
+ err = c.addReaction(chatID, messageID, "eyes")
+ if err != nil {
+ return func() {}, err
+ }
+ return func() {
+ c.removeReaction(chatID, messageID, "eyes")
+ }, nil
+}
+```
+
+#### MessageEditor — 消息编辑
+
+```go
+// 如果平台支持编辑已发送的消息(用于 Placeholder 替换)
+func (c *MatrixChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
+ // 调用 Matrix API 编辑消息
+ return nil
+}
+```
+
+#### PlaceholderCapable — 占位消息
+
+```go
+// 如果平台支持发送占位消息(如 "Thinking... 💭"),并且实现了 MessageEditor,
+// 则 Manager 的 preSend 会在出站时自动将占位消息编辑为最终回复。
+// SendPlaceholder 内部根据 PlaceholderConfig.Enabled 决定是否发送;
+// 返回 ("", nil) 表示跳过。
+func (c *MatrixChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ cfg := c.config.Channels.Matrix.Placeholder
+ if !cfg.Enabled {
+ return "", nil
+ }
+ text := cfg.Text
+ if text == "" {
+ text = "Thinking... 💭"
+ }
+ // 调用 Matrix API 发送占位消息
+ msg, err := c.sendText(ctx, chatID, text)
+ if err != nil {
+ return "", err
+ }
+ return msg.ID, nil
+}
+```
+
+#### WebhookHandler — HTTP Webhook 接收
+
+```go
+// 如果 channel 通过 webhook 接收消息(而非长轮询/WebSocket)
+func (c *MatrixChannel) WebhookPath() string {
+ return "/webhook/matrix" // 路径会被注册到共享 HTTP 服务器
+}
+
+func (c *MatrixChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // 处理 webhook 请求
+}
+```
+
+#### HealthChecker — 健康检查端点
+
+```go
+func (c *MatrixChannel) HealthPath() string {
+ return "/health/matrix"
+}
+
+func (c *MatrixChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
+ if c.IsRunning() {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("OK"))
+ } else {
+ w.WriteHeader(http.StatusServiceUnavailable)
+ }
+}
+```
+
+### 3.4 入站侧 Typing/Reaction/Placeholder 自动编排
+
+`BaseChannel.HandleMessage` 在发布入站消息**之前**,自动检测 channel 是否实现了 `TypingCapable`、`ReactionCapable` 和/或 `PlaceholderCapable`,并触发相应的指示器。三条管道完全独立,互不干扰:
+
+```go
+// BaseChannel.HandleMessage 内部自动执行(无需 channel 手动调用):
+if c.owner != nil && c.placeholderRecorder != nil {
+ // Typing — 独立管道
+ if tc, ok := c.owner.(TypingCapable); ok {
+ if stop, err := tc.StartTyping(ctx, chatID); err == nil {
+ c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
+ }
+ }
+ // Reaction — 独立管道
+ if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
+ if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
+ c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
+ }
+ }
+ // Placeholder — 独立管道
+ if pc, ok := c.owner.(PlaceholderCapable); ok {
+ if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" {
+ c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID)
+ }
+ }
+}
+```
+
+**这意味着**:
+- 实现 `TypingCapable` 的 channel(Telegram、Discord、LINE、Pico)无需在 `handleMessage` 中手动调用 `StartTyping` + `RecordTypingStop`
+- 实现 `ReactionCapable` 的 channel(Slack、OneBot)无需在 `handleMessage` 中手动调用 `AddReaction` + `RecordTypingStop`
+- 实现 `PlaceholderCapable` 的 channel(Telegram、Discord、Pico)无需在 `handleMessage` 中手动发送占位消息并调用 `RecordPlaceholder`
+- Channel 只需实现对应接口,`HandleMessage` 会自动完成编排
+- 不实现这些接口的 channel 不受影响(类型断言会失败,跳过)
+- `PlaceholderCapable` 的 `SendPlaceholder` 方法内部根据配置的 `PlaceholderConfig.Enabled` 决定是否发送;返回 `("", nil)` 时跳过注册
+
+**Owner 注入**:Manager 在 `initChannel` 中自动调用 `SetOwner(ch)` 将具体 channel 注入 BaseChannel,无需开发者手动设置。
+
+当 Agent 处理完消息后,Manager 的 `preSend` 会自动:
+1. 调用已记录的 `stop()` 停止 Typing
+2. 调用已记录的 `undo()` 撤销 Reaction
+3. 如果有 Placeholder,且 channel 实现了 `MessageEditor`,尝试编辑 Placeholder 为最终回复(跳过 Send)
+
+### 3.5 注册配置和 Gateway 接入
+
+#### 在 `pkg/config/config.go` 中添加配置
+
+```go
+type ChannelsConfig struct {
+ // ... 现有 channels
+ Matrix MatrixChannelConfig `json:"matrix"`
+}
+
+type MatrixChannelConfig struct {
+ Enabled bool `json:"enabled"`
+ HomeServer string `json:"home_server"`
+ Token string `json:"token"`
+ AllowFrom []string `json:"allow_from"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger"`
+ Placeholder PlaceholderConfig `json:"placeholder"`
+ ReasoningChannelID string `json:"reasoning_channel_id"`
+}
+```
+
+#### 在 Manager.initChannels() 中添加入口
+
+```go
+// pkg/channels/manager.go 的 initChannels() 方法中
+if m.config.Channels.Matrix.Enabled && m.config.Channels.Matrix.Token != "" {
+ m.initChannel("matrix", "Matrix")
+}
+```
+
+> **注意**:如果你的 channel 有多种模式(如 WhatsApp Bridge vs Native),需要在 initChannels 中根据配置分支:
+> ```go
+> if cfg.UseNative {
+> m.initChannel("whatsapp_native", "WhatsApp Native")
+> } else {
+> m.initChannel("whatsapp", "WhatsApp")
+> }
+> ```
+
+#### 在 Gateway 中添加 blank import
+
+```go
+// cmd/picoclaw/internal/gateway/helpers.go
+import (
+ _ "github.com/sipeed/picoclaw/pkg/channels/matrix"
+)
+```
+
+---
+
+## 第四部分:核心子系统详解
+
+### 4.1 MessageBus
+
+**文件**:`pkg/bus/bus.go`、`pkg/bus/types.go`
+
+```go
+type MessageBus struct {
+ inbound chan InboundMessage // 缓冲区 = 64
+ outbound chan OutboundMessage // 缓冲区 = 64
+ outboundMedia chan OutboundMediaMessage // 缓冲区 = 64
+ done chan struct{} // 关闭信号
+ closed atomic.Bool // 防止重复关闭
+}
+```
+
+**关键行为**:
+
+| 方法 | 行为 |
+|------|------|
+| `PublishInbound(ctx, msg)` | 检查 closed → 发送到 inbound channel → 阻塞/超时/关闭 |
+| `ConsumeInbound(ctx)` | 从 inbound 读取 → 阻塞/关闭/取消 |
+| `PublishOutbound(ctx, msg)` | 发送到 outbound channel |
+| `SubscribeOutbound(ctx)` | 从 outbound 读取(Manager dispatcher 调用) |
+| `PublishOutboundMedia(ctx, msg)` | 发送到 outboundMedia channel |
+| `SubscribeOutboundMedia(ctx)` | 从 outboundMedia 读取(Manager media dispatcher 调用) |
+| `Close()` | CAS 关闭 → close(done) → 排水所有 channel(**不关闭 channel 本身**,避免并发 send-on-closed panic) |
+
+**设计要点**:
+- 缓冲区从 16 增至 64,减少突发负载下的阻塞
+- `Close()` 不关闭底层 channel(只关闭 `done` 信号通道),因为可能有正在并发 `Publish` 的 goroutine
+- 排水循环确保 buffered 消息不被静默丢弃
+
+### 4.2 结构化消息类型
+
+**文件**:`pkg/bus/types.go`
+
+```go
+// 路由对等体
+type Peer struct {
+ Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
+ ID string `json:"id"`
+}
+
+// 发送者身份信息
+type SenderInfo struct {
+ Platform string `json:"platform,omitempty"` // "telegram", "discord", ...
+ PlatformID string `json:"platform_id,omitempty"` // 平台原始 ID
+ CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" 规范格式
+ Username string `json:"username,omitempty"`
+ DisplayName string `json:"display_name,omitempty"`
+}
+
+// 入站消息
+type InboundMessage struct {
+ Channel string // 来源 channel 名称
+ SenderID string // 发送者 ID(优先使用 CanonicalID)
+ Sender SenderInfo // 结构化发送者信息
+ ChatID string // 聊天/房间 ID
+ Content string // 消息文本
+ Media []string // 媒体引用列表(media://...)
+ Peer Peer // 路由对等体(一等字段)
+ MessageID string // 平台消息 ID(一等字段)
+ MediaScope string // 媒体生命周期作用域
+ SessionKey string // 会话键
+ Metadata map[string]string // 仅用于 channel 特有扩展
+}
+
+// 出站文本消息
+type OutboundMessage struct {
+ Channel string
+ ChatID string
+ Content string
+}
+
+// 出站媒体消息
+type OutboundMediaMessage struct {
+ Channel string
+ ChatID string
+ Parts []MediaPart
+}
+
+// 媒体片段
+type MediaPart struct {
+ Type string // "image" | "audio" | "video" | "file"
+ Ref string // "media://uuid"
+ Caption string
+ Filename string
+ ContentType string
+}
+```
+
+### 4.3 BaseChannel
+
+**文件**:`pkg/channels/base.go`
+
+BaseChannel 是所有 channel 的共享抽象层,提供以下能力:
+
+| 方法/特性 | 说明 |
+|---|---|
+| `Name() string` | Channel 名称 |
+| `IsRunning() bool` | 原子读取运行状态 |
+| `SetRunning(bool)` | 原子设置运行状态 |
+| `MaxMessageLength() int` | 消息长度限制(rune 计数),0 = 无限制 |
+| `ReasoningChannelID() string` | 思维链路由目标 channel ID(空 = 不路由) |
+| `IsAllowed(senderID string) bool` | 旧格式允许列表检查(支持 `"id\|username"` 和 `"@username"` 格式) |
+| `IsAllowedSender(sender SenderInfo) bool` | 新格式允许列表检查(委托给 `identity.MatchAllowed`) |
+| `ShouldRespondInGroup(isMentioned, content) (bool, string)` | 统一群聊触发过滤逻辑 |
+| `HandleMessage(...)` | 统一入站消息处理:权限检查 → 构建 MediaScope → 自动触发 Typing/Reaction/Placeholder → 发布到 Bus |
+| `SetMediaStore(s) / GetMediaStore()` | Manager 注入的媒体存储 |
+| `SetPlaceholderRecorder(r) / GetPlaceholderRecorder()` | Manager 注入的占位符记录器 |
+| `SetOwner(ch) ` | Manager 注入的具体 channel 引用(用于 HandleMessage 内部的 Typing/Reaction/Placeholder 类型断言) |
+
+**功能选项**:
+
+```go
+channels.WithMaxMessageLength(4096) // 设置平台消息长度限制
+channels.WithGroupTrigger(groupTriggerCfg) // 设置群聊触发配置
+channels.WithReasoningChannelID(id) // 设置思维链路由目标 channel
+```
+
+### 4.4 工厂注册表
+
+**文件**:`pkg/channels/registry.go`
+
+```go
+type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
+
+func RegisterFactory(name string, f ChannelFactory) // 子包 init() 中调用
+func getFactory(name string) (ChannelFactory, bool) // Manager 内部调用
+```
+
+工厂注册表使用 `sync.RWMutex` 保护,在 `init()` 阶段注册(进程启动时完成)。Manager 在 `initChannel()` 中通过名字查找工厂并调用它。
+
+### 4.5 错误分类与重试
+
+**文件**:`pkg/channels/errors.go`、`pkg/channels/errutil.go`
+
+#### 哨兵错误
+
+```go
+var (
+ ErrNotRunning = errors.New("channel not running") // 永久:不重试
+ ErrRateLimit = errors.New("rate limited") // 固定延迟:1s 后重试
+ ErrTemporary = errors.New("temporary failure") // 指数退避:500ms * 2^attempt,最大 8s
+ ErrSendFailed = errors.New("send failed") // 永久:不重试
+)
+```
+
+#### 错误分类帮助函数
+
+```go
+// 根据 HTTP 状态码自动分类
+func ClassifySendError(statusCode int, rawErr error) error {
+ // 429 → ErrRateLimit
+ // 5xx → ErrTemporary
+ // 4xx → ErrSendFailed
+}
+
+// 网络错误统一包装为临时错误
+func ClassifyNetError(err error) error {
+ // → ErrTemporary
+}
+```
+
+#### Manager 重试策略(`sendWithRetry`)
+
+```
+最大重试次数: 3
+速率限制延迟: 1 秒
+基础退避: 500 毫秒
+最大退避: 8 秒
+
+重试逻辑:
+ ErrNotRunning → 立即失败,不重试
+ ErrSendFailed → 立即失败,不重试
+ ErrRateLimit → 等待 1s → 重试
+ ErrTemporary → 等待 500ms * 2^attempt(最大 8s) → 重试
+ 其他未知错误 → 等待 500ms * 2^attempt(最大 8s) → 重试
+```
+
+### 4.6 Manager 编排
+
+**文件**:`pkg/channels/manager.go`
+
+#### Per-channel Worker 架构
+
+```go
+type channelWorker struct {
+ ch Channel // channel 实例
+ queue chan bus.OutboundMessage // 出站文本队列(缓冲 16)
+ mediaQueue chan bus.OutboundMediaMessage // 出站媒体队列(缓冲 16)
+ done chan struct{} // 文本 worker 完成信号
+ mediaDone chan struct{} // 媒体 worker 完成信号
+ limiter *rate.Limiter // per-channel 速率限制器
+}
+```
+
+#### Per-channel 速率限制配置
+
+```go
+var channelRateConfig = map[string]float64{
+ "telegram": 20, // 20 msg/s
+ "discord": 1, // 1 msg/s
+ "slack": 1, // 1 msg/s
+ "line": 10, // 10 msg/s
+}
+// 默认: 10 msg/s
+// burst = max(1, ceil(rate/2))
+```
+
+#### 生命周期管理
+
+```
+StartAll:
+ 1. 遍历已注册 channels → channel.Start(ctx)
+ 2. 为每个启动成功的 channel 创建 channelWorker
+ 3. 启动 goroutines:
+ - runWorker (per-channel 出站文本)
+ - runMediaWorker (per-channel 出站媒体)
+ - dispatchOutbound (从 bus 路由到 worker 队列)
+ - dispatchOutboundMedia (从 bus 路由到 media worker 队列)
+ - runTTLJanitor (每 10s 清理过期 typing/reaction/placeholder)
+ 4. 启动共享 HTTP 服务器(如已配置)
+
+StopAll:
+ 1. 关闭共享 HTTP 服务器(5s 超时)
+ 2. 取消 dispatcher context
+ 3. 关闭 text worker 队列 → 等待排水完成
+ 4. 关闭 media worker 队列 → 等待排水完成
+ 5. 停止每个 channel(channel.Stop)
+```
+
+#### Typing/Reaction/Placeholder 管理
+
+```go
+// Manager 实现 PlaceholderRecorder 接口
+func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string)
+func (m *Manager) RecordTypingStop(channel, chatID string, stop func())
+func (m *Manager) RecordReactionUndo(channel, chatID string, undo func())
+
+// 入站侧:BaseChannel.HandleMessage 自动编排
+// BaseChannel.HandleMessage 在 PublishInbound 之前,通过 owner 类型断言自动触发:
+// - TypingCapable.StartTyping → RecordTypingStop
+// - ReactionCapable.ReactToMessage → RecordReactionUndo
+// - PlaceholderCapable.SendPlaceholder → RecordPlaceholder
+// 三者独立,互不干扰。Channel 无需手动调用。
+
+// 出站侧:发送前处理
+func (m *Manager) preSend(ctx, name, msg, ch) bool {
+ key := name + ":" + msg.ChatID
+ // 1. 停止 Typing(调用存储的 stop 函数)
+ // 2. 撤销 Reaction(调用存储的 undo 函数)
+ // 3. 尝试编辑 Placeholder(如果 channel 实现了 MessageEditor)
+ // 成功 → return true(跳过 Send)
+ // 失败 → return false(继续 Send)
+}
+```
+
+Manager 存储完全分离,三条管道互不干扰:
+
+```go
+Manager {
+ typingStops sync.Map // "channel:chatID" → typingEntry ← 管 TypingCapable
+ reactionUndos sync.Map // "channel:chatID" → reactionEntry ← 管 ReactionCapable
+ placeholders sync.Map // "channel:chatID" → placeholderEntry
+}
+```
+
+TTL 清理:
+- Typing 停止函数:5 分钟 TTL(到期后自动调用 stop 并删除)
+- Reaction 撤销函数:5 分钟 TTL(到期后自动调用 undo 并删除)
+- Placeholder ID:10 分钟 TTL(到期后删除)
+- 清理间隔:10 秒
+
+### 4.7 消息分割
+
+**文件**:`pkg/channels/split.go`
+
+`SplitMessage(content string, maxLen int) []string`
+
+智能分割策略:
+1. 计算有效分割点 = maxLen - 10% 缓冲区(为代码块闭合留空间)
+2. 优先在换行符处分割
+3. 其次在空格/制表符处分割
+4. 检测未闭合的代码块(` ``` `)
+5. 如果代码块未闭合:
+ - 尝试扩展到 maxLen 以包含闭合围栏
+ - 如果代码块太长,注入闭合/重开围栏(`\n```\n` + header)
+ - 最后手段:在代码块开始前分割
+
+### 4.8 MediaStore
+
+**文件**:`pkg/media/store.go`
+
+```go
+type MediaStore interface {
+ Store(localPath string, meta MediaMeta, scope string) (ref string, err error)
+ Resolve(ref string) (localPath string, err error)
+ ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error)
+ ReleaseAll(scope string) error
+}
+```
+
+**FileMediaStore 实现**:
+- 纯内存映射,不复制/移动文件
+- 引用格式:`media://`
+- Scope 格式:`channel:chatID:messageID`(由 `BuildMediaScope` 生成)
+- **两阶段操作**:
+ - Phase 1(持锁):从 map 中收集并删除条目
+ - Phase 2(无锁):从磁盘删除文件
+ - 目的:最小化锁争用
+- **TTL 清理**:`NewFileMediaStoreWithCleanup` → `Start()` 启动后台清理协程
+- 清理间隔和最大存活时间由配置控制
+
+### 4.9 Identity
+
+**文件**:`pkg/identity/identity.go`
+
+```go
+// 构建规范 ID
+func BuildCanonicalID(platform, platformID string) string
+// → "telegram:123456"
+
+// 解析规范 ID
+func ParseCanonicalID(canonical string) (platform, id string, ok bool)
+
+// 匹配允许列表(向后兼容)
+func MatchAllowed(sender bus.SenderInfo, allowed string) bool
+```
+
+`MatchAllowed` 支持的允许列表格式:
+| 格式 | 匹配方式 |
+|------|----------|
+| `"123456"` | 匹配 `sender.PlatformID` |
+| `"@alice"` | 匹配 `sender.Username` |
+| `"123456\|alice"` | 匹配 PlatformID 或 Username(旧格式兼容) |
+| `"telegram:123456"` | 精确匹配 `sender.CanonicalID`(新格式) |
+
+### 4.10 共享 HTTP 服务器
+
+**文件**:`pkg/channels/manager.go` 的 `SetupHTTPServer`
+
+Manager 创建单一 `http.Server`,自动发现和注册:
+- 实现 `WebhookHandler` 的 channel → 挂载到 `wh.WebhookPath()`
+- 实现 `HealthChecker` 的 channel → 挂载到 `hc.HealthPath()`
+- Health 全局端点由 `health.Server.RegisterOnMux` 注册
+
+超时配置:ReadTimeout = 30s, WriteTimeout = 30s
+
+---
+
+## 第五部分:关键设计决策与约定
+
+### 5.1 必须遵守的约定
+
+1. **错误分类是合约**:Channel 的 `Send` 方法**必须**返回哨兵错误(或包装它们)。Manager 的重试策略完全依赖 `errors.Is` 检查。如果返回未分类的错误,Manager 会按"未知错误"处理(指数退避重试)。
+
+2. **SetRunning 是生命周期信号**:`Start` 成功后**必须**调用 `c.SetRunning(true)`,`Stop` 开始时**必须**调用 `c.SetRunning(false)`。`Send` 中**必须**检查 `c.IsRunning()` 并返回 `ErrNotRunning`。
+
+3. **HandleMessage 包含权限检查**:不要在调用 `HandleMessage` 之前自行进行权限检查(除非你需要在检查前做平台特定的预处理)。`HandleMessage` 内部已经调用 `IsAllowedSender`/`IsAllowed`。
+
+4. **消息分割由 Manager 处理**:Channel 的 `Send` 方法不需要处理长消息分割。Manager 会在调用 `Send` 之前根据 `MaxMessageLength()` 自动分割。Channel 只需通过 `WithMaxMessageLength` 声明限制。
+
+5. **Typing/Reaction/Placeholder 由 BaseChannel + Manager 自动处理**:Channel 的 `Send` 方法不需要管理 Typing 停止、Reaction 撤销或 Placeholder 编辑。`BaseChannel.HandleMessage` 在入站侧自动触发 `TypingCapable`、`ReactionCapable` 和 `PlaceholderCapable`(通过 `owner` 类型断言);Manager 的 `preSend` 在出站侧自动停止 Typing、撤销 Reaction、编辑 Placeholder。Channel 只需实现对应接口即可。
+
+6. **工厂注册在 init() 中**:每个子包必须有 `init.go` 文件调用 `channels.RegisterFactory`。Gateway 必须通过 blank import(`_ "pkg/channels/xxx"`)触发注册。
+
+### 5.2 Metadata 字段使用约定
+
+**不要再把以下信息放入 Metadata**:
+- `peer_kind` / `peer_id` → 使用 `InboundMessage.Peer`
+- `message_id` → 使用 `InboundMessage.MessageID`
+- `sender_platform` / `sender_username` → 使用 `InboundMessage.Sender`
+
+**Metadata 仅用于**:
+- Channel 特有的扩展信息(如 Telegram 的 `reply_to_message_id`)
+- 不适合放入结构化字段的临时信息
+
+### 5.3 并发安全约定
+
+- `BaseChannel.running`:使用 `atomic.Bool`,线程安全
+- `Manager.channels` / `Manager.workers`:使用 `sync.RWMutex` 保护
+- `Manager.placeholders` / `Manager.typingStops` / `Manager.reactionUndos`:使用 `sync.Map`
+- `MessageBus.closed`:使用 `atomic.Bool`
+- `FileMediaStore`:使用 `sync.RWMutex`,两阶段操作减少持锁时间
+- Channel Worker queue:Go channel,天然并发安全
+
+### 5.4 测试约定
+
+已有测试文件:
+- `pkg/channels/base_test.go` — BaseChannel 单元测试
+- `pkg/channels/manager_test.go` — Manager 单元测试
+- `pkg/channels/split_test.go` — 消息分割测试
+- `pkg/channels/errors_test.go` — 错误类型测试
+- `pkg/channels/errutil_test.go` — 错误分类测试
+
+为新 channel 添加测试时:
+```bash
+go test ./pkg/channels/matrix/ -v # 子包测试
+go test ./pkg/channels/ -run TestSpecific -v # 框架测试
+make test # 全量测试
+```
+
+---
+
+## 附录:完整文件清单与接口速查表
+
+### A.1 框架层文件
+
+| 文件 | 职责 |
+|------|------|
+| `pkg/channels/base.go` | BaseChannel 结构体、Channel 接口、MessageLengthProvider、BaseChannelOption、HandleMessage |
+| `pkg/channels/interfaces.go` | TypingCapable、MessageEditor、ReactionCapable、PlaceholderCapable、PlaceholderRecorder 接口 |
+| `pkg/channels/media.go` | MediaSender 接口 |
+| `pkg/channels/webhook.go` | WebhookHandler、HealthChecker 接口 |
+| `pkg/channels/errors.go` | ErrNotRunning、ErrRateLimit、ErrTemporary、ErrSendFailed 哨兵 |
+| `pkg/channels/errutil.go` | ClassifySendError、ClassifyNetError 帮助函数 |
+| `pkg/channels/registry.go` | RegisterFactory、getFactory 工厂注册表 |
+| `pkg/channels/manager.go` | Manager:Worker 队列、速率限制、重试、preSend、共享 HTTP、TTL janitor |
+| `pkg/channels/split.go` | SplitMessage 长消息分割 |
+| `pkg/bus/bus.go` | MessageBus 实现 |
+| `pkg/bus/types.go` | Peer、SenderInfo、InboundMessage、OutboundMessage、OutboundMediaMessage、MediaPart |
+| `pkg/media/store.go` | MediaStore 接口、FileMediaStore 实现 |
+| `pkg/identity/identity.go` | BuildCanonicalID、ParseCanonicalID、MatchAllowed |
+
+### A.2 Channel 子包
+
+| 子包 | 注册名 | 可选接口 |
+|------|--------|----------|
+| `pkg/channels/telegram/` | `"telegram"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/discord/` | `"discord"` | TypingCapable, PlaceholderCapable, MessageEditor, MediaSender |
+| `pkg/channels/slack/` | `"slack"` | ReactionCapable, MediaSender |
+| `pkg/channels/line/` | `"line"` | TypingCapable, MediaSender, WebhookHandler |
+| `pkg/channels/onebot/` | `"onebot"` | ReactionCapable, MediaSender |
+| `pkg/channels/dingtalk/` | `"dingtalk"` | — |
+| `pkg/channels/feishu/` | `"feishu"` | — (架构特定 build tags: `feishu_32.go` / `feishu_64.go`) |
+| `pkg/channels/wecom/` | `"wecom"` | WebhookHandler, HealthChecker |
+| `pkg/channels/wecom/` | `"wecom_app"` | MediaSender, WebhookHandler, HealthChecker |
+| `pkg/channels/qq/` | `"qq"` | — |
+| `pkg/channels/whatsapp/` | `"whatsapp"` | — (Bridge 模式) |
+| `pkg/channels/whatsapp_native/` | `"whatsapp_native"` | — (原生 whatsmeow 模式) |
+| `pkg/channels/maixcam/` | `"maixcam"` | — |
+| `pkg/channels/pico/` | `"pico"` | TypingCapable, PlaceholderCapable, MessageEditor, WebhookHandler |
+
+### A.3 接口速查表
+
+```go
+// ===== 必须实现 =====
+type Channel interface {
+ Name() string
+ Start(ctx context.Context) error
+ Stop(ctx context.Context) error
+ Send(ctx context.Context, msg bus.OutboundMessage) error
+ IsRunning() bool
+ IsAllowed(senderID string) bool
+ IsAllowedSender(sender bus.SenderInfo) bool
+ ReasoningChannelID() string
+}
+
+// ===== 可选实现 =====
+type MediaSender interface {
+ SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error
+}
+
+type TypingCapable interface {
+ StartTyping(ctx context.Context, chatID string) (stop func(), err error)
+}
+
+type ReactionCapable interface {
+ ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error)
+}
+
+type PlaceholderCapable interface {
+ SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error)
+}
+
+type MessageEditor interface {
+ EditMessage(ctx context.Context, chatID, messageID, content string) error
+}
+
+type WebhookHandler interface {
+ WebhookPath() string
+ http.Handler
+}
+
+type HealthChecker interface {
+ HealthPath() string
+ HealthHandler(w http.ResponseWriter, r *http.Request)
+}
+
+type MessageLengthProvider interface {
+ MaxMessageLength() int
+}
+
+// ===== 由 Manager 注入 =====
+type PlaceholderRecorder interface {
+ RecordPlaceholder(channel, chatID, placeholderID string)
+ RecordTypingStop(channel, chatID string, stop func())
+ RecordReactionUndo(channel, chatID string, undo func())
+}
+```
+
+### A.4 Gateway 启动序列(完整引导流程)
+
+```go
+// 1. 创建核心组件
+msgBus := bus.NewMessageBus()
+provider := providers.CreateProvider(cfg)
+agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
+
+// 2. 创建媒体存储(带 TTL 清理)
+mediaStore := media.NewFileMediaStoreWithCleanup(cleanerConfig)
+mediaStore.Start()
+
+// 3. 创建 Channel Manager(触发 initChannels → 工厂查找 → 构造 → 注入 MediaStore/PlaceholderRecorder/Owner)
+channelManager := channels.NewManager(cfg, msgBus, mediaStore)
+
+// 4. 注入引用
+agentLoop.SetChannelManager(channelManager)
+agentLoop.SetMediaStore(mediaStore)
+
+// 5. 配置共享 HTTP 服务器
+channelManager.SetupHTTPServer(addr, healthServer)
+
+// 6. 启动
+channelManager.StartAll(ctx) // 启动 channels + workers + dispatchers + HTTP server
+go agentLoop.Run(ctx) // 启动 Agent 消息循环
+
+// 7. 关闭(信号触发)
+cancel() // 取消 context
+msgBus.Close() // 信号关闭 + 排水
+channelManager.StopAll(shutdownCtx) // 停止 HTTP + workers + channels
+mediaStore.Stop() // 停止 TTL 清理
+agentLoop.Stop() // 停止 Agent
+```
+
+### A.5 Per-channel 速率限制参考
+
+| Channel | 速率 (msg/s) | Burst |
+|---------|-------------|-------|
+| telegram | 20 | 10 |
+| discord | 1 | 1 |
+| slack | 1 | 1 |
+| line | 10 | 5 |
+| _其他_ | 10 (默认) | 5 |
+
+### A.6 已知限制和注意事项
+
+1. **媒体清理暂时禁用**:Agent loop 中的 `ReleaseAll` 调用被注释掉了(`refactor(loop): disable media cleanup to prevent premature file deletion`),因为会话边界尚未明确定义。TTL 清理仍然有效。
+
+2. **Feishu 架构特定编译**:Feishu channel 使用 build tags 区分 32 位和 64 位架构(`feishu_32.go` / `feishu_64.go`)。Feishu 使用 SDK 的 WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
+
+3. **WeCom 有两个工厂**:`"wecom"`(Bot 模式,纯 webhook)和 `"wecom_app"`(应用模式,支持 MediaSender)分别注册。两者都实现了 `WebhookHandler` 和 `HealthChecker`。
+
+4. **Pico Protocol**:`pkg/channels/pico/` 实现了一个自定义的 PicoClaw 原生协议 channel,通过 WebSocket webhook (`/pico/ws`) 接收消息。
+
+5. **WhatsApp 有两种模式**:`"whatsapp"`(Bridge 模式,通过外部 bridge URL 通信)和 `"whatsapp_native"`(原生 whatsmeow 模式,直接连接 WhatsApp)。Manager 根据 `WhatsAppConfig.UseNative` 决定初始化哪个。
+
+6. **DingTalk 使用 Stream 模式**:DingTalk 使用 SDK 的 Stream/WebSocket 模式(非 HTTP webhook),因此不实现 `WebhookHandler`。
+
+7. **PlaceholderConfig 的配置与实现**:`PlaceholderConfig` 出现在 6 个 channel config 中(Telegram、Discord、Slack、LINE、OneBot、Pico),但只有实现了 `PlaceholderCapable` + `MessageEditor` 的 channel(Telegram、Discord、Pico)能真正使用占位消息编辑功能。其余 channel 的 `PlaceholderConfig` 为预留字段。
+
+8. **ReasoningChannelID**:大多数 channel config 都包含 `reasoning_channel_id` 字段,用于将 LLM 的思维链(reasoning/thinking)路由到指定 channel(WhatsApp、Telegram、Feishu、Discord、MaixCam、QQ、DingTalk、Slack、LINE、OneBot、WeCom、WeComApp)。注意:`PicoConfig` 目前不包含该字段。`BaseChannel` 通过 `WithReasoningChannelID` 选项和 `ReasoningChannelID()` 方法暴露此配置。
\ No newline at end of file
diff --git a/pkg/channels/base.go b/pkg/channels/base.go
index cd6419ebb..063a66523 100644
--- a/pkg/channels/base.go
+++ b/pkg/channels/base.go
@@ -2,11 +2,44 @@ package channels
import (
"context"
+ "crypto/rand"
+ "encoding/binary"
+ "encoding/hex"
+ "strconv"
"strings"
+ "sync/atomic"
+ "time"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
)
+var (
+ uniqueIDCounter uint64
+ uniqueIDPrefix string
+)
+
+func init() {
+ // One-time read from crypto/rand for a unique prefix (single syscall).
+ var b [8]byte
+ if _, err := rand.Read(b[:]); err != nil {
+ // fallback to time-based prefix
+ binary.BigEndian.PutUint64(b[:], uint64(time.Now().UnixNano()))
+ }
+ uniqueIDPrefix = hex.EncodeToString(b[:])
+}
+
+// uniqueID generates a process-unique ID using a random prefix and an atomic counter.
+// This ID is intended for internal correlation (e.g. media scope keys) and is NOT
+// cryptographically secure — it must not be used in contexts where unpredictability matters.
+func uniqueID() string {
+ n := atomic.AddUint64(&uniqueIDCounter, 1)
+ return uniqueIDPrefix + strconv.FormatUint(n, 16)
+}
+
type Channel interface {
Name() string
Start(ctx context.Context) error
@@ -14,32 +47,126 @@ type Channel interface {
Send(ctx context.Context, msg bus.OutboundMessage) error
IsRunning() bool
IsAllowed(senderID string) bool
+ IsAllowedSender(sender bus.SenderInfo) bool
+ ReasoningChannelID() string
+}
+
+// BaseChannelOption is a functional option for configuring a BaseChannel.
+type BaseChannelOption func(*BaseChannel)
+
+// WithMaxMessageLength sets the maximum message length (in runes) for a channel.
+// Messages exceeding this limit will be automatically split by the Manager.
+// A value of 0 means no limit.
+func WithMaxMessageLength(n int) BaseChannelOption {
+ return func(c *BaseChannel) { c.maxMessageLength = n }
+}
+
+// WithGroupTrigger sets the group trigger configuration for a channel.
+func WithGroupTrigger(gt config.GroupTriggerConfig) BaseChannelOption {
+ return func(c *BaseChannel) { c.groupTrigger = gt }
+}
+
+// WithReasoningChannelID sets the reasoning channel ID where thoughts should be sent.
+func WithReasoningChannelID(id string) BaseChannelOption {
+ return func(c *BaseChannel) { c.reasoningChannelID = id }
+}
+
+// MessageLengthProvider is an opt-in interface that channels implement
+// to advertise their maximum message length. The Manager uses this via
+// type assertion to decide whether to split outbound messages.
+type MessageLengthProvider interface {
+ MaxMessageLength() int
}
type BaseChannel struct {
- config any
- bus *bus.MessageBus
- running bool
- name string
- allowList []string
+ config any
+ bus *bus.MessageBus
+ running atomic.Bool
+ name string
+ allowList []string
+ maxMessageLength int
+ groupTrigger config.GroupTriggerConfig
+ mediaStore media.MediaStore
+ placeholderRecorder PlaceholderRecorder
+ owner Channel // the concrete channel that embeds this BaseChannel
+ reasoningChannelID string
}
-func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []string) *BaseChannel {
- return &BaseChannel{
+func NewBaseChannel(
+ name string,
+ config any,
+ bus *bus.MessageBus,
+ allowList []string,
+ opts ...BaseChannelOption,
+) *BaseChannel {
+ bc := &BaseChannel{
config: config,
bus: bus,
name: name,
allowList: allowList,
- running: false,
}
+ for _, opt := range opts {
+ opt(bc)
+ }
+ return bc
+}
+
+// MaxMessageLength returns the maximum message length (in runes) for this channel.
+// A value of 0 means no limit.
+func (c *BaseChannel) MaxMessageLength() int {
+ return c.maxMessageLength
+}
+
+// ShouldRespondInGroup determines whether the bot should respond in a group chat.
+// Each channel is responsible for:
+// 1. Detecting isMentioned (platform-specific)
+// 2. Stripping bot mention from content (platform-specific)
+// 3. Calling this method to get the group response decision
+//
+// Logic:
+// - If isMentioned → always respond
+// - If mention_only configured and not mentioned → ignore
+// - If prefixes configured → respond if content starts with any prefix (strip it)
+// - If prefixes configured but no match and not mentioned → ignore
+// - Otherwise (no group_trigger configured) → respond to all (permissive default)
+func (c *BaseChannel) ShouldRespondInGroup(isMentioned bool, content string) (bool, string) {
+ gt := c.groupTrigger
+
+ // Mentioned → always respond
+ if isMentioned {
+ return true, strings.TrimSpace(content)
+ }
+
+ // mention_only → require mention
+ if gt.MentionOnly {
+ return false, content
+ }
+
+ // Prefix matching
+ if len(gt.Prefixes) > 0 {
+ for _, prefix := range gt.Prefixes {
+ if prefix != "" && strings.HasPrefix(content, prefix) {
+ return true, strings.TrimSpace(strings.TrimPrefix(content, prefix))
+ }
+ }
+ // Prefixes configured but none matched and not mentioned → ignore
+ return false, content
+ }
+
+ // No group_trigger configured → permissive (respond to all)
+ return true, strings.TrimSpace(content)
}
func (c *BaseChannel) Name() string {
return c.name
}
+func (c *BaseChannel) ReasoningChannelID() string {
+ return c.reasoningChannelID
+}
+
func (c *BaseChannel) IsRunning() bool {
- return c.running
+ return c.running.Load()
}
func (c *BaseChannel) IsAllowed(senderID string) bool {
@@ -81,23 +208,130 @@ func (c *BaseChannel) IsAllowed(senderID string) bool {
return false
}
-func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []string, metadata map[string]string) {
- if !c.IsAllowed(senderID) {
- return
+// IsAllowedSender checks whether a structured SenderInfo is permitted by the allow-list.
+// It delegates to identity.MatchAllowed for each entry, providing unified matching
+// across all legacy formats and the new canonical "platform:id" format.
+func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool {
+ if len(c.allowList) == 0 {
+ return true
}
+ for _, allowed := range c.allowList {
+ if identity.MatchAllowed(sender, allowed) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (c *BaseChannel) HandleMessage(
+ ctx context.Context,
+ peer bus.Peer,
+ messageID, senderID, chatID, content string,
+ media []string,
+ metadata map[string]string,
+ senderOpts ...bus.SenderInfo,
+) {
+ // Use SenderInfo-based allow check when available, else fall back to string
+ var sender bus.SenderInfo
+ if len(senderOpts) > 0 {
+ sender = senderOpts[0]
+ }
+ if sender.CanonicalID != "" || sender.PlatformID != "" {
+ if !c.IsAllowedSender(sender) {
+ return
+ }
+ } else {
+ if !c.IsAllowed(senderID) {
+ return
+ }
+ }
+
+ // Set SenderID to canonical if available, otherwise keep the raw senderID
+ resolvedSenderID := senderID
+ if sender.CanonicalID != "" {
+ resolvedSenderID = sender.CanonicalID
+ }
+
+ scope := BuildMediaScope(c.name, chatID, messageID)
+
msg := bus.InboundMessage{
- Channel: c.name,
- SenderID: senderID,
- ChatID: chatID,
- Content: content,
- Media: media,
- Metadata: metadata,
+ Channel: c.name,
+ SenderID: resolvedSenderID,
+ Sender: sender,
+ ChatID: chatID,
+ Content: content,
+ Media: media,
+ Peer: peer,
+ MessageID: messageID,
+ MediaScope: scope,
+ Metadata: metadata,
}
- c.bus.PublishInbound(msg)
+ // Auto-trigger typing indicator, message reaction, and placeholder before publishing.
+ // Each capability is independent — all three may fire for the same message.
+ if c.owner != nil && c.placeholderRecorder != nil {
+ // Typing — independent pipeline
+ if tc, ok := c.owner.(TypingCapable); ok {
+ if stop, err := tc.StartTyping(ctx, chatID); err == nil {
+ c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop)
+ }
+ }
+ // Reaction — independent pipeline
+ if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" {
+ if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil {
+ c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo)
+ }
+ }
+ // Placeholder — independent pipeline
+ if pc, ok := c.owner.(PlaceholderCapable); ok {
+ if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" {
+ c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID)
+ }
+ }
+ }
+
+ if err := c.bus.PublishInbound(ctx, msg); err != nil {
+ logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{
+ "channel": c.name,
+ "chat_id": chatID,
+ "error": err.Error(),
+ })
+ }
}
-func (c *BaseChannel) setRunning(running bool) {
- c.running = running
+func (c *BaseChannel) SetRunning(running bool) {
+ c.running.Store(running)
+}
+
+// SetMediaStore injects a MediaStore into the channel.
+func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s }
+
+// GetMediaStore returns the injected MediaStore (may be nil).
+func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore }
+
+// SetPlaceholderRecorder injects a PlaceholderRecorder into the channel.
+func (c *BaseChannel) SetPlaceholderRecorder(r PlaceholderRecorder) {
+ c.placeholderRecorder = r
+}
+
+// GetPlaceholderRecorder returns the injected PlaceholderRecorder (may be nil).
+func (c *BaseChannel) GetPlaceholderRecorder() PlaceholderRecorder {
+ return c.placeholderRecorder
+}
+
+// SetOwner injects the concrete channel that embeds this BaseChannel.
+// This allows HandleMessage to auto-trigger TypingCapable / ReactionCapable / PlaceholderCapable.
+func (c *BaseChannel) SetOwner(ch Channel) {
+ c.owner = ch
+}
+
+// BuildMediaScope constructs a scope key for media lifecycle tracking.
+func BuildMediaScope(channel, chatID, messageID string) string {
+ id := messageID
+ if id == "" {
+ id = uniqueID()
+ }
+ return channel + ":" + chatID + ":" + id
}
diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go
index 78c6d1d66..6132b8bf9 100644
--- a/pkg/channels/base_test.go
+++ b/pkg/channels/base_test.go
@@ -1,6 +1,11 @@
package channels
-import "testing"
+import (
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
func TestBaseChannelIsAllowed(t *testing.T) {
tests := []struct {
@@ -50,3 +55,211 @@ func TestBaseChannelIsAllowed(t *testing.T) {
})
}
}
+
+func TestShouldRespondInGroup(t *testing.T) {
+ tests := []struct {
+ name string
+ gt config.GroupTriggerConfig
+ isMentioned bool
+ content string
+ wantRespond bool
+ wantContent string
+ }{
+ {
+ name: "no config - permissive default",
+ gt: config.GroupTriggerConfig{},
+ isMentioned: false,
+ content: "hello world",
+ wantRespond: true,
+ wantContent: "hello world",
+ },
+ {
+ name: "no config - mentioned",
+ gt: config.GroupTriggerConfig{},
+ isMentioned: true,
+ content: "hello world",
+ wantRespond: true,
+ wantContent: "hello world",
+ },
+ {
+ name: "mention_only - not mentioned",
+ gt: config.GroupTriggerConfig{MentionOnly: true},
+ isMentioned: false,
+ content: "hello world",
+ wantRespond: false,
+ wantContent: "hello world",
+ },
+ {
+ name: "mention_only - mentioned",
+ gt: config.GroupTriggerConfig{MentionOnly: true},
+ isMentioned: true,
+ content: "hello world",
+ wantRespond: true,
+ wantContent: "hello world",
+ },
+ {
+ name: "prefix match",
+ gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}},
+ isMentioned: false,
+ content: "/ask hello",
+ wantRespond: true,
+ wantContent: "hello",
+ },
+ {
+ name: "prefix no match - not mentioned",
+ gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}},
+ isMentioned: false,
+ content: "hello world",
+ wantRespond: false,
+ wantContent: "hello world",
+ },
+ {
+ name: "prefix no match - but mentioned",
+ gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}},
+ isMentioned: true,
+ content: "hello world",
+ wantRespond: true,
+ wantContent: "hello world",
+ },
+ {
+ name: "multiple prefixes - second matches",
+ gt: config.GroupTriggerConfig{Prefixes: []string{"/ask", "/bot"}},
+ isMentioned: false,
+ content: "/bot help me",
+ wantRespond: true,
+ wantContent: "help me",
+ },
+ {
+ name: "mention_only with prefixes - mentioned overrides",
+ gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}},
+ isMentioned: true,
+ content: "hello",
+ wantRespond: true,
+ wantContent: "hello",
+ },
+ {
+ name: "mention_only with prefixes - not mentioned, no prefix",
+ gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}},
+ isMentioned: false,
+ content: "hello",
+ wantRespond: false,
+ wantContent: "hello",
+ },
+ {
+ name: "empty prefix in list is skipped",
+ gt: config.GroupTriggerConfig{Prefixes: []string{"", "/ask"}},
+ isMentioned: false,
+ content: "/ask test",
+ wantRespond: true,
+ wantContent: "test",
+ },
+ {
+ name: "prefix strips leading whitespace after prefix",
+ gt: config.GroupTriggerConfig{Prefixes: []string{"/ask "}},
+ isMentioned: false,
+ content: "/ask hello",
+ wantRespond: true,
+ wantContent: "hello",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ch := NewBaseChannel("test", nil, nil, nil, WithGroupTrigger(tt.gt))
+ gotRespond, gotContent := ch.ShouldRespondInGroup(tt.isMentioned, tt.content)
+ if gotRespond != tt.wantRespond {
+ t.Errorf("ShouldRespondInGroup() respond = %v, want %v", gotRespond, tt.wantRespond)
+ }
+ if gotContent != tt.wantContent {
+ t.Errorf("ShouldRespondInGroup() content = %q, want %q", gotContent, tt.wantContent)
+ }
+ })
+ }
+}
+
+func TestIsAllowedSender(t *testing.T) {
+ tests := []struct {
+ name string
+ allowList []string
+ sender bus.SenderInfo
+ want bool
+ }{
+ {
+ name: "empty allowlist allows all",
+ allowList: nil,
+ sender: bus.SenderInfo{PlatformID: "anyone"},
+ want: true,
+ },
+ {
+ name: "numeric ID matches PlatformID",
+ allowList: []string{"123456"},
+ sender: bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "123456",
+ CanonicalID: "telegram:123456",
+ },
+ want: true,
+ },
+ {
+ name: "canonical format matches",
+ allowList: []string{"telegram:123456"},
+ sender: bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "123456",
+ CanonicalID: "telegram:123456",
+ },
+ want: true,
+ },
+ {
+ name: "canonical format wrong platform",
+ allowList: []string{"discord:123456"},
+ sender: bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "123456",
+ CanonicalID: "telegram:123456",
+ },
+ want: false,
+ },
+ {
+ name: "@username matches",
+ allowList: []string{"@alice"},
+ sender: bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "123456",
+ CanonicalID: "telegram:123456",
+ Username: "alice",
+ },
+ want: true,
+ },
+ {
+ name: "compound id|username matches by ID",
+ allowList: []string{"123456|alice"},
+ sender: bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "123456",
+ CanonicalID: "telegram:123456",
+ Username: "alice",
+ },
+ want: true,
+ },
+ {
+ name: "non matching sender denied",
+ allowList: []string{"654321"},
+ sender: bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "123456",
+ CanonicalID: "telegram:123456",
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ch := NewBaseChannel("test", nil, nil, tt.allowList)
+ if got := ch.IsAllowedSender(tt.sender); got != tt.want {
+ t.Fatalf("IsAllowedSender(%+v) = %v, want %v", tt.sender, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go
similarity index 83%
rename from pkg/channels/dingtalk.go
rename to pkg/channels/dingtalk/dingtalk.go
index 662fba3b7..8642ad362 100644
--- a/pkg/channels/dingtalk.go
+++ b/pkg/channels/dingtalk/dingtalk.go
@@ -1,7 +1,7 @@
// PicoClaw - Ultra-lightweight personal AI agent
// DingTalk channel implementation using Stream Mode
-package channels
+package dingtalk
import (
"context"
@@ -12,7 +12,9 @@ import (
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -20,7 +22,7 @@ import (
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
// It uses WebSocket for receiving messages via stream mode and API for sending
type DingTalkChannel struct {
- *BaseChannel
+ *channels.BaseChannel
config config.DingTalkConfig
clientID string
clientSecret string
@@ -37,7 +39,11 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
return nil, fmt.Errorf("dingtalk client_id and client_secret are required")
}
- base := NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom)
+ base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(20000),
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
return &DingTalkChannel{
BaseChannel: base,
@@ -70,7 +76,7 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
return fmt.Errorf("failed to start stream client: %w", err)
}
- c.setRunning(true)
+ c.SetRunning(true)
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
return nil
}
@@ -87,7 +93,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
c.streamClient.Close()
}
- c.setRunning(false)
+ c.SetRunning(false)
logger.InfoC("dingtalk", "DingTalk channel stopped")
return nil
}
@@ -95,7 +101,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
// Send sends a message to DingTalk via the chatbot reply API
func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("dingtalk channel not running")
+ return channels.ErrNotRunning
}
// Get session webhook from storage
@@ -159,12 +165,17 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
"session_webhook": data.SessionWebhook,
}
+ var peer bus.Peer
if data.ConversationType == "1" {
- metadata["peer_kind"] = "direct"
- metadata["peer_id"] = senderID
+ peer = bus.Peer{Kind: "direct", ID: senderID}
} else {
- metadata["peer_kind"] = "group"
- metadata["peer_id"] = data.ConversationId
+ peer = bus.Peer{Kind: "group", ID: data.ConversationId}
+ // In group chats, apply unified group trigger filtering
+ respond, cleaned := c.ShouldRespondInGroup(false, content)
+ if !respond {
+ return nil, nil
+ }
+ content = cleaned
}
logger.DebugCF("dingtalk", "Received message", map[string]any{
@@ -173,8 +184,20 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
"preview": utils.Truncate(content, 50),
})
+ // Build sender info
+ sender := bus.SenderInfo{
+ Platform: "dingtalk",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("dingtalk", senderID),
+ DisplayName: senderNick,
+ }
+
+ if !c.IsAllowedSender(sender) {
+ return nil, nil
+ }
+
// Handle the message through the base channel
- c.HandleMessage(senderID, chatID, content, nil, metadata)
+ c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender)
// Return nil to indicate we've handled the message asynchronously
// The response will be sent through the message bus
@@ -197,7 +220,7 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c
contentBytes,
)
if err != nil {
- return fmt.Errorf("failed to send reply: %w", err)
+ return fmt.Errorf("dingtalk send: %w", channels.ErrTemporary)
}
return nil
diff --git a/pkg/channels/dingtalk/init.go b/pkg/channels/dingtalk/init.go
new file mode 100644
index 000000000..5f49bce8c
--- /dev/null
+++ b/pkg/channels/dingtalk/init.go
@@ -0,0 +1,13 @@
+package dingtalk
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewDingTalkChannel(cfg.Channels.DingTalk, b)
+ })
+}
diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go
deleted file mode 100644
index f6faa3373..000000000
--- a/pkg/channels/discord.go
+++ /dev/null
@@ -1,373 +0,0 @@
-package channels
-
-import (
- "context"
- "fmt"
- "os"
- "strings"
- "sync"
- "time"
-
- "github.com/bwmarrin/discordgo"
-
- "github.com/sipeed/picoclaw/pkg/bus"
- "github.com/sipeed/picoclaw/pkg/config"
- "github.com/sipeed/picoclaw/pkg/logger"
- "github.com/sipeed/picoclaw/pkg/utils"
- "github.com/sipeed/picoclaw/pkg/voice"
-)
-
-const (
- transcriptionTimeout = 30 * time.Second
- sendTimeout = 10 * time.Second
-)
-
-type DiscordChannel struct {
- *BaseChannel
- session *discordgo.Session
- config config.DiscordConfig
- transcriber *voice.GroqTranscriber
- ctx context.Context
- typingMu sync.Mutex
- typingStop map[string]chan struct{} // chatID → stop signal
- botUserID string // stored for mention checking
-}
-
-func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
- session, err := discordgo.New("Bot " + cfg.Token)
- if err != nil {
- return nil, fmt.Errorf("failed to create discord session: %w", err)
- }
-
- base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom)
-
- return &DiscordChannel{
- BaseChannel: base,
- session: session,
- config: cfg,
- transcriber: nil,
- ctx: context.Background(),
- typingStop: make(map[string]chan struct{}),
- }, nil
-}
-
-func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
- c.transcriber = transcriber
-}
-
-func (c *DiscordChannel) getContext() context.Context {
- if c.ctx == nil {
- return context.Background()
- }
- return c.ctx
-}
-
-func (c *DiscordChannel) Start(ctx context.Context) error {
- logger.InfoC("discord", "Starting Discord bot")
-
- c.ctx = ctx
-
- // Get bot user ID before opening session to avoid race condition
- botUser, err := c.session.User("@me")
- if err != nil {
- return fmt.Errorf("failed to get bot user: %w", err)
- }
- c.botUserID = botUser.ID
-
- c.session.AddHandler(c.handleMessage)
-
- if err := c.session.Open(); err != nil {
- return fmt.Errorf("failed to open discord session: %w", err)
- }
-
- c.setRunning(true)
-
- logger.InfoCF("discord", "Discord bot connected", map[string]any{
- "username": botUser.Username,
- "user_id": botUser.ID,
- })
-
- return nil
-}
-
-func (c *DiscordChannel) Stop(ctx context.Context) error {
- logger.InfoC("discord", "Stopping Discord bot")
- c.setRunning(false)
-
- // Stop all typing goroutines before closing session
- c.typingMu.Lock()
- for chatID, stop := range c.typingStop {
- close(stop)
- delete(c.typingStop, chatID)
- }
- c.typingMu.Unlock()
-
- if err := c.session.Close(); err != nil {
- return fmt.Errorf("failed to close discord session: %w", err)
- }
-
- return nil
-}
-
-func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
- c.stopTyping(msg.ChatID)
-
- if !c.IsRunning() {
- return fmt.Errorf("discord bot not running")
- }
-
- channelID := msg.ChatID
- if channelID == "" {
- return fmt.Errorf("channel ID is empty")
- }
-
- runes := []rune(msg.Content)
- if len(runes) == 0 {
- return nil
- }
-
- chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars
-
- for _, chunk := range chunks {
- if err := c.sendChunk(ctx, channelID, chunk); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
- // Use the passed ctx for timeout control
- sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
- defer cancel()
-
- done := make(chan error, 1)
- go func() {
- _, err := c.session.ChannelMessageSend(channelID, content)
- done <- err
- }()
-
- select {
- case err := <-done:
- if err != nil {
- return fmt.Errorf("failed to send discord message: %w", err)
- }
- return nil
- case <-sendCtx.Done():
- return fmt.Errorf("send message timeout: %w", sendCtx.Err())
- }
-}
-
-// appendContent safely appends content to existing text
-func appendContent(content, suffix string) string {
- if content == "" {
- return suffix
- }
- return content + "\n" + suffix
-}
-
-func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
- if m == nil || m.Author == nil {
- return
- }
-
- if m.Author.ID == s.State.User.ID {
- return
- }
-
- // Check allowlist first to avoid downloading attachments and transcribing for rejected users
- if !c.IsAllowed(m.Author.ID) {
- logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
- "user_id": m.Author.ID,
- })
- return
- }
-
- // If configured to only respond to mentions, check if bot is mentioned
- // Skip this check for DMs (GuildID is empty) - DMs should always be responded to
- if c.config.MentionOnly && m.GuildID != "" {
- isMentioned := false
- for _, mention := range m.Mentions {
- if mention.ID == c.botUserID {
- isMentioned = true
- break
- }
- }
- if !isMentioned {
- logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{
- "user_id": m.Author.ID,
- })
- return
- }
- }
-
- senderID := m.Author.ID
- senderName := m.Author.Username
- if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
- senderName += "#" + m.Author.Discriminator
- }
-
- content := m.Content
- content = c.stripBotMention(content)
- mediaPaths := make([]string, 0, len(m.Attachments))
- localFiles := make([]string, 0, len(m.Attachments))
-
- // Ensure temp files are cleaned up when function returns
- defer func() {
- for _, file := range localFiles {
- if err := os.Remove(file); err != nil {
- logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
- "file": file,
- "error": err.Error(),
- })
- }
- }
- }()
-
- for _, attachment := range m.Attachments {
- isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
-
- if isAudio {
- localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
- if localPath != "" {
- localFiles = append(localFiles, localPath)
-
- var transcribedText string
- if c.transcriber != nil && c.transcriber.IsAvailable() {
- ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
- result, err := c.transcriber.Transcribe(ctx, localPath)
- cancel() // Release context resources immediately to avoid leaks in for loop
-
- if err != nil {
- logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
- "error": err.Error(),
- })
- transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
- } else {
- transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
- logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
- "text": result.Text,
- })
- }
- } else {
- transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
- }
-
- content = appendContent(content, transcribedText)
- } else {
- logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
- "url": attachment.URL,
- "filename": attachment.Filename,
- })
- mediaPaths = append(mediaPaths, attachment.URL)
- content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
- }
- } else {
- mediaPaths = append(mediaPaths, attachment.URL)
- content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
- }
- }
-
- if content == "" && len(mediaPaths) == 0 {
- return
- }
-
- if content == "" {
- content = "[media only]"
- }
-
- // Start typing after all early returns — guaranteed to have a matching Send()
- c.startTyping(m.ChannelID)
-
- logger.DebugCF("discord", "Received message", map[string]any{
- "sender_name": senderName,
- "sender_id": senderID,
- "preview": utils.Truncate(content, 50),
- })
-
- peerKind := "channel"
- peerID := m.ChannelID
- if m.GuildID == "" {
- peerKind = "direct"
- peerID = senderID
- }
-
- metadata := map[string]string{
- "message_id": m.ID,
- "user_id": senderID,
- "username": m.Author.Username,
- "display_name": senderName,
- "guild_id": m.GuildID,
- "channel_id": m.ChannelID,
- "is_dm": fmt.Sprintf("%t", m.GuildID == ""),
- "peer_kind": peerKind,
- "peer_id": peerID,
- }
-
- c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
-}
-
-// startTyping starts a continuous typing indicator loop for the given chatID.
-// It stops any existing typing loop for that chatID before starting a new one.
-func (c *DiscordChannel) startTyping(chatID string) {
- c.typingMu.Lock()
- // Stop existing loop for this chatID if any
- if stop, ok := c.typingStop[chatID]; ok {
- close(stop)
- }
- stop := make(chan struct{})
- c.typingStop[chatID] = stop
- c.typingMu.Unlock()
-
- go func() {
- if err := c.session.ChannelTyping(chatID); err != nil {
- logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
- }
- ticker := time.NewTicker(8 * time.Second)
- defer ticker.Stop()
- timeout := time.After(5 * time.Minute)
- for {
- select {
- case <-stop:
- return
- case <-timeout:
- return
- case <-c.ctx.Done():
- return
- case <-ticker.C:
- if err := c.session.ChannelTyping(chatID); err != nil {
- logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
- }
- }
- }
- }()
-}
-
-// stopTyping stops the typing indicator loop for the given chatID.
-func (c *DiscordChannel) stopTyping(chatID string) {
- c.typingMu.Lock()
- defer c.typingMu.Unlock()
- if stop, ok := c.typingStop[chatID]; ok {
- close(stop)
- delete(c.typingStop, chatID)
- }
-}
-
-func (c *DiscordChannel) downloadAttachment(url, filename string) string {
- return utils.DownloadFile(url, filename, utils.DownloadOptions{
- LoggerPrefix: "discord",
- })
-}
-
-// stripBotMention removes the bot mention from the message content.
-// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
-func (c *DiscordChannel) stripBotMention(text string) string {
- if c.botUserID == "" {
- return text
- }
- // Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID>
- text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "")
- text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
- return strings.TrimSpace(text)
-}
diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go
new file mode 100644
index 000000000..1de910c83
--- /dev/null
+++ b/pkg/channels/discord/discord.go
@@ -0,0 +1,521 @@
+package discord
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/bwmarrin/discordgo"
+ "github.com/gorilla/websocket"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+const (
+ sendTimeout = 10 * time.Second
+)
+
+type DiscordChannel struct {
+ *channels.BaseChannel
+ session *discordgo.Session
+ config config.DiscordConfig
+ ctx context.Context
+ cancel context.CancelFunc
+ typingMu sync.Mutex
+ typingStop map[string]chan struct{} // chatID → stop signal
+ botUserID string // stored for mention checking
+}
+
+func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
+ session, err := discordgo.New("Bot " + cfg.Token)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create discord session: %w", err)
+ }
+
+ if err := applyDiscordProxy(session, cfg.Proxy); err != nil {
+ return nil, err
+ }
+ base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(2000),
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
+
+ return &DiscordChannel{
+ BaseChannel: base,
+ session: session,
+ config: cfg,
+ ctx: context.Background(),
+ typingStop: make(map[string]chan struct{}),
+ }, nil
+}
+
+func (c *DiscordChannel) Start(ctx context.Context) error {
+ logger.InfoC("discord", "Starting Discord bot")
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ // Get bot user ID before opening session to avoid race condition
+ botUser, err := c.session.User("@me")
+ if err != nil {
+ return fmt.Errorf("failed to get bot user: %w", err)
+ }
+ c.botUserID = botUser.ID
+
+ c.session.AddHandler(c.handleMessage)
+
+ if err := c.session.Open(); err != nil {
+ return fmt.Errorf("failed to open discord session: %w", err)
+ }
+
+ c.SetRunning(true)
+
+ logger.InfoCF("discord", "Discord bot connected", map[string]any{
+ "username": botUser.Username,
+ "user_id": botUser.ID,
+ })
+
+ return nil
+}
+
+func (c *DiscordChannel) Stop(ctx context.Context) error {
+ logger.InfoC("discord", "Stopping Discord bot")
+ c.SetRunning(false)
+
+ // Stop all typing goroutines before closing session
+ c.typingMu.Lock()
+ for chatID, stop := range c.typingStop {
+ close(stop)
+ delete(c.typingStop, chatID)
+ }
+ c.typingMu.Unlock()
+
+ // Cancel our context so typing goroutines using c.ctx.Done() exit
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ if err := c.session.Close(); err != nil {
+ return fmt.Errorf("failed to close discord session: %w", err)
+ }
+
+ return nil
+}
+
+func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ channelID := msg.ChatID
+ if channelID == "" {
+ return fmt.Errorf("channel ID is empty")
+ }
+
+ if len([]rune(msg.Content)) == 0 {
+ return nil
+ }
+
+ return c.sendChunk(ctx, channelID, msg.Content)
+}
+
+// SendMedia implements the channels.MediaSender interface.
+func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ channelID := msg.ChatID
+ if channelID == "" {
+ return fmt.Errorf("channel ID is empty")
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ // Collect all files into a single ChannelMessageSendComplex call
+ files := make([]*discordgo.File, 0, len(msg.Parts))
+ var caption string
+
+ for _, part := range msg.Parts {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("discord", "Failed to resolve media ref", map[string]any{
+ "ref": part.Ref,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ file, err := os.Open(localPath)
+ if err != nil {
+ logger.ErrorCF("discord", "Failed to open media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+ // Note: discordgo reads from the Reader and we can't close it before send
+
+ filename := part.Filename
+ if filename == "" {
+ filename = "file"
+ }
+
+ files = append(files, &discordgo.File{
+ Name: filename,
+ ContentType: part.ContentType,
+ Reader: file,
+ })
+
+ if part.Caption != "" && caption == "" {
+ caption = part.Caption
+ }
+ }
+
+ if len(files) == 0 {
+ return nil
+ }
+
+ sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ _, err := c.session.ChannelMessageSendComplex(channelID, &discordgo.MessageSend{
+ Content: caption,
+ Files: files,
+ })
+ done <- err
+ }()
+
+ select {
+ case err := <-done:
+ // Close all file readers
+ for _, f := range files {
+ if closer, ok := f.Reader.(*os.File); ok {
+ closer.Close()
+ }
+ }
+ if err != nil {
+ return fmt.Errorf("discord send media: %w", channels.ErrTemporary)
+ }
+ return nil
+ case <-sendCtx.Done():
+ // Close all file readers
+ for _, f := range files {
+ if closer, ok := f.Reader.(*os.File); ok {
+ closer.Close()
+ }
+ }
+ return sendCtx.Err()
+ }
+}
+
+// EditMessage implements channels.MessageEditor.
+func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
+ _, err := c.session.ChannelMessageEdit(chatID, messageID, content)
+ return err
+}
+
+// SendPlaceholder implements channels.PlaceholderCapable.
+// It sends a placeholder message that will later be edited to the actual
+// response via EditMessage (channels.MessageEditor).
+func (c *DiscordChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ if !c.config.Placeholder.Enabled {
+ return "", nil
+ }
+
+ text := c.config.Placeholder.Text
+ if text == "" {
+ text = "Thinking... 💭"
+ }
+
+ msg, err := c.session.ChannelMessageSend(chatID, text)
+ if err != nil {
+ return "", err
+ }
+
+ return msg.ID, nil
+}
+
+func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
+ // Use the passed ctx for timeout control
+ sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ _, err := c.session.ChannelMessageSend(channelID, content)
+ done <- err
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ return fmt.Errorf("discord send: %w", channels.ErrTemporary)
+ }
+ return nil
+ case <-sendCtx.Done():
+ return sendCtx.Err()
+ }
+}
+
+// appendContent safely appends content to existing text
+func appendContent(content, suffix string) string {
+ if content == "" {
+ return suffix
+ }
+ return content + "\n" + suffix
+}
+
+func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
+ if m == nil || m.Author == nil {
+ return
+ }
+
+ if m.Author.ID == s.State.User.ID {
+ return
+ }
+
+ // Check allowlist first to avoid downloading attachments for rejected users
+ sender := bus.SenderInfo{
+ Platform: "discord",
+ PlatformID: m.Author.ID,
+ CanonicalID: identity.BuildCanonicalID("discord", m.Author.ID),
+ Username: m.Author.Username,
+ }
+ // Build display name
+ displayName := m.Author.Username
+ if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
+ displayName += "#" + m.Author.Discriminator
+ }
+ sender.DisplayName = displayName
+
+ if !c.IsAllowedSender(sender) {
+ logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
+ "user_id": m.Author.ID,
+ })
+ return
+ }
+
+ content := m.Content
+
+ // In guild (group) channels, apply unified group trigger filtering
+ // DMs (GuildID is empty) always get a response
+ if m.GuildID != "" {
+ isMentioned := false
+ for _, mention := range m.Mentions {
+ if mention.ID == c.botUserID {
+ isMentioned = true
+ break
+ }
+ }
+ content = c.stripBotMention(content)
+ respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
+ if !respond {
+ logger.DebugCF("discord", "Group message ignored by group trigger", map[string]any{
+ "user_id": m.Author.ID,
+ })
+ return
+ }
+ content = cleaned
+ } else {
+ // DMs: just strip bot mention without filtering
+ content = c.stripBotMention(content)
+ }
+
+ senderID := m.Author.ID
+
+ mediaPaths := make([]string, 0, len(m.Attachments))
+
+ scope := channels.BuildMediaScope("discord", m.ChannelID, m.ID)
+
+ // Helper to register a local file with the media store
+ storeMedia := func(localPath, filename string) string {
+ if store := c.GetMediaStore(); store != nil {
+ ref, err := store.Store(localPath, media.MediaMeta{
+ Filename: filename,
+ Source: "discord",
+ }, scope)
+ if err == nil {
+ return ref
+ }
+ }
+ return localPath // fallback
+ }
+
+ for _, attachment := range m.Attachments {
+ isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
+
+ if isAudio {
+ localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
+ if localPath != "" {
+ mediaPaths = append(mediaPaths, storeMedia(localPath, attachment.Filename))
+ content = appendContent(content, fmt.Sprintf("[audio: %s]", attachment.Filename))
+ } else {
+ logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
+ "url": attachment.URL,
+ "filename": attachment.Filename,
+ })
+ mediaPaths = append(mediaPaths, attachment.URL)
+ content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
+ }
+ } else {
+ mediaPaths = append(mediaPaths, attachment.URL)
+ content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
+ }
+ }
+
+ if content == "" && len(mediaPaths) == 0 {
+ return
+ }
+
+ if content == "" {
+ content = "[media only]"
+ }
+
+ logger.DebugCF("discord", "Received message", map[string]any{
+ "sender_name": sender.DisplayName,
+ "sender_id": senderID,
+ "preview": utils.Truncate(content, 50),
+ })
+
+ peerKind := "channel"
+ peerID := m.ChannelID
+ if m.GuildID == "" {
+ peerKind = "direct"
+ peerID = senderID
+ }
+
+ peer := bus.Peer{Kind: peerKind, ID: peerID}
+
+ metadata := map[string]string{
+ "user_id": senderID,
+ "username": m.Author.Username,
+ "display_name": sender.DisplayName,
+ "guild_id": m.GuildID,
+ "channel_id": m.ChannelID,
+ "is_dm": fmt.Sprintf("%t", m.GuildID == ""),
+ }
+
+ c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender)
+}
+
+// startTyping starts a continuous typing indicator loop for the given chatID.
+// It stops any existing typing loop for that chatID before starting a new one.
+func (c *DiscordChannel) startTyping(chatID string) {
+ c.typingMu.Lock()
+ // Stop existing loop for this chatID if any
+ if stop, ok := c.typingStop[chatID]; ok {
+ close(stop)
+ }
+ stop := make(chan struct{})
+ c.typingStop[chatID] = stop
+ c.typingMu.Unlock()
+
+ go func() {
+ if err := c.session.ChannelTyping(chatID); err != nil {
+ logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
+ }
+ ticker := time.NewTicker(8 * time.Second)
+ defer ticker.Stop()
+ timeout := time.After(5 * time.Minute)
+ for {
+ select {
+ case <-stop:
+ return
+ case <-timeout:
+ return
+ case <-c.ctx.Done():
+ return
+ case <-ticker.C:
+ if err := c.session.ChannelTyping(chatID); err != nil {
+ logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
+ }
+ }
+ }
+ }()
+}
+
+// stopTyping stops the typing indicator loop for the given chatID.
+func (c *DiscordChannel) stopTyping(chatID string) {
+ c.typingMu.Lock()
+ defer c.typingMu.Unlock()
+ if stop, ok := c.typingStop[chatID]; ok {
+ close(stop)
+ delete(c.typingStop, chatID)
+ }
+}
+
+// StartTyping implements channels.TypingCapable.
+// It starts a continuous typing indicator and returns an idempotent stop function.
+func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
+ c.startTyping(chatID)
+ return func() { c.stopTyping(chatID) }, nil
+}
+
+func (c *DiscordChannel) downloadAttachment(url, filename string) string {
+ return utils.DownloadFile(url, filename, utils.DownloadOptions{
+ LoggerPrefix: "discord",
+ ProxyURL: c.config.Proxy,
+ })
+}
+
+func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error {
+ var proxyFunc func(*http.Request) (*url.URL, error)
+ if proxyAddr != "" {
+ proxyURL, err := url.Parse(proxyAddr)
+ if err != nil {
+ return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err)
+ }
+ proxyFunc = http.ProxyURL(proxyURL)
+ } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
+ proxyFunc = http.ProxyFromEnvironment
+ }
+
+ if proxyFunc == nil {
+ return nil
+ }
+
+ transport := &http.Transport{Proxy: proxyFunc}
+ session.Client = &http.Client{
+ Timeout: sendTimeout,
+ Transport: transport,
+ }
+
+ if session.Dialer != nil {
+ dialerCopy := *session.Dialer
+ dialerCopy.Proxy = proxyFunc
+ session.Dialer = &dialerCopy
+ } else {
+ session.Dialer = &websocket.Dialer{Proxy: proxyFunc}
+ }
+
+ return nil
+}
+
+// stripBotMention removes the bot mention from the message content.
+// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
+func (c *DiscordChannel) stripBotMention(text string) string {
+ if c.botUserID == "" {
+ return text
+ }
+ // Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID>
+ text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "")
+ text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
+ return strings.TrimSpace(text)
+}
diff --git a/pkg/channels/discord/discord_test.go b/pkg/channels/discord/discord_test.go
new file mode 100644
index 000000000..0cd5328f4
--- /dev/null
+++ b/pkg/channels/discord/discord_test.go
@@ -0,0 +1,91 @@
+package discord
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/bwmarrin/discordgo"
+)
+
+func TestApplyDiscordProxy_CustomProxy(t *testing.T) {
+ session, err := discordgo.New("Bot test-token")
+ if err != nil {
+ t.Fatalf("discordgo.New() error: %v", err)
+ }
+
+ if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil {
+ t.Fatalf("applyDiscordProxy() error: %v", err)
+ }
+
+ req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
+ if err != nil {
+ t.Fatalf("http.NewRequest() error: %v", err)
+ }
+
+ restProxy := session.Client.Transport.(*http.Transport).Proxy
+ restProxyURL, err := restProxy(req)
+ if err != nil {
+ t.Fatalf("rest proxy func error: %v", err)
+ }
+ if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want {
+ t.Fatalf("REST proxy = %q, want %q", got, want)
+ }
+
+ wsProxyURL, err := session.Dialer.Proxy(req)
+ if err != nil {
+ t.Fatalf("ws proxy func error: %v", err)
+ }
+ if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want {
+ t.Fatalf("WS proxy = %q, want %q", got, want)
+ }
+}
+
+func TestApplyDiscordProxy_FromEnvironment(t *testing.T) {
+ t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
+ t.Setenv("http_proxy", "http://127.0.0.1:8888")
+ t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
+ t.Setenv("https_proxy", "http://127.0.0.1:8888")
+ t.Setenv("ALL_PROXY", "")
+ t.Setenv("all_proxy", "")
+ t.Setenv("NO_PROXY", "")
+ t.Setenv("no_proxy", "")
+
+ session, err := discordgo.New("Bot test-token")
+ if err != nil {
+ t.Fatalf("discordgo.New() error: %v", err)
+ }
+
+ if err = applyDiscordProxy(session, ""); err != nil {
+ t.Fatalf("applyDiscordProxy() error: %v", err)
+ }
+
+ req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil)
+ if err != nil {
+ t.Fatalf("http.NewRequest() error: %v", err)
+ }
+
+ gotURL, err := session.Dialer.Proxy(req)
+ if err != nil {
+ t.Fatalf("ws proxy func error: %v", err)
+ }
+
+ wantURL, err := url.Parse("http://127.0.0.1:8888")
+ if err != nil {
+ t.Fatalf("url.Parse() error: %v", err)
+ }
+ if gotURL.String() != wantURL.String() {
+ t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String())
+ }
+}
+
+func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) {
+ session, err := discordgo.New("Bot test-token")
+ if err != nil {
+ t.Fatalf("discordgo.New() error: %v", err)
+ }
+
+ if err = applyDiscordProxy(session, "://bad-proxy"); err == nil {
+ t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil")
+ }
+}
diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go
new file mode 100644
index 000000000..15a539804
--- /dev/null
+++ b/pkg/channels/discord/init.go
@@ -0,0 +1,13 @@
+package discord
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewDiscordChannel(cfg.Channels.Discord, b)
+ })
+}
diff --git a/pkg/channels/errors.go b/pkg/channels/errors.go
new file mode 100644
index 000000000..09ee88b3f
--- /dev/null
+++ b/pkg/channels/errors.go
@@ -0,0 +1,21 @@
+package channels
+
+import "errors"
+
+var (
+ // ErrNotRunning indicates the channel is not running.
+ // Manager will not retry.
+ ErrNotRunning = errors.New("channel not running")
+
+ // ErrRateLimit indicates the platform returned a rate-limit response (e.g. HTTP 429).
+ // Manager will wait a fixed delay and retry.
+ ErrRateLimit = errors.New("rate limited")
+
+ // ErrTemporary indicates a transient failure (e.g. network timeout, 5xx).
+ // Manager will use exponential backoff and retry.
+ ErrTemporary = errors.New("temporary failure")
+
+ // ErrSendFailed indicates a permanent failure (e.g. invalid chat ID, 4xx non-429).
+ // Manager will not retry.
+ ErrSendFailed = errors.New("send failed")
+)
diff --git a/pkg/channels/errors_test.go b/pkg/channels/errors_test.go
new file mode 100644
index 000000000..e5592345a
--- /dev/null
+++ b/pkg/channels/errors_test.go
@@ -0,0 +1,56 @@
+package channels
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+)
+
+func TestErrorsIs(t *testing.T) {
+ wrapped := fmt.Errorf("telegram API: %w", ErrRateLimit)
+ if !errors.Is(wrapped, ErrRateLimit) {
+ t.Error("wrapped ErrRateLimit should match")
+ }
+ if errors.Is(wrapped, ErrTemporary) {
+ t.Error("wrapped ErrRateLimit should not match ErrTemporary")
+ }
+}
+
+func TestErrorsIsAllTypes(t *testing.T) {
+ sentinels := []error{ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed}
+
+ for _, sentinel := range sentinels {
+ wrapped := fmt.Errorf("context: %w", sentinel)
+ if !errors.Is(wrapped, sentinel) {
+ t.Errorf("wrapped %v should match itself", sentinel)
+ }
+
+ // Verify it doesn't match other sentinel errors
+ for _, other := range sentinels {
+ if other == sentinel {
+ continue
+ }
+ if errors.Is(wrapped, other) {
+ t.Errorf("wrapped %v should not match %v", sentinel, other)
+ }
+ }
+ }
+}
+
+func TestErrorMessages(t *testing.T) {
+ tests := []struct {
+ err error
+ want string
+ }{
+ {ErrNotRunning, "channel not running"},
+ {ErrRateLimit, "rate limited"},
+ {ErrTemporary, "temporary failure"},
+ {ErrSendFailed, "send failed"},
+ }
+
+ for _, tt := range tests {
+ if got := tt.err.Error(); got != tt.want {
+ t.Errorf("error message = %q, want %q", got, tt.want)
+ }
+ }
+}
diff --git a/pkg/channels/errutil.go b/pkg/channels/errutil.go
new file mode 100644
index 000000000..319e3c980
--- /dev/null
+++ b/pkg/channels/errutil.go
@@ -0,0 +1,30 @@
+package channels
+
+import (
+ "fmt"
+ "net/http"
+)
+
+// ClassifySendError wraps a raw error with the appropriate sentinel based on
+// an HTTP status code. Channels that perform HTTP API calls should use this
+// in their Send path.
+func ClassifySendError(statusCode int, rawErr error) error {
+ switch {
+ case statusCode == http.StatusTooManyRequests:
+ return fmt.Errorf("%w: %v", ErrRateLimit, rawErr)
+ case statusCode >= 500:
+ return fmt.Errorf("%w: %v", ErrTemporary, rawErr)
+ case statusCode >= 400:
+ return fmt.Errorf("%w: %v", ErrSendFailed, rawErr)
+ default:
+ return rawErr
+ }
+}
+
+// ClassifyNetError wraps a network/timeout error as ErrTemporary.
+func ClassifyNetError(err error) error {
+ if err == nil {
+ return nil
+ }
+ return fmt.Errorf("%w: %v", ErrTemporary, err)
+}
diff --git a/pkg/channels/errutil_test.go b/pkg/channels/errutil_test.go
new file mode 100644
index 000000000..e3d35f65b
--- /dev/null
+++ b/pkg/channels/errutil_test.go
@@ -0,0 +1,97 @@
+package channels
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+)
+
+func TestClassifySendError(t *testing.T) {
+ raw := fmt.Errorf("some API error")
+
+ tests := []struct {
+ name string
+ statusCode int
+ wantIs error
+ wantNil bool
+ }{
+ {"429 -> ErrRateLimit", 429, ErrRateLimit, false},
+ {"500 -> ErrTemporary", 500, ErrTemporary, false},
+ {"502 -> ErrTemporary", 502, ErrTemporary, false},
+ {"503 -> ErrTemporary", 503, ErrTemporary, false},
+ {"400 -> ErrSendFailed", 400, ErrSendFailed, false},
+ {"403 -> ErrSendFailed", 403, ErrSendFailed, false},
+ {"404 -> ErrSendFailed", 404, ErrSendFailed, false},
+ {"200 -> raw error", 200, nil, false},
+ {"201 -> raw error", 201, nil, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := ClassifySendError(tt.statusCode, raw)
+ if err == nil {
+ t.Fatal("expected non-nil error")
+ }
+ if tt.wantIs != nil {
+ if !errors.Is(err, tt.wantIs) {
+ t.Errorf("errors.Is(err, %v) = false, want true; err = %v", tt.wantIs, err)
+ }
+ } else {
+ // Should return the raw error unchanged
+ if err != raw {
+ t.Errorf("expected raw error to be returned unchanged for status %d, got %v", tt.statusCode, err)
+ }
+ }
+ })
+ }
+}
+
+func TestClassifySendErrorNoFalsePositive(t *testing.T) {
+ raw := fmt.Errorf("some error")
+
+ // 429 should NOT match ErrTemporary or ErrSendFailed
+ err := ClassifySendError(429, raw)
+ if errors.Is(err, ErrTemporary) {
+ t.Error("429 should not match ErrTemporary")
+ }
+ if errors.Is(err, ErrSendFailed) {
+ t.Error("429 should not match ErrSendFailed")
+ }
+
+ // 500 should NOT match ErrRateLimit or ErrSendFailed
+ err = ClassifySendError(500, raw)
+ if errors.Is(err, ErrRateLimit) {
+ t.Error("500 should not match ErrRateLimit")
+ }
+ if errors.Is(err, ErrSendFailed) {
+ t.Error("500 should not match ErrSendFailed")
+ }
+
+ // 400 should NOT match ErrRateLimit or ErrTemporary
+ err = ClassifySendError(400, raw)
+ if errors.Is(err, ErrRateLimit) {
+ t.Error("400 should not match ErrRateLimit")
+ }
+ if errors.Is(err, ErrTemporary) {
+ t.Error("400 should not match ErrTemporary")
+ }
+}
+
+func TestClassifyNetError(t *testing.T) {
+ t.Run("nil error returns nil", func(t *testing.T) {
+ if err := ClassifyNetError(nil); err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ })
+
+ t.Run("non-nil error wraps as ErrTemporary", func(t *testing.T) {
+ raw := fmt.Errorf("connection refused")
+ err := ClassifyNetError(raw)
+ if err == nil {
+ t.Fatal("expected non-nil error")
+ }
+ if !errors.Is(err, ErrTemporary) {
+ t.Errorf("errors.Is(err, ErrTemporary) = false, want true; err = %v", err)
+ }
+ })
+}
diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go
new file mode 100644
index 000000000..fbe085b73
--- /dev/null
+++ b/pkg/channels/feishu/common.go
@@ -0,0 +1,86 @@
+package feishu
+
+import (
+ "encoding/json"
+ "regexp"
+ "strings"
+
+ larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+)
+
+// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
+var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`)
+
+// stringValue safely dereferences a *string pointer.
+func stringValue(v *string) string {
+ if v == nil {
+ return ""
+ }
+ return *v
+}
+
+// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content.
+// JSON 2.0 cards support full CommonMark standard markdown syntax.
+func buildMarkdownCard(content string) (string, error) {
+ card := map[string]any{
+ "schema": "2.0",
+ "body": map[string]any{
+ "elements": []map[string]any{
+ {
+ "tag": "markdown",
+ "content": content,
+ },
+ },
+ },
+ }
+ data, err := json.Marshal(card)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+// extractJSONStringField unmarshals content as JSON and returns the value of the given string field.
+// Returns "" if the content is invalid JSON or the field is missing/empty.
+func extractJSONStringField(content, field string) string {
+ var m map[string]json.RawMessage
+ if err := json.Unmarshal([]byte(content), &m); err != nil {
+ return ""
+ }
+ raw, ok := m[field]
+ if !ok {
+ return ""
+ }
+ var s string
+ if err := json.Unmarshal(raw, &s); err != nil {
+ return ""
+ }
+ return s
+}
+
+// extractImageKey extracts the image_key from a Feishu image message content JSON.
+// Format: {"image_key": "img_xxx"}
+func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") }
+
+// extractFileKey extracts the file_key from a Feishu file/audio message content JSON.
+// Format: {"file_key": "file_xxx", "file_name": "...", ...}
+func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") }
+
+// extractFileName extracts the file_name from a Feishu file message content JSON.
+func extractFileName(content string) string { return extractJSONStringField(content, "file_name") }
+
+// stripMentionPlaceholders removes @_user_N placeholders from the text content.
+// These are inserted by Feishu when users @mention someone in a message.
+func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string {
+ if len(mentions) == 0 {
+ return content
+ }
+ for _, m := range mentions {
+ if m.Key != nil && *m.Key != "" {
+ content = strings.ReplaceAll(content, *m.Key, "")
+ }
+ }
+ // Also clean up any remaining @_user_N patterns
+ content = mentionPlaceholderRegex.ReplaceAllString(content, "")
+ return strings.TrimSpace(content)
+}
diff --git a/pkg/channels/feishu/common_test.go b/pkg/channels/feishu/common_test.go
new file mode 100644
index 000000000..fefc9f7c1
--- /dev/null
+++ b/pkg/channels/feishu/common_test.go
@@ -0,0 +1,292 @@
+package feishu
+
+import (
+ "encoding/json"
+ "testing"
+
+ larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+)
+
+func TestExtractJSONStringField(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ field string
+ want string
+ }{
+ {
+ name: "valid field",
+ content: `{"image_key": "img_v2_xxx"}`,
+ field: "image_key",
+ want: "img_v2_xxx",
+ },
+ {
+ name: "missing field",
+ content: `{"image_key": "img_v2_xxx"}`,
+ field: "file_key",
+ want: "",
+ },
+ {
+ name: "invalid JSON",
+ content: `not json at all`,
+ field: "image_key",
+ want: "",
+ },
+ {
+ name: "empty content",
+ content: "",
+ field: "image_key",
+ want: "",
+ },
+ {
+ name: "non-string field value",
+ content: `{"count": 42}`,
+ field: "count",
+ want: "",
+ },
+ {
+ name: "empty string value",
+ content: `{"image_key": ""}`,
+ field: "image_key",
+ want: "",
+ },
+ {
+ name: "multiple fields",
+ content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`,
+ field: "file_name",
+ want: "test.pdf",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractJSONStringField(tt.content, tt.field)
+ if got != tt.want {
+ t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractImageKey(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ want string
+ }{
+ {
+ name: "normal",
+ content: `{"image_key": "img_v2_abc123"}`,
+ want: "img_v2_abc123",
+ },
+ {
+ name: "missing key",
+ content: `{"file_key": "file_xxx"}`,
+ want: "",
+ },
+ {
+ name: "malformed JSON",
+ content: `{broken`,
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractImageKey(tt.content)
+ if got != tt.want {
+ t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractFileKey(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ want string
+ }{
+ {
+ name: "normal",
+ content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`,
+ want: "file_v2_abc123",
+ },
+ {
+ name: "missing key",
+ content: `{"image_key": "img_xxx"}`,
+ want: "",
+ },
+ {
+ name: "malformed JSON",
+ content: `not json`,
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractFileKey(tt.content)
+ if got != tt.want {
+ t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExtractFileName(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ want string
+ }{
+ {
+ name: "normal",
+ content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
+ want: "report.pdf",
+ },
+ {
+ name: "missing name",
+ content: `{"file_key": "file_xxx"}`,
+ want: "",
+ },
+ {
+ name: "malformed JSON",
+ content: `{bad`,
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractFileName(tt.content)
+ if got != tt.want {
+ t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestBuildMarkdownCard(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ }{
+ {
+ name: "normal content",
+ content: "Hello **world**",
+ },
+ {
+ name: "empty content",
+ content: "",
+ },
+ {
+ name: "special characters",
+ content: `Code: "foo" & 'baz'`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := buildMarkdownCard(tt.content)
+ if err != nil {
+ t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err)
+ }
+
+ // Verify valid JSON
+ var parsed map[string]any
+ if err := json.Unmarshal([]byte(result), &parsed); err != nil {
+ t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err)
+ }
+
+ // Verify schema
+ if parsed["schema"] != "2.0" {
+ t.Errorf("schema = %v, want %q", parsed["schema"], "2.0")
+ }
+
+ // Verify body.elements[0].content == input
+ body, ok := parsed["body"].(map[string]any)
+ if !ok {
+ t.Fatal("missing body in card JSON")
+ }
+ elements, ok := body["elements"].([]any)
+ if !ok || len(elements) == 0 {
+ t.Fatal("missing or empty elements in card JSON")
+ }
+ elem, ok := elements[0].(map[string]any)
+ if !ok {
+ t.Fatal("first element is not an object")
+ }
+ if elem["tag"] != "markdown" {
+ t.Errorf("tag = %v, want %q", elem["tag"], "markdown")
+ }
+ if elem["content"] != tt.content {
+ t.Errorf("content = %v, want %q", elem["content"], tt.content)
+ }
+ })
+ }
+}
+
+func TestStripMentionPlaceholders(t *testing.T) {
+ strPtr := func(s string) *string { return &s }
+
+ tests := []struct {
+ name string
+ content string
+ mentions []*larkim.MentionEvent
+ want string
+ }{
+ {
+ name: "no mentions",
+ content: "Hello world",
+ mentions: nil,
+ want: "Hello world",
+ },
+ {
+ name: "single mention",
+ content: "@_user_1 hello",
+ mentions: []*larkim.MentionEvent{
+ {Key: strPtr("@_user_1")},
+ },
+ want: "hello",
+ },
+ {
+ name: "multiple mentions",
+ content: "@_user_1 @_user_2 hey",
+ mentions: []*larkim.MentionEvent{
+ {Key: strPtr("@_user_1")},
+ {Key: strPtr("@_user_2")},
+ },
+ want: "hey",
+ },
+ {
+ name: "empty content",
+ content: "",
+ mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}},
+ want: "",
+ },
+ {
+ name: "empty mentions slice",
+ content: "@_user_1 test",
+ mentions: []*larkim.MentionEvent{},
+ want: "@_user_1 test",
+ },
+ {
+ name: "mention with nil key",
+ content: "@_user_1 test",
+ mentions: []*larkim.MentionEvent{
+ {Key: nil},
+ },
+ want: "test",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := stripMentionPlaceholders(tt.content, tt.mentions)
+ if got != tt.want {
+ t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/channels/feishu_32.go b/pkg/channels/feishu/feishu_32.go
similarity index 50%
rename from pkg/channels/feishu_32.go
rename to pkg/channels/feishu/feishu_32.go
index 5109b8195..f5e3aa224 100644
--- a/pkg/channels/feishu_32.go
+++ b/pkg/channels/feishu/feishu_32.go
@@ -1,20 +1,23 @@
//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64
-package channels
+package feishu
import (
"context"
"errors"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
// FeishuChannel is a stub implementation for 32-bit architectures
type FeishuChannel struct {
- *BaseChannel
+ *channels.BaseChannel
}
+var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures")
+
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
return nil, errors.New(
@@ -24,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
// Start is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Start(ctx context.Context) error {
- return nil
+ return errUnsupported
}
// Stop is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Stop(ctx context.Context) error {
- return nil
+ return errUnsupported
}
// Send is a stub method to satisfy the Channel interface
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
- return errors.New("feishu channel is not supported on 32-bit architectures")
+ return errUnsupported
+}
+
+// EditMessage is a stub method to satisfy MessageEditor
+func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
+ return errUnsupported
+}
+
+// SendPlaceholder is a stub method to satisfy PlaceholderCapable
+func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ return "", errUnsupported
+}
+
+// ReactToMessage is a stub method to satisfy ReactionCapable
+func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
+ return func() {}, errUnsupported
+}
+
+// SendMedia is a stub method to satisfy MediaSender
+func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ return errUnsupported
}
diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go
new file mode 100644
index 000000000..00f73064d
--- /dev/null
+++ b/pkg/channels/feishu/feishu_64.go
@@ -0,0 +1,818 @@
+//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
+
+package feishu
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "sync"
+ "sync/atomic"
+
+ lark "github.com/larksuite/oapi-sdk-go/v3"
+ larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
+ larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
+ larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+ larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+type FeishuChannel struct {
+ *channels.BaseChannel
+ config config.FeishuConfig
+ client *lark.Client
+ wsClient *larkws.Client
+
+ botOpenID atomic.Value // stores string; populated lazily for @mention detection
+
+ mu sync.Mutex
+ cancel context.CancelFunc
+}
+
+func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
+ base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom,
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
+
+ ch := &FeishuChannel{
+ BaseChannel: base,
+ config: cfg,
+ client: lark.NewClient(cfg.AppID, cfg.AppSecret),
+ }
+ ch.SetOwner(ch)
+ return ch, nil
+}
+
+func (c *FeishuChannel) Start(ctx context.Context) error {
+ if c.config.AppID == "" || c.config.AppSecret == "" {
+ return fmt.Errorf("feishu app_id or app_secret is empty")
+ }
+
+ // Fetch bot open_id via API for reliable @mention detection.
+ if err := c.fetchBotOpenID(ctx); err != nil {
+ logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{
+ "error": err.Error(),
+ })
+ }
+
+ dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey).
+ OnP2MessageReceiveV1(c.handleMessageReceive)
+
+ runCtx, cancel := context.WithCancel(ctx)
+
+ c.mu.Lock()
+ c.cancel = cancel
+ c.wsClient = larkws.NewClient(
+ c.config.AppID,
+ c.config.AppSecret,
+ larkws.WithEventHandler(dispatcher),
+ )
+ wsClient := c.wsClient
+ c.mu.Unlock()
+
+ c.SetRunning(true)
+ logger.InfoC("feishu", "Feishu channel started (websocket mode)")
+
+ go func() {
+ if err := wsClient.Start(runCtx); err != nil {
+ logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
+
+ return nil
+}
+
+func (c *FeishuChannel) Stop(ctx context.Context) error {
+ c.mu.Lock()
+ if c.cancel != nil {
+ c.cancel()
+ c.cancel = nil
+ }
+ c.wsClient = nil
+ c.mu.Unlock()
+
+ c.SetRunning(false)
+ logger.InfoC("feishu", "Feishu channel stopped")
+ return nil
+}
+
+// Send sends a message using Interactive Card format for markdown rendering.
+func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ if msg.ChatID == "" {
+ return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
+ }
+
+ // Build interactive card with markdown content
+ cardContent, err := buildMarkdownCard(msg.Content)
+ if err != nil {
+ return fmt.Errorf("feishu send: card build failed: %w", err)
+ }
+ return c.sendCard(ctx, msg.ChatID, cardContent)
+}
+
+// EditMessage implements channels.MessageEditor.
+// Uses Message.Patch to update an interactive card message.
+func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
+ cardContent, err := buildMarkdownCard(content)
+ if err != nil {
+ return fmt.Errorf("feishu edit: card build failed: %w", err)
+ }
+
+ req := larkim.NewPatchMessageReqBuilder().
+ MessageId(messageID).
+ Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Patch(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu edit: %w", err)
+ }
+ if !resp.Success() {
+ return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+ return nil
+}
+
+// SendPlaceholder implements channels.PlaceholderCapable.
+// Sends an interactive card with placeholder text and returns its message ID.
+func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ if !c.config.Placeholder.Enabled {
+ logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{
+ "chat_id": chatID,
+ })
+ return "", nil
+ }
+
+ text := c.config.Placeholder.Text
+ if text == "" {
+ text = "Thinking..."
+ }
+
+ cardContent, err := buildMarkdownCard(text)
+ if err != nil {
+ return "", fmt.Errorf("feishu placeholder: card build failed: %w", err)
+ }
+
+ req := larkim.NewCreateMessageReqBuilder().
+ ReceiveIdType(larkim.ReceiveIdTypeChatId).
+ Body(larkim.NewCreateMessageReqBodyBuilder().
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeInteractive).
+ Content(cardContent).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Create(ctx, req)
+ if err != nil {
+ return "", fmt.Errorf("feishu placeholder send: %w", err)
+ }
+ if !resp.Success() {
+ return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+
+ if resp.Data != nil && resp.Data.MessageId != nil {
+ return *resp.Data.MessageId, nil
+ }
+ return "", nil
+}
+
+// ReactToMessage implements channels.ReactionCapable.
+// Adds an "Pin" reaction and returns an undo function to remove it.
+func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
+ req := larkim.NewCreateMessageReactionReqBuilder().
+ MessageId(messageID).
+ Body(larkim.NewCreateMessageReactionReqBodyBuilder().
+ ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{
+ "message_id": messageID,
+ "error": err.Error(),
+ })
+ return func() {}, fmt.Errorf("feishu react: %w", err)
+ }
+ if !resp.Success() {
+ logger.ErrorCF("feishu", "Reaction API error", map[string]any{
+ "message_id": messageID,
+ "code": resp.Code,
+ "msg": resp.Msg,
+ })
+ return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+
+ var reactionID string
+ if resp.Data != nil && resp.Data.ReactionId != nil {
+ reactionID = *resp.Data.ReactionId
+ }
+ if reactionID == "" {
+ return func() {}, nil
+ }
+
+ var undone atomic.Bool
+ undo := func() {
+ if !undone.CompareAndSwap(false, true) {
+ return
+ }
+ delReq := larkim.NewDeleteMessageReactionReqBuilder().
+ MessageId(messageID).
+ ReactionId(reactionID).
+ Build()
+ _, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq)
+ }
+ return undo, nil
+}
+
+// SendMedia implements channels.MediaSender.
+// Uploads images/files via Feishu API then sends as messages.
+func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ if msg.ChatID == "" {
+ return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ for _, part := range msg.Parts {
+ if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// sendMediaPart resolves and sends a single media part.
+func (c *FeishuChannel) sendMediaPart(
+ ctx context.Context,
+ chatID string,
+ part bus.MediaPart,
+ store media.MediaStore,
+) error {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{
+ "ref": part.Ref,
+ "error": err.Error(),
+ })
+ return nil // skip this part
+ }
+
+ file, err := os.Open(localPath)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to open media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ return nil // skip this part
+ }
+ defer file.Close()
+
+ switch part.Type {
+ case "image":
+ err = c.sendImage(ctx, chatID, file)
+ default:
+ filename := part.Filename
+ if filename == "" {
+ filename = "file"
+ }
+ err = c.sendFile(ctx, chatID, file, filename, part.Type)
+ }
+
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to send media", map[string]any{
+ "type": part.Type,
+ "error": err.Error(),
+ })
+ return fmt.Errorf("feishu send media: %w", channels.ErrTemporary)
+ }
+ return nil
+}
+
+// --- Inbound message handling ---
+
+func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
+ if event == nil || event.Event == nil || event.Event.Message == nil {
+ return nil
+ }
+
+ message := event.Event.Message
+ sender := event.Event.Sender
+
+ chatID := stringValue(message.ChatId)
+ if chatID == "" {
+ return nil
+ }
+
+ senderID := extractFeishuSenderID(sender)
+ if senderID == "" {
+ senderID = "unknown"
+ }
+
+ messageType := stringValue(message.MessageType)
+ messageID := stringValue(message.MessageId)
+ rawContent := stringValue(message.Content)
+
+ // Check allowlist early to avoid downloading media for rejected senders.
+ // BaseChannel.HandleMessage will check again, but this avoids wasted network I/O.
+ senderInfo := bus.SenderInfo{
+ Platform: "feishu",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("feishu", senderID),
+ }
+ if !c.IsAllowedSender(senderInfo) {
+ return nil
+ }
+
+ // Extract content based on message type
+ content := extractContent(messageType, rawContent)
+
+ // Handle media messages (download and store)
+ var mediaRefs []string
+ if store := c.GetMediaStore(); store != nil && messageID != "" {
+ mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store)
+ }
+
+ // Append media tags to content (like Telegram does)
+ content = appendMediaTags(content, messageType, mediaRefs)
+
+ if content == "" {
+ content = "[empty message]"
+ }
+
+ metadata := map[string]string{}
+ if messageID != "" {
+ metadata["message_id"] = messageID
+ }
+ if messageType != "" {
+ metadata["message_type"] = messageType
+ }
+ chatType := stringValue(message.ChatType)
+ if chatType != "" {
+ metadata["chat_type"] = chatType
+ }
+ if sender != nil && sender.TenantKey != nil {
+ metadata["tenant_key"] = *sender.TenantKey
+ }
+
+ var peer bus.Peer
+ if chatType == "p2p" {
+ peer = bus.Peer{Kind: "direct", ID: senderID}
+ } else {
+ peer = bus.Peer{Kind: "group", ID: chatID}
+
+ // Check if bot was mentioned
+ isMentioned := c.isBotMentioned(message)
+
+ // Strip mention placeholders from content before group trigger check
+ if len(message.Mentions) > 0 {
+ content = stripMentionPlaceholders(content, message.Mentions)
+ }
+
+ // In group chats, apply unified group trigger filtering
+ respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
+ if !respond {
+ return nil
+ }
+ content = cleaned
+ }
+
+ logger.InfoCF("feishu", "Feishu message received", map[string]any{
+ "sender_id": senderID,
+ "chat_id": chatID,
+ "message_id": messageID,
+ "preview": utils.Truncate(content, 80),
+ })
+
+ c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
+ return nil
+}
+
+// --- Internal helpers ---
+
+// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id.
+func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error {
+ resp, err := c.client.Do(ctx, &larkcore.ApiReq{
+ HttpMethod: http.MethodGet,
+ ApiPath: "/open-apis/bot/v3/info",
+ SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant},
+ })
+ if err != nil {
+ return fmt.Errorf("bot info request: %w", err)
+ }
+
+ var result struct {
+ Code int `json:"code"`
+ Bot struct {
+ OpenID string `json:"open_id"`
+ } `json:"bot"`
+ }
+ if err := json.Unmarshal(resp.RawBody, &result); err != nil {
+ return fmt.Errorf("bot info parse: %w", err)
+ }
+ if result.Code != 0 {
+ return fmt.Errorf("bot info api error (code=%d)", result.Code)
+ }
+ if result.Bot.OpenID == "" {
+ return fmt.Errorf("bot info: empty open_id")
+ }
+
+ c.botOpenID.Store(result.Bot.OpenID)
+ logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{
+ "open_id": result.Bot.OpenID,
+ })
+ return nil
+}
+
+// isBotMentioned checks if the bot was @mentioned in the message.
+func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool {
+ if message.Mentions == nil {
+ return false
+ }
+
+ knownID, _ := c.botOpenID.Load().(string)
+ if knownID == "" {
+ logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil)
+ return false
+ }
+
+ for _, m := range message.Mentions {
+ if m.Id == nil {
+ continue
+ }
+ if m.Id.OpenId != nil && *m.Id.OpenId == knownID {
+ return true
+ }
+ }
+ return false
+}
+
+// extractContent extracts text content from different message types.
+func extractContent(messageType, rawContent string) string {
+ if rawContent == "" {
+ return ""
+ }
+
+ switch messageType {
+ case larkim.MsgTypeText:
+ var textPayload struct {
+ Text string `json:"text"`
+ }
+ if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil {
+ return textPayload.Text
+ }
+ return rawContent
+
+ case larkim.MsgTypePost:
+ // Pass raw JSON to LLM — structured rich text is more informative than flattened plain text
+ return rawContent
+
+ case larkim.MsgTypeImage:
+ // Image messages don't have text content
+ return ""
+
+ case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
+ // File/audio/video messages may have a filename
+ name := extractFileName(rawContent)
+ if name != "" {
+ return name
+ }
+ return ""
+
+ default:
+ return rawContent
+ }
+}
+
+// downloadInboundMedia downloads media from inbound messages and stores in MediaStore.
+func (c *FeishuChannel) downloadInboundMedia(
+ ctx context.Context,
+ chatID, messageID, messageType, rawContent string,
+ store media.MediaStore,
+) []string {
+ var refs []string
+ scope := channels.BuildMediaScope("feishu", chatID, messageID)
+
+ switch messageType {
+ case larkim.MsgTypeImage:
+ imageKey := extractImageKey(rawContent)
+ if imageKey == "" {
+ return nil
+ }
+ ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
+ if ref != "" {
+ refs = append(refs, ref)
+ }
+
+ case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
+ fileKey := extractFileKey(rawContent)
+ if fileKey == "" {
+ return nil
+ }
+ // Derive a fallback extension from the message type.
+ var ext string
+ switch messageType {
+ case larkim.MsgTypeAudio:
+ ext = ".ogg"
+ case larkim.MsgTypeMedia:
+ ext = ".mp4"
+ default:
+ ext = "" // generic file — rely on resp.FileName
+ }
+ ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope)
+ if ref != "" {
+ refs = append(refs, ref)
+ }
+ }
+
+ return refs
+}
+
+// downloadResource downloads a message resource (image/file) from Feishu,
+// writes it to the project media directory, and stores the reference in MediaStore.
+// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension.
+func (c *FeishuChannel) downloadResource(
+ ctx context.Context,
+ messageID, fileKey, resourceType, fallbackExt string,
+ store media.MediaStore,
+ scope string,
+) string {
+ req := larkim.NewGetMessageResourceReqBuilder().
+ MessageId(messageID).
+ FileKey(fileKey).
+ Type(resourceType).
+ Build()
+
+ resp, err := c.client.Im.V1.MessageResource.Get(ctx, req)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to download resource", map[string]any{
+ "message_id": messageID,
+ "file_key": fileKey,
+ "error": err.Error(),
+ })
+ return ""
+ }
+ if !resp.Success() {
+ logger.ErrorCF("feishu", "Resource download api error", map[string]any{
+ "code": resp.Code,
+ "msg": resp.Msg,
+ })
+ return ""
+ }
+
+ if resp.File == nil {
+ return ""
+ }
+ // Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body).
+ if closer, ok := resp.File.(io.Closer); ok {
+ defer closer.Close()
+ }
+
+ filename := resp.FileName
+ if filename == "" {
+ filename = fileKey
+ }
+ // If filename still has no extension, append the fallback (like Telegram's ext parameter).
+ if filepath.Ext(filename) == "" && fallbackExt != "" {
+ filename += fallbackExt
+ }
+
+ // Write to the shared picoclaw_media directory using a unique name to avoid collisions.
+ mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
+ if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
+ logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
+ "error": mkdirErr.Error(),
+ })
+ return ""
+ }
+ ext := filepath.Ext(filename)
+ localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext))
+
+ out, err := os.Create(localPath)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{
+ "error": err.Error(),
+ })
+ return ""
+ }
+
+ if _, copyErr := io.Copy(out, resp.File); copyErr != nil {
+ out.Close()
+ os.Remove(localPath)
+ logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{
+ "error": copyErr.Error(),
+ })
+ return ""
+ }
+ out.Close()
+
+ ref, err := store.Store(localPath, media.MediaMeta{
+ Filename: filename,
+ Source: "feishu",
+ }, scope)
+ if err != nil {
+ logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{
+ "file_key": fileKey,
+ "error": err.Error(),
+ })
+ os.Remove(localPath)
+ return ""
+ }
+
+ return ref
+}
+
+// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]").
+func appendMediaTags(content, messageType string, mediaRefs []string) string {
+ if len(mediaRefs) == 0 {
+ return content
+ }
+
+ var tag string
+ switch messageType {
+ case larkim.MsgTypeImage:
+ tag = "[image: photo]"
+ case larkim.MsgTypeAudio:
+ tag = "[audio]"
+ case larkim.MsgTypeMedia:
+ tag = "[video]"
+ case larkim.MsgTypeFile:
+ tag = "[file]"
+ default:
+ tag = "[attachment]"
+ }
+
+ if content == "" {
+ return tag
+ }
+ return content + " " + tag
+}
+
+// sendCard sends an interactive card message to a chat.
+func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error {
+ req := larkim.NewCreateMessageReqBuilder().
+ ReceiveIdType(larkim.ReceiveIdTypeChatId).
+ Body(larkim.NewCreateMessageReqBodyBuilder().
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeInteractive).
+ Content(cardContent).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Create(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu send card: %w", channels.ErrTemporary)
+ }
+
+ if !resp.Success() {
+ return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
+ }
+
+ logger.DebugCF("feishu", "Feishu card message sent", map[string]any{
+ "chat_id": chatID,
+ })
+
+ return nil
+}
+
+// sendImage uploads an image and sends it as a message.
+func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error {
+ // Upload image to get image_key
+ uploadReq := larkim.NewCreateImageReqBuilder().
+ Body(larkim.NewCreateImageReqBodyBuilder().
+ ImageType("message").
+ Image(file).
+ Build()).
+ Build()
+
+ uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq)
+ if err != nil {
+ return fmt.Errorf("feishu image upload: %w", err)
+ }
+ if !uploadResp.Success() {
+ return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
+ }
+ if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil {
+ return fmt.Errorf("feishu image upload: no image_key returned")
+ }
+
+ imageKey := *uploadResp.Data.ImageKey
+
+ // Send image message
+ content, _ := json.Marshal(map[string]string{"image_key": imageKey})
+ req := larkim.NewCreateMessageReqBuilder().
+ ReceiveIdType(larkim.ReceiveIdTypeChatId).
+ Body(larkim.NewCreateMessageReqBodyBuilder().
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeImage).
+ Content(string(content)).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Create(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu image send: %w", err)
+ }
+ if !resp.Success() {
+ return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+ return nil
+}
+
+// sendFile uploads a file and sends it as a message.
+func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error {
+ // Map part type to Feishu file type
+ feishuFileType := "stream"
+ switch fileType {
+ case "audio":
+ feishuFileType = "opus"
+ case "video":
+ feishuFileType = "mp4"
+ }
+
+ // Upload file to get file_key
+ uploadReq := larkim.NewCreateFileReqBuilder().
+ Body(larkim.NewCreateFileReqBodyBuilder().
+ FileType(feishuFileType).
+ FileName(filename).
+ File(file).
+ Build()).
+ Build()
+
+ uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq)
+ if err != nil {
+ return fmt.Errorf("feishu file upload: %w", err)
+ }
+ if !uploadResp.Success() {
+ return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
+ }
+ if uploadResp.Data == nil || uploadResp.Data.FileKey == nil {
+ return fmt.Errorf("feishu file upload: no file_key returned")
+ }
+
+ fileKey := *uploadResp.Data.FileKey
+
+ // Send file message
+ content, _ := json.Marshal(map[string]string{"file_key": fileKey})
+ req := larkim.NewCreateMessageReqBuilder().
+ ReceiveIdType(larkim.ReceiveIdTypeChatId).
+ Body(larkim.NewCreateMessageReqBodyBuilder().
+ ReceiveId(chatID).
+ MsgType(larkim.MsgTypeFile).
+ Content(string(content)).
+ Build()).
+ Build()
+
+ resp, err := c.client.Im.V1.Message.Create(ctx, req)
+ if err != nil {
+ return fmt.Errorf("feishu file send: %w", err)
+ }
+ if !resp.Success() {
+ return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg)
+ }
+ return nil
+}
+
+func extractFeishuSenderID(sender *larkim.EventSender) string {
+ if sender == nil || sender.SenderId == nil {
+ return ""
+ }
+
+ if sender.SenderId.UserId != nil && *sender.SenderId.UserId != "" {
+ return *sender.SenderId.UserId
+ }
+ if sender.SenderId.OpenId != nil && *sender.SenderId.OpenId != "" {
+ return *sender.SenderId.OpenId
+ }
+ if sender.SenderId.UnionId != nil && *sender.SenderId.UnionId != "" {
+ return *sender.SenderId.UnionId
+ }
+
+ return ""
+}
diff --git a/pkg/channels/feishu/feishu_64_test.go b/pkg/channels/feishu/feishu_64_test.go
new file mode 100644
index 000000000..dc3eab2e7
--- /dev/null
+++ b/pkg/channels/feishu/feishu_64_test.go
@@ -0,0 +1,256 @@
+//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
+
+package feishu
+
+import (
+ "testing"
+
+ larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+)
+
+func TestExtractContent(t *testing.T) {
+ tests := []struct {
+ name string
+ messageType string
+ rawContent string
+ want string
+ }{
+ {
+ name: "text message",
+ messageType: "text",
+ rawContent: `{"text": "hello world"}`,
+ want: "hello world",
+ },
+ {
+ name: "text message invalid JSON",
+ messageType: "text",
+ rawContent: `not json`,
+ want: "not json",
+ },
+ {
+ name: "post message returns raw JSON",
+ messageType: "post",
+ rawContent: `{"title": "test post"}`,
+ want: `{"title": "test post"}`,
+ },
+ {
+ name: "image message returns empty",
+ messageType: "image",
+ rawContent: `{"image_key": "img_xxx"}`,
+ want: "",
+ },
+ {
+ name: "file message with filename",
+ messageType: "file",
+ rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
+ want: "report.pdf",
+ },
+ {
+ name: "file message without filename",
+ messageType: "file",
+ rawContent: `{"file_key": "file_xxx"}`,
+ want: "",
+ },
+ {
+ name: "audio message with filename",
+ messageType: "audio",
+ rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`,
+ want: "recording.ogg",
+ },
+ {
+ name: "media message with filename",
+ messageType: "media",
+ rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`,
+ want: "video.mp4",
+ },
+ {
+ name: "unknown message type returns raw",
+ messageType: "sticker",
+ rawContent: `{"sticker_id": "sticker_xxx"}`,
+ want: `{"sticker_id": "sticker_xxx"}`,
+ },
+ {
+ name: "empty raw content",
+ messageType: "text",
+ rawContent: "",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractContent(tt.messageType, tt.rawContent)
+ if got != tt.want {
+ t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAppendMediaTags(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ messageType string
+ mediaRefs []string
+ want string
+ }{
+ {
+ name: "no refs returns content unchanged",
+ content: "hello",
+ messageType: "image",
+ mediaRefs: nil,
+ want: "hello",
+ },
+ {
+ name: "empty refs returns content unchanged",
+ content: "hello",
+ messageType: "image",
+ mediaRefs: []string{},
+ want: "hello",
+ },
+ {
+ name: "image with content",
+ content: "check this",
+ messageType: "image",
+ mediaRefs: []string{"ref1"},
+ want: "check this [image: photo]",
+ },
+ {
+ name: "image empty content",
+ content: "",
+ messageType: "image",
+ mediaRefs: []string{"ref1"},
+ want: "[image: photo]",
+ },
+ {
+ name: "audio",
+ content: "listen",
+ messageType: "audio",
+ mediaRefs: []string{"ref1"},
+ want: "listen [audio]",
+ },
+ {
+ name: "media/video",
+ content: "watch",
+ messageType: "media",
+ mediaRefs: []string{"ref1"},
+ want: "watch [video]",
+ },
+ {
+ name: "file",
+ content: "report.pdf",
+ messageType: "file",
+ mediaRefs: []string{"ref1"},
+ want: "report.pdf [file]",
+ },
+ {
+ name: "unknown type",
+ content: "something",
+ messageType: "sticker",
+ mediaRefs: []string{"ref1"},
+ want: "something [attachment]",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs)
+ if got != tt.want {
+ t.Errorf(
+ "appendMediaTags(%q, %q, %v) = %q, want %q",
+ tt.content,
+ tt.messageType,
+ tt.mediaRefs,
+ got,
+ tt.want,
+ )
+ }
+ })
+ }
+}
+
+func TestExtractFeishuSenderID(t *testing.T) {
+ strPtr := func(s string) *string { return &s }
+
+ tests := []struct {
+ name string
+ sender *larkim.EventSender
+ want string
+ }{
+ {
+ name: "nil sender",
+ sender: nil,
+ want: "",
+ },
+ {
+ name: "nil sender ID",
+ sender: &larkim.EventSender{SenderId: nil},
+ want: "",
+ },
+ {
+ name: "userId preferred",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr("u_abc123"),
+ OpenId: strPtr("ou_def456"),
+ UnionId: strPtr("on_ghi789"),
+ },
+ },
+ want: "u_abc123",
+ },
+ {
+ name: "openId fallback",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr(""),
+ OpenId: strPtr("ou_def456"),
+ UnionId: strPtr("on_ghi789"),
+ },
+ },
+ want: "ou_def456",
+ },
+ {
+ name: "unionId fallback",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr(""),
+ OpenId: strPtr(""),
+ UnionId: strPtr("on_ghi789"),
+ },
+ },
+ want: "on_ghi789",
+ },
+ {
+ name: "all empty strings",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: strPtr(""),
+ OpenId: strPtr(""),
+ UnionId: strPtr(""),
+ },
+ },
+ want: "",
+ },
+ {
+ name: "nil userId pointer falls through",
+ sender: &larkim.EventSender{
+ SenderId: &larkim.UserId{
+ UserId: nil,
+ OpenId: strPtr("ou_def456"),
+ UnionId: nil,
+ },
+ },
+ want: "ou_def456",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := extractFeishuSenderID(tt.sender)
+ if got != tt.want {
+ t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/channels/feishu/init.go b/pkg/channels/feishu/init.go
new file mode 100644
index 000000000..7e5a62dae
--- /dev/null
+++ b/pkg/channels/feishu/init.go
@@ -0,0 +1,13 @@
+package feishu
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewFeishuChannel(cfg.Channels.Feishu, b)
+ })
+}
diff --git a/pkg/channels/feishu_64.go b/pkg/channels/feishu_64.go
deleted file mode 100644
index 42e74980f..000000000
--- a/pkg/channels/feishu_64.go
+++ /dev/null
@@ -1,227 +0,0 @@
-//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
-
-package channels
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "sync"
- "time"
-
- lark "github.com/larksuite/oapi-sdk-go/v3"
- larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
- larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
- larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
-
- "github.com/sipeed/picoclaw/pkg/bus"
- "github.com/sipeed/picoclaw/pkg/config"
- "github.com/sipeed/picoclaw/pkg/logger"
- "github.com/sipeed/picoclaw/pkg/utils"
-)
-
-type FeishuChannel struct {
- *BaseChannel
- config config.FeishuConfig
- client *lark.Client
- wsClient *larkws.Client
-
- mu sync.Mutex
- cancel context.CancelFunc
-}
-
-func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
- base := NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom)
-
- return &FeishuChannel{
- BaseChannel: base,
- config: cfg,
- client: lark.NewClient(cfg.AppID, cfg.AppSecret),
- }, nil
-}
-
-func (c *FeishuChannel) Start(ctx context.Context) error {
- if c.config.AppID == "" || c.config.AppSecret == "" {
- return fmt.Errorf("feishu app_id or app_secret is empty")
- }
-
- dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey).
- OnP2MessageReceiveV1(c.handleMessageReceive)
-
- runCtx, cancel := context.WithCancel(ctx)
-
- c.mu.Lock()
- c.cancel = cancel
- c.wsClient = larkws.NewClient(
- c.config.AppID,
- c.config.AppSecret,
- larkws.WithEventHandler(dispatcher),
- )
- wsClient := c.wsClient
- c.mu.Unlock()
-
- c.setRunning(true)
- logger.InfoC("feishu", "Feishu channel started (websocket mode)")
-
- go func() {
- if err := wsClient.Start(runCtx); err != nil {
- logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]any{
- "error": err.Error(),
- })
- }
- }()
-
- return nil
-}
-
-func (c *FeishuChannel) Stop(ctx context.Context) error {
- c.mu.Lock()
- if c.cancel != nil {
- c.cancel()
- c.cancel = nil
- }
- c.wsClient = nil
- c.mu.Unlock()
-
- c.setRunning(false)
- logger.InfoC("feishu", "Feishu channel stopped")
- return nil
-}
-
-func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
- if !c.IsRunning() {
- return fmt.Errorf("feishu channel not running")
- }
-
- if msg.ChatID == "" {
- return fmt.Errorf("chat ID is empty")
- }
-
- payload, err := json.Marshal(map[string]string{"text": msg.Content})
- if err != nil {
- return fmt.Errorf("failed to marshal feishu content: %w", err)
- }
-
- req := larkim.NewCreateMessageReqBuilder().
- ReceiveIdType(larkim.ReceiveIdTypeChatId).
- Body(larkim.NewCreateMessageReqBodyBuilder().
- ReceiveId(msg.ChatID).
- MsgType(larkim.MsgTypeText).
- Content(string(payload)).
- Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())).
- Build()).
- Build()
-
- resp, err := c.client.Im.V1.Message.Create(ctx, req)
- if err != nil {
- return fmt.Errorf("failed to send feishu message: %w", err)
- }
-
- if !resp.Success() {
- return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg)
- }
-
- logger.DebugCF("feishu", "Feishu message sent", map[string]any{
- "chat_id": msg.ChatID,
- })
-
- return nil
-}
-
-func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error {
- if event == nil || event.Event == nil || event.Event.Message == nil {
- return nil
- }
-
- message := event.Event.Message
- sender := event.Event.Sender
-
- chatID := stringValue(message.ChatId)
- if chatID == "" {
- return nil
- }
-
- senderID := extractFeishuSenderID(sender)
- if senderID == "" {
- senderID = "unknown"
- }
-
- content := extractFeishuMessageContent(message)
- if content == "" {
- content = "[empty message]"
- }
-
- metadata := map[string]string{}
- if messageID := stringValue(message.MessageId); messageID != "" {
- metadata["message_id"] = messageID
- }
- if messageType := stringValue(message.MessageType); messageType != "" {
- metadata["message_type"] = messageType
- }
- if chatType := stringValue(message.ChatType); chatType != "" {
- metadata["chat_type"] = chatType
- }
- if sender != nil && sender.TenantKey != nil {
- metadata["tenant_key"] = *sender.TenantKey
- }
-
- chatType := stringValue(message.ChatType)
- if chatType == "p2p" {
- metadata["peer_kind"] = "direct"
- metadata["peer_id"] = senderID
- } else {
- metadata["peer_kind"] = "group"
- metadata["peer_id"] = chatID
- }
-
- logger.InfoCF("feishu", "Feishu message received", map[string]any{
- "sender_id": senderID,
- "chat_id": chatID,
- "preview": utils.Truncate(content, 80),
- })
-
- c.HandleMessage(senderID, chatID, content, nil, metadata)
- return nil
-}
-
-func extractFeishuSenderID(sender *larkim.EventSender) string {
- if sender == nil || sender.SenderId == nil {
- return ""
- }
-
- if sender.SenderId.UserId != nil && *sender.SenderId.UserId != "" {
- return *sender.SenderId.UserId
- }
- if sender.SenderId.OpenId != nil && *sender.SenderId.OpenId != "" {
- return *sender.SenderId.OpenId
- }
- if sender.SenderId.UnionId != nil && *sender.SenderId.UnionId != "" {
- return *sender.SenderId.UnionId
- }
-
- return ""
-}
-
-func extractFeishuMessageContent(message *larkim.EventMessage) string {
- if message == nil || message.Content == nil || *message.Content == "" {
- return ""
- }
-
- if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText {
- var textPayload struct {
- Text string `json:"text"`
- }
- if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil {
- return textPayload.Text
- }
- }
-
- return *message.Content
-}
-
-func stringValue(v *string) string {
- if v == nil {
- return ""
- }
- return *v
-}
diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go
new file mode 100644
index 000000000..74caeeac5
--- /dev/null
+++ b/pkg/channels/interfaces.go
@@ -0,0 +1,41 @@
+package channels
+
+import "context"
+
+// TypingCapable — channels that can show a typing/thinking indicator.
+// StartTyping begins the indicator and returns a stop function.
+// The stop function MUST be idempotent and safe to call multiple times.
+type TypingCapable interface {
+ StartTyping(ctx context.Context, chatID string) (stop func(), err error)
+}
+
+// MessageEditor — channels that can edit an existing message.
+// messageID is always string; channels convert platform-specific types internally.
+type MessageEditor interface {
+ EditMessage(ctx context.Context, chatID string, messageID string, content string) error
+}
+
+// ReactionCapable — channels that can add a reaction (e.g. 👀) to an inbound message.
+// ReactToMessage adds a reaction and returns an undo function to remove it.
+// The undo function MUST be idempotent and safe to call multiple times.
+type ReactionCapable interface {
+ ReactToMessage(ctx context.Context, chatID, messageID string) (undo func(), err error)
+}
+
+// PlaceholderCapable — channels that can send a placeholder message
+// (e.g. "Thinking... 💭") that will later be edited to the actual response.
+// The channel MUST also implement MessageEditor for the placeholder to be useful.
+// SendPlaceholder returns the platform message ID of the placeholder so that
+// Manager.preSend can later edit it via MessageEditor.EditMessage.
+type PlaceholderCapable interface {
+ SendPlaceholder(ctx context.Context, chatID string) (messageID string, err error)
+}
+
+// PlaceholderRecorder is injected into channels by Manager.
+// Channels call these methods on inbound to register typing/placeholder state.
+// Manager uses the registered state on outbound to stop typing and edit placeholders.
+type PlaceholderRecorder interface {
+ RecordPlaceholder(channel, chatID, placeholderID string)
+ RecordTypingStop(channel, chatID string, stop func())
+ RecordReactionUndo(channel, chatID string, undo func())
+}
diff --git a/pkg/channels/line/init.go b/pkg/channels/line/init.go
new file mode 100644
index 000000000..9265575cc
--- /dev/null
+++ b/pkg/channels/line/init.go
@@ -0,0 +1,13 @@
+package line
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewLINEChannel(cfg.Channels.LINE, b)
+ })
+}
diff --git a/pkg/channels/line.go b/pkg/channels/line/line.go
similarity index 72%
rename from pkg/channels/line.go
rename to pkg/channels/line/line.go
index 44134996f..398f12e6b 100644
--- a/pkg/channels/line.go
+++ b/pkg/channels/line/line.go
@@ -1,4 +1,4 @@
-package channels
+package line
import (
"bytes"
@@ -10,14 +10,16 @@ import (
"fmt"
"io"
"net/http"
- "os"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -41,14 +43,15 @@ type replyTokenEntry struct {
// using the LINE Messaging API with HTTP webhook for receiving messages
// and REST API for sending messages.
type LINEChannel struct {
- *BaseChannel
+ *channels.BaseChannel
config config.LINEConfig
- httpServer *http.Server
- botUserID string // Bot's user ID
- botBasicID string // Bot's basic ID (e.g. @216ru...)
- botDisplayName string // Bot's display name for text-based mention detection
- replyTokens sync.Map // chatID -> replyTokenEntry
- quoteTokens sync.Map // chatID -> quoteToken (string)
+ infoClient *http.Client // for bot info lookups (short timeout)
+ apiClient *http.Client // for messaging API calls
+ botUserID string // Bot's user ID
+ botBasicID string // Bot's basic ID (e.g. @216ru...)
+ botDisplayName string // Bot's display name for text-based mention detection
+ replyTokens sync.Map // chatID -> replyTokenEntry
+ quoteTokens sync.Map // chatID -> quoteToken (string)
ctx context.Context
cancel context.CancelFunc
}
@@ -59,15 +62,21 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
}
- base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom)
+ base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(5000),
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
return &LINEChannel{
BaseChannel: base,
config: cfg,
+ infoClient: &http.Client{Timeout: 10 * time.Second},
+ apiClient: &http.Client{Timeout: 30 * time.Second},
}, nil
}
-// Start launches the HTTP webhook server.
+// Start initializes the LINE channel.
func (c *LINEChannel) Start(ctx context.Context) error {
logger.InfoC("line", "Starting LINE channel (Webhook Mode)")
@@ -86,32 +95,7 @@ func (c *LINEChannel) Start(ctx context.Context) error {
})
}
- mux := http.NewServeMux()
- path := c.config.WebhookPath
- if path == "" {
- path = "/webhook/line"
- }
- mux.HandleFunc(path, c.webhookHandler)
-
- addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
- c.httpServer = &http.Server{
- Addr: addr,
- Handler: mux,
- }
-
- go func() {
- logger.InfoCF("line", "LINE webhook server listening", map[string]any{
- "addr": addr,
- "path": path,
- })
- if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.ErrorCF("line", "Webhook server error", map[string]any{
- "error": err.Error(),
- })
- }
- }()
-
- c.setRunning(true)
+ c.SetRunning(true)
logger.InfoC("line", "LINE channel started (Webhook Mode)")
return nil
}
@@ -124,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error {
}
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
- client := &http.Client{Timeout: 10 * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.infoClient.Do(req)
if err != nil {
return err
}
@@ -150,7 +133,7 @@ func (c *LINEChannel) fetchBotInfo() error {
return nil
}
-// Stop gracefully shuts down the HTTP server.
+// Stop gracefully stops the LINE channel.
func (c *LINEChannel) Stop(ctx context.Context) error {
logger.InfoC("line", "Stopping LINE channel")
@@ -158,21 +141,24 @@ func (c *LINEChannel) Stop(ctx context.Context) error {
c.cancel()
}
- if c.httpServer != nil {
- shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
- if err := c.httpServer.Shutdown(shutdownCtx); err != nil {
- logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{
- "error": err.Error(),
- })
- }
- }
-
- c.setRunning(false)
+ c.SetRunning(false)
logger.InfoC("line", "LINE channel stopped")
return nil
}
+// WebhookPath returns the path for registering on the shared HTTP server.
+func (c *LINEChannel) WebhookPath() string {
+ if c.config.WebhookPath != "" {
+ return c.config.WebhookPath
+ }
+ return "/webhook/line"
+}
+
+// ServeHTTP implements http.Handler for the shared HTTP server.
+func (c *LINEChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ c.webhookHandler(w, r)
+}
+
// webhookHandler handles incoming LINE webhook requests.
func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -284,14 +270,6 @@ func (c *LINEChannel) processEvent(event lineEvent) {
return
}
- // In group chats, only respond when the bot is mentioned
- if isGroup && !c.isBotMentioned(msg) {
- logger.DebugCF("line", "Ignoring group message without mention", map[string]any{
- "chat_id": chatID,
- })
- return
- }
-
// Store reply token for later use
if event.ReplyToken != "" {
c.replyTokens.Store(chatID, replyTokenEntry{
@@ -307,18 +285,22 @@ func (c *LINEChannel) processEvent(event lineEvent) {
var content string
var mediaPaths []string
- localFiles := []string{}
- defer func() {
- for _, file := range localFiles {
- if err := os.Remove(file); err != nil {
- logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{
- "file": file,
- "error": err.Error(),
- })
+ scope := channels.BuildMediaScope("line", chatID, msg.ID)
+
+ // Helper to register a local file with the media store
+ storeMedia := func(localPath, filename string) string {
+ if store := c.GetMediaStore(); store != nil {
+ ref, err := store.Store(localPath, media.MediaMeta{
+ Filename: filename,
+ Source: "line",
+ }, scope)
+ if err == nil {
+ return ref
}
}
- }()
+ return localPath // fallback
+ }
switch msg.Type {
case "text":
@@ -330,22 +312,19 @@ func (c *LINEChannel) processEvent(event lineEvent) {
case "image":
localPath := c.downloadContent(msg.ID, "image.jpg")
if localPath != "" {
- localFiles = append(localFiles, localPath)
- mediaPaths = append(mediaPaths, localPath)
+ mediaPaths = append(mediaPaths, storeMedia(localPath, "image.jpg"))
content = "[image]"
}
case "audio":
localPath := c.downloadContent(msg.ID, "audio.m4a")
if localPath != "" {
- localFiles = append(localFiles, localPath)
- mediaPaths = append(mediaPaths, localPath)
+ mediaPaths = append(mediaPaths, storeMedia(localPath, "audio.m4a"))
content = "[audio]"
}
case "video":
localPath := c.downloadContent(msg.ID, "video.mp4")
if localPath != "" {
- localFiles = append(localFiles, localPath)
- mediaPaths = append(mediaPaths, localPath)
+ mediaPaths = append(mediaPaths, storeMedia(localPath, "video.mp4"))
content = "[video]"
}
case "file":
@@ -360,18 +339,29 @@ func (c *LINEChannel) processEvent(event lineEvent) {
return
}
+ // In group chats, apply unified group trigger filtering
+ if isGroup {
+ isMentioned := c.isBotMentioned(msg)
+ respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
+ if !respond {
+ logger.DebugCF("line", "Ignoring group message by group trigger", map[string]any{
+ "chat_id": chatID,
+ })
+ return
+ }
+ content = cleaned
+ }
+
metadata := map[string]string{
"platform": "line",
"source_type": event.Source.Type,
- "message_id": msg.ID,
}
+ var peer bus.Peer
if isGroup {
- metadata["peer_kind"] = "group"
- metadata["peer_id"] = chatID
+ peer = bus.Peer{Kind: "group", ID: chatID}
} else {
- metadata["peer_kind"] = "direct"
- metadata["peer_id"] = senderID
+ peer = bus.Peer{Kind: "direct", ID: senderID}
}
logger.DebugCF("line", "Received message", map[string]any{
@@ -382,10 +372,17 @@ func (c *LINEChannel) processEvent(event lineEvent) {
"preview": utils.Truncate(content, 50),
})
- // Show typing/loading indicator (requires user ID, not group ID)
- c.sendLoading(senderID)
+ sender := bus.SenderInfo{
+ Platform: "line",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("line", senderID),
+ }
- c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
+ if !c.IsAllowedSender(sender) {
+ return
+ }
+
+ c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata, sender)
}
// isBotMentioned checks if the bot is mentioned in the message.
@@ -491,7 +488,7 @@ func (c *LINEChannel) resolveChatID(source lineSource) string {
// using a cached reply token, then falls back to the Push API.
func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("line channel not running")
+ return channels.ErrNotRunning
}
// Load and consume quote token for this chat
@@ -519,6 +516,36 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken)
}
+// SendMedia implements the channels.MediaSender interface.
+// LINE requires media to be accessible via public URL; since we only have local files,
+// we fall back to sending a text message with the filename/caption.
+// For full support, an external file hosting service would be needed.
+func (c *LINEChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ // LINE Messaging API requires publicly accessible URLs for media messages.
+ // Since we only have local file paths, send caption text as fallback.
+ for _, part := range msg.Parts {
+ caption := part.Caption
+ if caption == "" {
+ caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename)
+ }
+
+ if err := c.sendPush(ctx, msg.ChatID, caption, ""); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
// buildTextMessage creates a text message object, optionally with quoteToken.
func buildTextMessage(content, quoteToken string) map[string]string {
msg := map[string]string{
@@ -551,17 +578,58 @@ func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken stri
return c.callAPI(ctx, linePushEndpoint, payload)
}
+// StartTyping implements channels.TypingCapable using LINE's loading animation.
+//
+// NOTE: The LINE loading animation API only works for 1:1 chats.
+// Group/room chat IDs (starting with "C" or "R") are detected automatically;
+// for these, a no-op stop function is returned without calling the API.
+func (c *LINEChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
+ if chatID == "" {
+ return func() {}, nil
+ }
+
+ // Group/room chats: LINE loading animation is 1:1 only.
+ if strings.HasPrefix(chatID, "C") || strings.HasPrefix(chatID, "R") {
+ return func() {}, nil
+ }
+
+ typingCtx, cancel := context.WithCancel(ctx)
+ var once sync.Once
+ stop := func() { once.Do(cancel) }
+
+ // Send immediately, then refresh periodically for long-running tasks.
+ if err := c.sendLoading(typingCtx, chatID); err != nil {
+ stop()
+ return stop, err
+ }
+
+ ticker := time.NewTicker(50 * time.Second)
+ go func() {
+ defer ticker.Stop()
+ for {
+ select {
+ case <-typingCtx.Done():
+ return
+ case <-ticker.C:
+ if err := c.sendLoading(typingCtx, chatID); err != nil {
+ logger.DebugCF("line", "Failed to refresh loading indicator", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }
+ }
+ }()
+
+ return stop, nil
+}
+
// sendLoading sends a loading animation indicator to the chat.
-func (c *LINEChannel) sendLoading(chatID string) {
+func (c *LINEChannel) sendLoading(ctx context.Context, chatID string) error {
payload := map[string]any{
"chatId": chatID,
"loadingSeconds": 60,
}
- if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil {
- logger.DebugCF("line", "Failed to send loading indicator", map[string]any{
- "error": err.Error(),
- })
- }
+ return c.callAPI(ctx, lineLoadingEndpoint, payload)
}
// callAPI makes an authenticated POST request to the LINE API.
@@ -579,16 +647,15 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
- client := &http.Client{Timeout: 30 * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.apiClient.Do(req)
if err != nil {
- return fmt.Errorf("API request failed: %w", err)
+ return channels.ClassifyNetError(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody))
+ return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody)))
}
return nil
diff --git a/pkg/channels/maixcam/init.go b/pkg/channels/maixcam/init.go
new file mode 100644
index 000000000..5a269b22b
--- /dev/null
+++ b/pkg/channels/maixcam/init.go
@@ -0,0 +1,13 @@
+package maixcam
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewMaixCamChannel(cfg.Channels.MaixCam, b)
+ })
+}
diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam/maixcam.go
similarity index 78%
rename from pkg/channels/maixcam.go
rename to pkg/channels/maixcam/maixcam.go
index 34ce62b20..ff9a3ed1a 100644
--- a/pkg/channels/maixcam.go
+++ b/pkg/channels/maixcam/maixcam.go
@@ -1,4 +1,4 @@
-package channels
+package maixcam
import (
"context"
@@ -6,16 +6,21 @@ import (
"fmt"
"net"
"sync"
+ "time"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
type MaixCamChannel struct {
- *BaseChannel
+ *channels.BaseChannel
config config.MaixCamConfig
listener net.Listener
+ ctx context.Context
+ cancel context.CancelFunc
clients map[net.Conn]bool
clientsMux sync.RWMutex
}
@@ -28,7 +33,13 @@ type MaixCamMessage struct {
}
func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) {
- base := NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom)
+ base := channels.NewBaseChannel(
+ "maixcam",
+ cfg,
+ bus,
+ cfg.AllowFrom,
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
return &MaixCamChannel{
BaseChannel: base,
@@ -40,37 +51,40 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
func (c *MaixCamChannel) Start(ctx context.Context) error {
logger.InfoC("maixcam", "Starting MaixCam channel server")
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
listener, err := net.Listen("tcp", addr)
if err != nil {
+ c.cancel()
return fmt.Errorf("failed to listen on %s: %w", addr, err)
}
c.listener = listener
- c.setRunning(true)
+ c.SetRunning(true)
logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{
"host": c.config.Host,
"port": c.config.Port,
})
- go c.acceptConnections(ctx)
+ go c.acceptConnections()
return nil
}
-func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
+func (c *MaixCamChannel) acceptConnections() {
logger.DebugC("maixcam", "Starting connection acceptor")
for {
select {
- case <-ctx.Done():
+ case <-c.ctx.Done():
logger.InfoC("maixcam", "Stopping connection acceptor")
return
default:
conn, err := c.listener.Accept()
if err != nil {
- if c.running {
+ if c.IsRunning() {
logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{
"error": err.Error(),
})
@@ -86,12 +100,12 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
c.clients[conn] = true
c.clientsMux.Unlock()
- go c.handleConnection(conn, ctx)
+ go c.handleConnection(conn)
}
}
}
-func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
+func (c *MaixCamChannel) handleConnection(conn net.Conn) {
logger.DebugC("maixcam", "Handling MaixCam connection")
defer func() {
@@ -106,7 +120,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
for {
select {
- case <-ctx.Done():
+ case <-c.ctx.Done():
return
default:
var msg MaixCamMessage
@@ -170,11 +184,29 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
"y": fmt.Sprintf("%.0f", y),
"w": fmt.Sprintf("%.0f", w),
"h": fmt.Sprintf("%.0f", h),
- "peer_kind": "channel",
- "peer_id": "default",
}
- c.HandleMessage(senderID, chatID, content, []string{}, metadata)
+ sender := bus.SenderInfo{
+ Platform: "maixcam",
+ PlatformID: "maixcam",
+ CanonicalID: identity.BuildCanonicalID("maixcam", "maixcam"),
+ }
+
+ if !c.IsAllowedSender(sender) {
+ return
+ }
+
+ c.HandleMessage(
+ c.ctx,
+ bus.Peer{Kind: "channel", ID: "default"},
+ "",
+ senderID,
+ chatID,
+ content,
+ []string{},
+ metadata,
+ sender,
+ )
}
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
@@ -185,7 +217,12 @@ func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
func (c *MaixCamChannel) Stop(ctx context.Context) error {
logger.InfoC("maixcam", "Stopping MaixCam channel")
- c.setRunning(false)
+ c.SetRunning(false)
+
+ // Cancel context first to signal goroutines to exit
+ if c.cancel != nil {
+ c.cancel()
+ }
if c.listener != nil {
c.listener.Close()
@@ -205,7 +242,14 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error {
func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("maixcam channel not running")
+ return channels.ErrNotRunning
+ }
+
+ // Check ctx before entering write path
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
}
c.clientsMux.RLock()
@@ -230,13 +274,15 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
var sendErr error
for conn := range c.clients {
+ _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := conn.Write(data); err != nil {
logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{
"client": conn.RemoteAddr().String(),
"error": err.Error(),
})
- sendErr = err
+ sendErr = fmt.Errorf("maixcam send: %w", channels.ErrTemporary)
}
+ _ = conn.SetWriteDeadline(time.Time{})
}
return sendErr
diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go
index 75edaf49e..fdd6d0c1f 100644
--- a/pkg/channels/manager.go
+++ b/pkg/channels/manager.go
@@ -8,32 +8,152 @@ package channels
import (
"context"
+ "errors"
"fmt"
+ "math"
+ "net/http"
"sync"
+ "time"
+
+ "golang.org/x/time/rate"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
+ "github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
)
+const (
+ defaultChannelQueueSize = 16
+ defaultRateLimit = 10 // default 10 msg/s
+ maxRetries = 3
+ rateLimitDelay = 1 * time.Second
+ baseBackoff = 500 * time.Millisecond
+ maxBackoff = 8 * time.Second
+
+ janitorInterval = 10 * time.Second
+ typingStopTTL = 5 * time.Minute
+ placeholderTTL = 10 * time.Minute
+)
+
+// typingEntry wraps a typing stop function with a creation timestamp for TTL eviction.
+type typingEntry struct {
+ stop func()
+ createdAt time.Time
+}
+
+// reactionEntry wraps a reaction undo function with a creation timestamp for TTL eviction.
+type reactionEntry struct {
+ undo func()
+ createdAt time.Time
+}
+
+// placeholderEntry wraps a placeholder ID with a creation timestamp for TTL eviction.
+type placeholderEntry struct {
+ id string
+ createdAt time.Time
+}
+
+// channelRateConfig maps channel name to per-second rate limit.
+var channelRateConfig = map[string]float64{
+ "telegram": 20,
+ "discord": 1,
+ "slack": 1,
+ "line": 10,
+}
+
+type channelWorker struct {
+ ch Channel
+ queue chan bus.OutboundMessage
+ mediaQueue chan bus.OutboundMediaMessage
+ done chan struct{}
+ mediaDone chan struct{}
+ limiter *rate.Limiter
+}
+
type Manager struct {
- channels map[string]Channel
- bus *bus.MessageBus
- config *config.Config
- dispatchTask *asyncTask
- mu sync.RWMutex
+ channels map[string]Channel
+ workers map[string]*channelWorker
+ bus *bus.MessageBus
+ config *config.Config
+ mediaStore media.MediaStore
+ dispatchTask *asyncTask
+ mux *http.ServeMux
+ httpServer *http.Server
+ mu sync.RWMutex
+ placeholders sync.Map // "channel:chatID" → placeholderID (string)
+ typingStops sync.Map // "channel:chatID" → func()
+ reactionUndos sync.Map // "channel:chatID" → reactionEntry
}
type asyncTask struct {
cancel context.CancelFunc
}
-func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error) {
+// RecordPlaceholder registers a placeholder message for later editing.
+// Implements PlaceholderRecorder.
+func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) {
+ key := channel + ":" + chatID
+ m.placeholders.Store(key, placeholderEntry{id: placeholderID, createdAt: time.Now()})
+}
+
+// RecordTypingStop registers a typing stop function for later invocation.
+// Implements PlaceholderRecorder.
+func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) {
+ key := channel + ":" + chatID
+ m.typingStops.Store(key, typingEntry{stop: stop, createdAt: time.Now()})
+}
+
+// RecordReactionUndo registers a reaction undo function for later invocation.
+// Implements PlaceholderRecorder.
+func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
+ key := channel + ":" + chatID
+ m.reactionUndos.Store(key, reactionEntry{undo: undo, createdAt: time.Now()})
+}
+
+// preSend handles typing stop, reaction undo, and placeholder editing before sending a message.
+// Returns true if the message was edited into a placeholder (skip Send).
+func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool {
+ key := name + ":" + msg.ChatID
+
+ // 1. Stop typing
+ if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
+ if entry, ok := v.(typingEntry); ok {
+ entry.stop() // idempotent, safe
+ }
+ }
+
+ // 2. Undo reaction
+ if v, loaded := m.reactionUndos.LoadAndDelete(key); loaded {
+ if entry, ok := v.(reactionEntry); ok {
+ entry.undo() // idempotent, safe
+ }
+ }
+
+ // 3. Try editing placeholder
+ if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
+ if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
+ if editor, ok := ch.(MessageEditor); ok {
+ if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil {
+ return true // edited successfully, skip Send
+ }
+ // edit failed → fall through to normal Send
+ }
+ }
+ }
+
+ return false
+}
+
+func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
m := &Manager{
- channels: make(map[string]Channel),
- bus: messageBus,
- config: cfg,
+ channels: make(map[string]Channel),
+ workers: make(map[string]*channelWorker),
+ bus: messageBus,
+ config: cfg,
+ mediaStore: store,
}
if err := m.initChannels(); err != nil {
@@ -43,163 +163,108 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error
return m, nil
}
+// initChannel is a helper that looks up a factory by name and creates the channel.
+func (m *Manager) initChannel(name, displayName string) {
+ f, ok := getFactory(name)
+ if !ok {
+ logger.WarnCF("channels", "Factory not registered", map[string]any{
+ "channel": displayName,
+ })
+ return
+ }
+ logger.DebugCF("channels", "Attempting to initialize channel", map[string]any{
+ "channel": displayName,
+ })
+ ch, err := f(m.config, m.bus)
+ if err != nil {
+ logger.ErrorCF("channels", "Failed to initialize channel", map[string]any{
+ "channel": displayName,
+ "error": err.Error(),
+ })
+ } else {
+ // Inject MediaStore if channel supports it
+ if m.mediaStore != nil {
+ if setter, ok := ch.(interface{ SetMediaStore(s media.MediaStore) }); ok {
+ setter.SetMediaStore(m.mediaStore)
+ }
+ }
+ // Inject PlaceholderRecorder if channel supports it
+ if setter, ok := ch.(interface{ SetPlaceholderRecorder(r PlaceholderRecorder) }); ok {
+ setter.SetPlaceholderRecorder(m)
+ }
+ // Inject owner reference so BaseChannel.HandleMessage can auto-trigger typing/reaction
+ if setter, ok := ch.(interface{ SetOwner(ch Channel) }); ok {
+ setter.SetOwner(ch)
+ }
+ m.channels[name] = ch
+ logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
+ "channel": displayName,
+ })
+ }
+}
+
func (m *Manager) initChannels() error {
logger.InfoC("channels", "Initializing channel manager")
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
- logger.DebugC("channels", "Attempting to initialize Telegram channel")
- telegram, err := NewTelegramChannel(m.config, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["telegram"] = telegram
- logger.InfoC("channels", "Telegram channel enabled successfully")
- }
+ m.initChannel("telegram", "Telegram")
}
- if m.config.Channels.WhatsApp.Enabled && m.config.Channels.WhatsApp.BridgeURL != "" {
- logger.DebugC("channels", "Attempting to initialize WhatsApp channel")
- whatsapp, err := NewWhatsAppChannel(m.config.Channels.WhatsApp, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["whatsapp"] = whatsapp
- logger.InfoC("channels", "WhatsApp channel enabled successfully")
+ if m.config.Channels.WhatsApp.Enabled {
+ waCfg := m.config.Channels.WhatsApp
+ if waCfg.UseNative {
+ m.initChannel("whatsapp_native", "WhatsApp Native")
+ } else if waCfg.BridgeURL != "" {
+ m.initChannel("whatsapp", "WhatsApp")
}
}
if m.config.Channels.Feishu.Enabled {
- logger.DebugC("channels", "Attempting to initialize Feishu channel")
- feishu, err := NewFeishuChannel(m.config.Channels.Feishu, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["feishu"] = feishu
- logger.InfoC("channels", "Feishu channel enabled successfully")
- }
+ m.initChannel("feishu", "Feishu")
}
if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" {
- logger.DebugC("channels", "Attempting to initialize Discord channel")
- discord, err := NewDiscordChannel(m.config.Channels.Discord, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["discord"] = discord
- logger.InfoC("channels", "Discord channel enabled successfully")
- }
+ m.initChannel("discord", "Discord")
}
if m.config.Channels.MaixCam.Enabled {
- logger.DebugC("channels", "Attempting to initialize MaixCam channel")
- maixcam, err := NewMaixCamChannel(m.config.Channels.MaixCam, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["maixcam"] = maixcam
- logger.InfoC("channels", "MaixCam channel enabled successfully")
- }
+ m.initChannel("maixcam", "MaixCam")
}
if m.config.Channels.QQ.Enabled {
- logger.DebugC("channels", "Attempting to initialize QQ channel")
- qq, err := NewQQChannel(m.config.Channels.QQ, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["qq"] = qq
- logger.InfoC("channels", "QQ channel enabled successfully")
- }
+ m.initChannel("qq", "QQ")
}
if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" {
- logger.DebugC("channels", "Attempting to initialize DingTalk channel")
- dingtalk, err := NewDingTalkChannel(m.config.Channels.DingTalk, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["dingtalk"] = dingtalk
- logger.InfoC("channels", "DingTalk channel enabled successfully")
- }
+ m.initChannel("dingtalk", "DingTalk")
}
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
- logger.DebugC("channels", "Attempting to initialize Slack channel")
- slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["slack"] = slackCh
- logger.InfoC("channels", "Slack channel enabled successfully")
- }
+ m.initChannel("slack", "Slack")
}
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
- logger.DebugC("channels", "Attempting to initialize LINE channel")
- line, err := NewLINEChannel(m.config.Channels.LINE, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["line"] = line
- logger.InfoC("channels", "LINE channel enabled successfully")
- }
+ m.initChannel("line", "LINE")
}
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
- logger.DebugC("channels", "Attempting to initialize OneBot channel")
- onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["onebot"] = onebot
- logger.InfoC("channels", "OneBot channel enabled successfully")
- }
+ m.initChannel("onebot", "OneBot")
}
if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" {
- logger.DebugC("channels", "Attempting to initialize WeCom channel")
- wecom, err := NewWeComBotChannel(m.config.Channels.WeCom, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize WeCom channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["wecom"] = wecom
- logger.InfoC("channels", "WeCom channel enabled successfully")
- }
+ m.initChannel("wecom", "WeCom")
+ }
+
+ if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" {
+ m.initChannel("wecom_aibot", "WeCom AI Bot")
}
if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
- logger.DebugC("channels", "Attempting to initialize WeCom App channel")
- wecomApp, err := NewWeComAppChannel(m.config.Channels.WeComApp, m.bus)
- if err != nil {
- logger.ErrorCF("channels", "Failed to initialize WeCom App channel", map[string]any{
- "error": err.Error(),
- })
- } else {
- m.channels["wecom_app"] = wecomApp
- logger.InfoC("channels", "WeCom App channel enabled successfully")
- }
+ m.initChannel("wecom_app", "WeCom App")
+ }
+
+ if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" {
+ m.initChannel("pico", "Pico")
}
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
@@ -209,13 +274,50 @@ func (m *Manager) initChannels() error {
return nil
}
+// SetupHTTPServer creates a shared HTTP server with the given listen address.
+// It registers health endpoints from the health server and discovers channels
+// that implement WebhookHandler and/or HealthChecker to register their handlers.
+func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
+ m.mux = http.NewServeMux()
+
+ // Register health endpoints
+ if healthServer != nil {
+ healthServer.RegisterOnMux(m.mux)
+ }
+
+ // Discover and register webhook handlers and health checkers
+ for name, ch := range m.channels {
+ if wh, ok := ch.(WebhookHandler); ok {
+ m.mux.Handle(wh.WebhookPath(), wh)
+ logger.InfoCF("channels", "Webhook handler registered", map[string]any{
+ "channel": name,
+ "path": wh.WebhookPath(),
+ })
+ }
+ if hc, ok := ch.(HealthChecker); ok {
+ m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler)
+ logger.InfoCF("channels", "Health endpoint registered", map[string]any{
+ "channel": name,
+ "path": hc.HealthPath(),
+ })
+ }
+ }
+
+ m.httpServer = &http.Server{
+ Addr: addr,
+ Handler: m.mux,
+ ReadTimeout: 30 * time.Second,
+ WriteTimeout: 30 * time.Second,
+ }
+}
+
func (m *Manager) StartAll(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if len(m.channels) == 0 {
logger.WarnC("channels", "No channels enabled")
- return nil
+ return errors.New("no channels enabled")
}
logger.InfoC("channels", "Starting all channels")
@@ -223,8 +325,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
dispatchCtx, cancel := context.WithCancel(ctx)
m.dispatchTask = &asyncTask{cancel: cancel}
- go m.dispatchOutbound(dispatchCtx)
-
for name, channel := range m.channels {
logger.InfoCF("channels", "Starting channel", map[string]any{
"channel": name,
@@ -234,7 +334,34 @@ func (m *Manager) StartAll(ctx context.Context) error {
"channel": name,
"error": err.Error(),
})
+ continue
}
+ // Lazily create worker only after channel starts successfully
+ w := newChannelWorker(name, channel)
+ m.workers[name] = w
+ go m.runWorker(dispatchCtx, name, w)
+ go m.runMediaWorker(dispatchCtx, name, w)
+ }
+
+ // Start the dispatcher that reads from the bus and routes to workers
+ go m.dispatchOutbound(dispatchCtx)
+ go m.dispatchOutboundMedia(dispatchCtx)
+
+ // Start the TTL janitor that cleans up stale typing/placeholder entries
+ go m.runTTLJanitor(dispatchCtx)
+
+ // Start shared HTTP server if configured
+ if m.httpServer != nil {
+ go func() {
+ logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{
+ "addr": m.httpServer.Addr,
+ })
+ if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
}
logger.InfoC("channels", "All channels started")
@@ -247,11 +374,48 @@ func (m *Manager) StopAll(ctx context.Context) error {
logger.InfoC("channels", "Stopping all channels")
+ // Shutdown shared HTTP server first
+ if m.httpServer != nil {
+ shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+ if err := m.httpServer.Shutdown(shutdownCtx); err != nil {
+ logger.ErrorCF("channels", "Shared HTTP server shutdown error", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ m.httpServer = nil
+ }
+
+ // Cancel dispatcher
if m.dispatchTask != nil {
m.dispatchTask.cancel()
m.dispatchTask = nil
}
+ // Close all worker queues and wait for them to drain
+ for _, w := range m.workers {
+ if w != nil {
+ close(w.queue)
+ }
+ }
+ for _, w := range m.workers {
+ if w != nil {
+ <-w.done
+ }
+ }
+ // Close all media worker queues and wait for them to drain
+ for _, w := range m.workers {
+ if w != nil {
+ close(w.mediaQueue)
+ }
+ }
+ for _, w := range m.workers {
+ if w != nil {
+ <-w.mediaDone
+ }
+ }
+
+ // Stop all channels
for name, channel := range m.channels {
logger.InfoCF("channels", "Stopping channel", map[string]any{
"channel": name,
@@ -268,42 +432,318 @@ func (m *Manager) StopAll(ctx context.Context) error {
return nil
}
+// newChannelWorker creates a channelWorker with a rate limiter configured
+// for the given channel name.
+func newChannelWorker(name string, ch Channel) *channelWorker {
+ rateVal := float64(defaultRateLimit)
+ if r, ok := channelRateConfig[name]; ok {
+ rateVal = r
+ }
+ burst := int(math.Max(1, math.Ceil(rateVal/2)))
+
+ return &channelWorker{
+ ch: ch,
+ queue: make(chan bus.OutboundMessage, defaultChannelQueueSize),
+ mediaQueue: make(chan bus.OutboundMediaMessage, defaultChannelQueueSize),
+ done: make(chan struct{}),
+ mediaDone: make(chan struct{}),
+ limiter: rate.NewLimiter(rate.Limit(rateVal), burst),
+ }
+}
+
+// runWorker processes outbound messages for a single channel, splitting
+// messages that exceed the channel's maximum message length.
+func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) {
+ defer close(w.done)
+ for {
+ select {
+ case msg, ok := <-w.queue:
+ if !ok {
+ return
+ }
+ maxLen := 0
+ if mlp, ok := w.ch.(MessageLengthProvider); ok {
+ maxLen = mlp.MaxMessageLength()
+ }
+ if maxLen > 0 && len([]rune(msg.Content)) > maxLen {
+ chunks := SplitMessage(msg.Content, maxLen)
+ for _, chunk := range chunks {
+ chunkMsg := msg
+ chunkMsg.Content = chunk
+ m.sendWithRetry(ctx, name, w, chunkMsg)
+ }
+ } else {
+ m.sendWithRetry(ctx, name, w, msg)
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+// sendWithRetry sends a message through the channel with rate limiting and
+// retry logic. It classifies errors to determine the retry strategy:
+// - ErrNotRunning / ErrSendFailed: permanent, no retry
+// - ErrRateLimit: fixed delay retry
+// - ErrTemporary / unknown: exponential backoff retry
+func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMessage) {
+ // Rate limit: wait for token
+ if err := w.limiter.Wait(ctx); err != nil {
+ // ctx canceled, shutting down
+ return
+ }
+
+ // Pre-send: stop typing and try to edit placeholder
+ if m.preSend(ctx, name, msg, w.ch) {
+ return // placeholder was edited successfully, skip Send
+ }
+
+ var lastErr error
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ lastErr = w.ch.Send(ctx, msg)
+ if lastErr == nil {
+ return
+ }
+
+ // Permanent failures — don't retry
+ if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) {
+ break
+ }
+
+ // Last attempt exhausted — don't sleep
+ if attempt == maxRetries {
+ break
+ }
+
+ // Rate limit error — fixed delay
+ if errors.Is(lastErr, ErrRateLimit) {
+ select {
+ case <-time.After(rateLimitDelay):
+ continue
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ // ErrTemporary or unknown error — exponential backoff
+ backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff)
+ select {
+ case <-time.After(backoff):
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ // All retries exhausted or permanent failure
+ logger.ErrorCF("channels", "Send failed", map[string]any{
+ "channel": name,
+ "chat_id": msg.ChatID,
+ "error": lastErr.Error(),
+ "retries": maxRetries,
+ })
+}
+
+func dispatchLoop[M any](
+ ctx context.Context,
+ m *Manager,
+ subscribe func(context.Context) (M, bool),
+ getChannel func(M) string,
+ enqueue func(context.Context, *channelWorker, M) bool,
+ startMsg, stopMsg, unknownMsg, noWorkerMsg string,
+) {
+ logger.InfoC("channels", startMsg)
+
+ for {
+ msg, ok := subscribe(ctx)
+ if !ok {
+ logger.InfoC("channels", stopMsg)
+ return
+ }
+
+ channel := getChannel(msg)
+
+ // Silently skip internal channels
+ if constants.IsInternalChannel(channel) {
+ continue
+ }
+
+ m.mu.RLock()
+ _, exists := m.channels[channel]
+ w, wExists := m.workers[channel]
+ m.mu.RUnlock()
+
+ if !exists {
+ logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
+ continue
+ }
+
+ if wExists && w != nil {
+ if !enqueue(ctx, w, msg) {
+ return
+ }
+ } else if exists {
+ logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
+ }
+ }
+}
+
func (m *Manager) dispatchOutbound(ctx context.Context) {
- logger.InfoC("channels", "Outbound dispatcher started")
+ dispatchLoop(
+ ctx, m,
+ m.bus.SubscribeOutbound,
+ func(msg bus.OutboundMessage) string { return msg.Channel },
+ func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
+ select {
+ case w.queue <- msg:
+ return true
+ case <-ctx.Done():
+ return false
+ }
+ },
+ "Outbound dispatcher started",
+ "Outbound dispatcher stopped",
+ "Unknown channel for outbound message",
+ "Channel has no active worker, skipping message",
+ )
+}
+
+func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
+ dispatchLoop(
+ ctx, m,
+ m.bus.SubscribeOutboundMedia,
+ func(msg bus.OutboundMediaMessage) string { return msg.Channel },
+ func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
+ select {
+ case w.mediaQueue <- msg:
+ return true
+ case <-ctx.Done():
+ return false
+ }
+ },
+ "Outbound media dispatcher started",
+ "Outbound media dispatcher stopped",
+ "Unknown channel for outbound media message",
+ "Channel has no active worker, skipping media message",
+ )
+}
+
+// runMediaWorker processes outbound media messages for a single channel.
+func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWorker) {
+ defer close(w.mediaDone)
+ for {
+ select {
+ case msg, ok := <-w.mediaQueue:
+ if !ok {
+ return
+ }
+ m.sendMediaWithRetry(ctx, name, w, msg)
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+// sendMediaWithRetry sends a media message through the channel with rate limiting and
+// retry logic. If the channel does not implement MediaSender, it silently skips.
+func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) {
+ ms, ok := w.ch.(MediaSender)
+ if !ok {
+ logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{
+ "channel": name,
+ })
+ return
+ }
+
+ // Rate limit: wait for token
+ if err := w.limiter.Wait(ctx); err != nil {
+ return
+ }
+
+ var lastErr error
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ lastErr = ms.SendMedia(ctx, msg)
+ if lastErr == nil {
+ return
+ }
+
+ // Permanent failures — don't retry
+ if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) {
+ break
+ }
+
+ // Last attempt exhausted — don't sleep
+ if attempt == maxRetries {
+ break
+ }
+
+ // Rate limit error — fixed delay
+ if errors.Is(lastErr, ErrRateLimit) {
+ select {
+ case <-time.After(rateLimitDelay):
+ continue
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ // ErrTemporary or unknown error — exponential backoff
+ backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff)
+ select {
+ case <-time.After(backoff):
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ // All retries exhausted or permanent failure
+ logger.ErrorCF("channels", "SendMedia failed", map[string]any{
+ "channel": name,
+ "chat_id": msg.ChatID,
+ "error": lastErr.Error(),
+ "retries": maxRetries,
+ })
+}
+
+// runTTLJanitor periodically scans the typingStops and placeholders maps
+// and evicts entries that have exceeded their TTL. This prevents memory
+// accumulation when outbound paths fail to trigger preSend (e.g. LLM errors).
+func (m *Manager) runTTLJanitor(ctx context.Context) {
+ ticker := time.NewTicker(janitorInterval)
+ defer ticker.Stop()
for {
select {
case <-ctx.Done():
- logger.InfoC("channels", "Outbound dispatcher stopped")
return
- default:
- msg, ok := m.bus.SubscribeOutbound(ctx)
- if !ok {
- continue
- }
-
- // Silently skip internal channels
- if constants.IsInternalChannel(msg.Channel) {
- continue
- }
-
- m.mu.RLock()
- channel, exists := m.channels[msg.Channel]
- m.mu.RUnlock()
-
- if !exists {
- logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
- "channel": msg.Channel,
- })
- continue
- }
-
- if err := channel.Send(ctx, msg); err != nil {
- logger.ErrorCF("channels", "Error sending message to channel", map[string]any{
- "channel": msg.Channel,
- "error": err.Error(),
- })
- }
+ case now := <-ticker.C:
+ m.typingStops.Range(func(key, value any) bool {
+ if entry, ok := value.(typingEntry); ok {
+ if now.Sub(entry.createdAt) > typingStopTTL {
+ if _, loaded := m.typingStops.LoadAndDelete(key); loaded {
+ entry.stop() // idempotent, safe
+ }
+ }
+ }
+ return true
+ })
+ m.reactionUndos.Range(func(key, value any) bool {
+ if entry, ok := value.(reactionEntry); ok {
+ if now.Sub(entry.createdAt) > typingStopTTL {
+ if _, loaded := m.reactionUndos.LoadAndDelete(key); loaded {
+ entry.undo() // idempotent, safe
+ }
+ }
+ }
+ return true
+ })
+ m.placeholders.Range(func(key, value any) bool {
+ if entry, ok := value.(placeholderEntry); ok {
+ if now.Sub(entry.createdAt) > placeholderTTL {
+ m.placeholders.Delete(key)
+ }
+ }
+ return true
+ })
}
}
}
@@ -349,12 +789,20 @@ func (m *Manager) RegisterChannel(name string, channel Channel) {
func (m *Manager) UnregisterChannel(name string) {
m.mu.Lock()
defer m.mu.Unlock()
+ if w, ok := m.workers[name]; ok && w != nil {
+ close(w.queue)
+ <-w.done
+ close(w.mediaQueue)
+ <-w.mediaDone
+ }
+ delete(m.workers, name)
delete(m.channels, name)
}
func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error {
m.mu.RLock()
- channel, exists := m.channels[channelName]
+ _, exists := m.channels[channelName]
+ w, wExists := m.workers[channelName]
m.mu.RUnlock()
if !exists {
@@ -367,5 +815,16 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
Content: content,
}
+ if wExists && w != nil {
+ select {
+ case w.queue <- msg:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+
+ // Fallback: direct send (should not happen)
+ channel, _ := m.channels[channelName]
return channel.Send(ctx, msg)
}
diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go
new file mode 100644
index 000000000..f09ecfe2f
--- /dev/null
+++ b/pkg/channels/manager_test.go
@@ -0,0 +1,862 @@
+package channels
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "golang.org/x/time/rate"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+)
+
+// mockChannel is a test double that delegates Send to a configurable function.
+type mockChannel struct {
+ BaseChannel
+ sendFn func(ctx context.Context, msg bus.OutboundMessage) error
+}
+
+func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ return m.sendFn(ctx, msg)
+}
+
+func (m *mockChannel) Start(ctx context.Context) error { return nil }
+func (m *mockChannel) Stop(ctx context.Context) error { return nil }
+
+// newTestManager creates a minimal Manager suitable for unit tests.
+func newTestManager() *Manager {
+ return &Manager{
+ channels: make(map[string]Channel),
+ workers: make(map[string]*channelWorker),
+ }
+}
+
+func TestSendWithRetry_Success(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ return nil
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ m.sendWithRetry(ctx, "test", w, msg)
+
+ if callCount != 1 {
+ t.Fatalf("expected 1 Send call, got %d", callCount)
+ }
+}
+
+func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ if callCount <= 2 {
+ return fmt.Errorf("network error: %w", ErrTemporary)
+ }
+ return nil
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ m.sendWithRetry(ctx, "test", w, msg)
+
+ if callCount != 3 {
+ t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount)
+ }
+}
+
+func TestSendWithRetry_PermanentFailure(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ return fmt.Errorf("bad chat ID: %w", ErrSendFailed)
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ m.sendWithRetry(ctx, "test", w, msg)
+
+ if callCount != 1 {
+ t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount)
+ }
+}
+
+func TestSendWithRetry_NotRunning(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ return ErrNotRunning
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ m.sendWithRetry(ctx, "test", w, msg)
+
+ if callCount != 1 {
+ t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount)
+ }
+}
+
+func TestSendWithRetry_RateLimitRetry(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ if callCount == 1 {
+ return fmt.Errorf("429: %w", ErrRateLimit)
+ }
+ return nil
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ start := time.Now()
+ m.sendWithRetry(ctx, "test", w, msg)
+ elapsed := time.Since(start)
+
+ if callCount != 2 {
+ t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount)
+ }
+ // Should have waited at least rateLimitDelay (1s) but allow some slack
+ if elapsed < 900*time.Millisecond {
+ t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed)
+ }
+}
+
+func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ return fmt.Errorf("timeout: %w", ErrTemporary)
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ m.sendWithRetry(ctx, "test", w, msg)
+
+ expected := maxRetries + 1 // initial attempt + maxRetries retries
+ if callCount != expected {
+ t.Fatalf("expected %d Send calls, got %d", expected, callCount)
+ }
+}
+
+func TestSendWithRetry_UnknownError(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ if callCount == 1 {
+ return errors.New("random unexpected error")
+ }
+ return nil
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ m.sendWithRetry(ctx, "test", w, msg)
+
+ if callCount != 2 {
+ t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount)
+ }
+}
+
+func TestSendWithRetry_ContextCancelled(t *testing.T) {
+ m := newTestManager()
+ var callCount int
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ return fmt.Errorf("timeout: %w", ErrTemporary)
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ // Cancel context after first Send attempt returns
+ ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
+ callCount++
+ cancel()
+ return fmt.Errorf("timeout: %w", ErrTemporary)
+ }
+
+ m.sendWithRetry(ctx, "test", w, msg)
+
+ // Should have called Send once, then noticed ctx canceled during backoff
+ if callCount != 1 {
+ t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount)
+ }
+}
+
+func TestWorkerRateLimiter(t *testing.T) {
+ m := newTestManager()
+
+ var mu sync.Mutex
+ var sendTimes []time.Time
+
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ mu.Lock()
+ sendTimes = append(sendTimes, time.Now())
+ mu.Unlock()
+ return nil
+ },
+ }
+
+ // Create a worker with a low rate: 2 msg/s, burst 1
+ w := &channelWorker{
+ ch: ch,
+ queue: make(chan bus.OutboundMessage, 10),
+ done: make(chan struct{}),
+ limiter: rate.NewLimiter(2, 1),
+ }
+
+ ctx := t.Context()
+
+ go m.runWorker(ctx, "test", w)
+
+ // Enqueue 4 messages
+ for i := range 4 {
+ w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
+ }
+
+ // Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
+ time.Sleep(3 * time.Second)
+
+ mu.Lock()
+ times := make([]time.Time, len(sendTimes))
+ copy(times, sendTimes)
+ mu.Unlock()
+
+ if len(times) != 4 {
+ t.Fatalf("expected 4 sends, got %d", len(times))
+ }
+
+ // Verify rate limiting: total duration should be at least 1s
+ // (first message immediate, then ~500ms between each subsequent one at 2/s)
+ totalDuration := times[len(times)-1].Sub(times[0])
+ if totalDuration < 1*time.Second {
+ t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration)
+ }
+}
+
+func TestNewChannelWorker_DefaultRate(t *testing.T) {
+ ch := &mockChannel{}
+ w := newChannelWorker("unknown_channel", ch)
+
+ if w.limiter == nil {
+ t.Fatal("expected limiter to be non-nil")
+ }
+ if w.limiter.Limit() != rate.Limit(defaultRateLimit) {
+ t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit())
+ }
+}
+
+func TestNewChannelWorker_ConfiguredRate(t *testing.T) {
+ ch := &mockChannel{}
+
+ for name, expectedRate := range channelRateConfig {
+ w := newChannelWorker(name, ch)
+ if w.limiter.Limit() != rate.Limit(expectedRate) {
+ t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit())
+ }
+ }
+}
+
+func TestRunWorker_MessageSplitting(t *testing.T) {
+ m := newTestManager()
+
+ var mu sync.Mutex
+ var received []string
+
+ ch := &mockChannelWithLength{
+ mockChannel: mockChannel{
+ sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
+ mu.Lock()
+ received = append(received, msg.Content)
+ mu.Unlock()
+ return nil
+ },
+ },
+ maxLen: 5,
+ }
+
+ w := &channelWorker{
+ ch: ch,
+ queue: make(chan bus.OutboundMessage, 10),
+ done: make(chan struct{}),
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := t.Context()
+
+ go m.runWorker(ctx, "test", w)
+
+ // Send a message that should be split
+ w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
+
+ time.Sleep(100 * time.Millisecond)
+
+ mu.Lock()
+ count := len(received)
+ mu.Unlock()
+
+ if count < 2 {
+ t.Fatalf("expected message to be split into at least 2 chunks, got %d", count)
+ }
+}
+
+// mockChannelWithLength implements MessageLengthProvider.
+type mockChannelWithLength struct {
+ mockChannel
+ maxLen int
+}
+
+func (m *mockChannelWithLength) MaxMessageLength() int {
+ return m.maxLen
+}
+
+func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
+ m := newTestManager()
+
+ var callTimes []time.Time
+ var callCount atomic.Int32
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ callTimes = append(callTimes, time.Now())
+ callCount.Add(1)
+ return fmt.Errorf("timeout: %w", ErrTemporary)
+ },
+ }
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
+
+ start := time.Now()
+ m.sendWithRetry(ctx, "test", w, msg)
+ totalElapsed := time.Since(start)
+
+ // With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s
+ // Total backoff: 500ms + 1s + 2s = 3.5s
+ // Allow some margin
+ if totalElapsed < 3*time.Second {
+ t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed)
+ }
+
+ if int(callCount.Load()) != maxRetries+1 {
+ t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load())
+ }
+}
+
+// --- Phase 10: preSend orchestration tests ---
+
+// mockMessageEditor is a channel that supports MessageEditor.
+type mockMessageEditor struct {
+ mockChannel
+ editFn func(ctx context.Context, chatID, messageID, content string) error
+}
+
+func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error {
+ return m.editFn(ctx, chatID, messageID, content)
+}
+
+func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
+ m := newTestManager()
+ var sendCalled bool
+ var editCalled bool
+
+ ch := &mockMessageEditor{
+ mockChannel: mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ sendCalled = true
+ return nil
+ },
+ },
+ editFn: func(_ context.Context, chatID, messageID, content string) error {
+ editCalled = true
+ if chatID != "123" {
+ t.Fatalf("expected chatID 123, got %s", chatID)
+ }
+ if messageID != "456" {
+ t.Fatalf("expected messageID 456, got %s", messageID)
+ }
+ if content != "hello" {
+ t.Fatalf("expected content 'hello', got %s", content)
+ }
+ return nil
+ },
+ }
+
+ // Register placeholder
+ m.RecordPlaceholder("test", "123", "456")
+
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
+ edited := m.preSend(context.Background(), "test", msg, ch)
+
+ if !edited {
+ t.Fatal("expected preSend to return true (placeholder edited)")
+ }
+ if !editCalled {
+ t.Fatal("expected EditMessage to be called")
+ }
+ if sendCalled {
+ t.Fatal("expected Send to NOT be called when placeholder edited")
+ }
+}
+
+func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) {
+ m := newTestManager()
+
+ ch := &mockMessageEditor{
+ mockChannel: mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ return nil
+ },
+ },
+ editFn: func(_ context.Context, _, _, _ string) error {
+ return fmt.Errorf("edit failed")
+ },
+ }
+
+ m.RecordPlaceholder("test", "123", "456")
+
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
+ edited := m.preSend(context.Background(), "test", msg, ch)
+
+ if edited {
+ t.Fatal("expected preSend to return false when edit fails")
+ }
+}
+
+func TestPreSend_TypingStopCalled(t *testing.T) {
+ m := newTestManager()
+ var stopCalled bool
+
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ return nil
+ },
+ }
+
+ m.RecordTypingStop("test", "123", func() {
+ stopCalled = true
+ })
+
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
+ m.preSend(context.Background(), "test", msg, ch)
+
+ if !stopCalled {
+ t.Fatal("expected typing stop func to be called")
+ }
+}
+
+func TestPreSend_NoRegisteredState(t *testing.T) {
+ m := newTestManager()
+
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ return nil
+ },
+ }
+
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
+ edited := m.preSend(context.Background(), "test", msg, ch)
+
+ if edited {
+ t.Fatal("expected preSend to return false with no registered state")
+ }
+}
+
+func TestPreSend_TypingAndPlaceholder(t *testing.T) {
+ m := newTestManager()
+ var stopCalled bool
+ var editCalled bool
+
+ ch := &mockMessageEditor{
+ mockChannel: mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ return nil
+ },
+ },
+ editFn: func(_ context.Context, _, _, _ string) error {
+ editCalled = true
+ return nil
+ },
+ }
+
+ m.RecordTypingStop("test", "123", func() {
+ stopCalled = true
+ })
+ m.RecordPlaceholder("test", "123", "456")
+
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
+ edited := m.preSend(context.Background(), "test", msg, ch)
+
+ if !stopCalled {
+ t.Fatal("expected typing stop to be called")
+ }
+ if !editCalled {
+ t.Fatal("expected EditMessage to be called")
+ }
+ if !edited {
+ t.Fatal("expected preSend to return true")
+ }
+}
+
+func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) {
+ m := newTestManager()
+
+ var wg sync.WaitGroup
+ for i := range 100 {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ chatID := fmt.Sprintf("chat_%d", i%10)
+ m.RecordPlaceholder("test", chatID, fmt.Sprintf("msg_%d", i))
+ }(i)
+ }
+ wg.Wait()
+}
+
+func TestRecordTypingStop_ConcurrentSafe(t *testing.T) {
+ m := newTestManager()
+
+ var wg sync.WaitGroup
+ for i := range 100 {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ chatID := fmt.Sprintf("chat_%d", i%10)
+ m.RecordTypingStop("test", chatID, func() {})
+ }(i)
+ }
+ wg.Wait()
+}
+
+func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) {
+ m := newTestManager()
+ var sendCalled bool
+
+ ch := &mockMessageEditor{
+ mockChannel: mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ sendCalled = true
+ return nil
+ },
+ },
+ editFn: func(_ context.Context, _, _, _ string) error {
+ return nil // edit succeeds
+ },
+ }
+
+ m.RecordPlaceholder("test", "123", "456")
+
+ w := &channelWorker{
+ ch: ch,
+ limiter: rate.NewLimiter(rate.Inf, 1),
+ }
+
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
+ m.sendWithRetry(context.Background(), "test", w, msg)
+
+ if sendCalled {
+ t.Fatal("expected Send to NOT be called when placeholder was edited")
+ }
+}
+
+// --- Dispatcher exit tests (Step 1) ---
+
+func TestDispatcherExitsOnCancel(t *testing.T) {
+ mb := bus.NewMessageBus()
+ defer mb.Close()
+
+ m := &Manager{
+ channels: make(map[string]Channel),
+ workers: make(map[string]*channelWorker),
+ bus: mb,
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ done := make(chan struct{})
+
+ go func() {
+ m.dispatchOutbound(ctx)
+ close(done)
+ }()
+
+ // Cancel context and verify the dispatcher exits quickly
+ cancel()
+
+ select {
+ case <-done:
+ // success
+ case <-time.After(2 * time.Second):
+ t.Fatal("dispatchOutbound did not exit within 2s after context cancel")
+ }
+}
+
+func TestDispatcherMediaExitsOnCancel(t *testing.T) {
+ mb := bus.NewMessageBus()
+ defer mb.Close()
+
+ m := &Manager{
+ channels: make(map[string]Channel),
+ workers: make(map[string]*channelWorker),
+ bus: mb,
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ done := make(chan struct{})
+
+ go func() {
+ m.dispatchOutboundMedia(ctx)
+ close(done)
+ }()
+
+ cancel()
+
+ select {
+ case <-done:
+ // success
+ case <-time.After(2 * time.Second):
+ t.Fatal("dispatchOutboundMedia did not exit within 2s after context cancel")
+ }
+}
+
+// --- TTL Janitor tests (Step 2) ---
+
+func TestTypingStopJanitorEviction(t *testing.T) {
+ m := newTestManager()
+
+ var stopCalled atomic.Bool
+ // Store a typing entry with a creation time far in the past
+ m.typingStops.Store("test:123", typingEntry{
+ stop: func() { stopCalled.Store(true) },
+ createdAt: time.Now().Add(-10 * time.Minute), // well past typingStopTTL
+ })
+
+ // Run janitor with a short-lived context
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Manually trigger the janitor logic once by simulating a tick
+ go func() {
+ // Override janitor to run immediately
+ now := time.Now()
+ m.typingStops.Range(func(key, value any) bool {
+ if entry, ok := value.(typingEntry); ok {
+ if now.Sub(entry.createdAt) > typingStopTTL {
+ if _, loaded := m.typingStops.LoadAndDelete(key); loaded {
+ entry.stop()
+ }
+ }
+ }
+ return true
+ })
+ cancel()
+ }()
+
+ <-ctx.Done()
+
+ if !stopCalled.Load() {
+ t.Fatal("expected typing stop function to be called by janitor eviction")
+ }
+
+ // Verify entry was deleted
+ if _, loaded := m.typingStops.Load("test:123"); loaded {
+ t.Fatal("expected typing entry to be deleted after eviction")
+ }
+}
+
+func TestPlaceholderJanitorEviction(t *testing.T) {
+ m := newTestManager()
+
+ // Store a placeholder entry with a creation time far in the past
+ m.placeholders.Store("test:456", placeholderEntry{
+ id: "msg_old",
+ createdAt: time.Now().Add(-20 * time.Minute), // well past placeholderTTL
+ })
+
+ // Simulate janitor logic
+ now := time.Now()
+ m.placeholders.Range(func(key, value any) bool {
+ if entry, ok := value.(placeholderEntry); ok {
+ if now.Sub(entry.createdAt) > placeholderTTL {
+ m.placeholders.Delete(key)
+ }
+ }
+ return true
+ })
+
+ // Verify entry was deleted
+ if _, loaded := m.placeholders.Load("test:456"); loaded {
+ t.Fatal("expected placeholder entry to be deleted after eviction")
+ }
+}
+
+func TestPreSendStillWorksWithWrappedTypes(t *testing.T) {
+ m := newTestManager()
+ var stopCalled bool
+ var editCalled bool
+
+ ch := &mockMessageEditor{
+ mockChannel: mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ return nil
+ },
+ },
+ editFn: func(_ context.Context, chatID, messageID, content string) error {
+ editCalled = true
+ if messageID != "ph_id" {
+ t.Fatalf("expected messageID ph_id, got %s", messageID)
+ }
+ return nil
+ },
+ }
+
+ // Use the new wrapped types via the public API
+ m.RecordTypingStop("test", "chat1", func() {
+ stopCalled = true
+ })
+ m.RecordPlaceholder("test", "chat1", "ph_id")
+
+ msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}
+ edited := m.preSend(context.Background(), "test", msg, ch)
+
+ if !stopCalled {
+ t.Fatal("expected typing stop to be called via wrapped type")
+ }
+ if !editCalled {
+ t.Fatal("expected EditMessage to be called via wrapped type")
+ }
+ if !edited {
+ t.Fatal("expected preSend to return true")
+ }
+}
+
+// --- Lazy worker creation tests (Step 6) ---
+
+func TestLazyWorkerCreation(t *testing.T) {
+ m := newTestManager()
+
+ ch := &mockChannel{
+ sendFn: func(_ context.Context, _ bus.OutboundMessage) error {
+ return nil
+ },
+ }
+
+ // RegisterChannel should NOT create a worker
+ m.RegisterChannel("lazy", ch)
+
+ m.mu.RLock()
+ _, chExists := m.channels["lazy"]
+ _, wExists := m.workers["lazy"]
+ m.mu.RUnlock()
+
+ if !chExists {
+ t.Fatal("expected channel to be registered")
+ }
+ if wExists {
+ t.Fatal("expected worker to NOT be created by RegisterChannel (lazy creation)")
+ }
+}
+
+// --- FastID uniqueness test (Step 5) ---
+
+func TestBuildMediaScope_FastIDUniqueness(t *testing.T) {
+ seen := make(map[string]bool)
+
+ for range 1000 {
+ scope := BuildMediaScope("test", "chat1", "")
+ if seen[scope] {
+ t.Fatalf("duplicate scope generated: %s", scope)
+ }
+ seen[scope] = true
+ }
+
+ // Verify format: "channel:chatID:id"
+ scope := BuildMediaScope("telegram", "42", "")
+ parts := 0
+ for _, c := range scope {
+ if c == ':' {
+ parts++
+ }
+ }
+ if parts != 2 {
+ t.Fatalf("expected scope to have 2 colons (channel:chatID:id), got: %s", scope)
+ }
+}
+
+func TestBuildMediaScope_WithMessageID(t *testing.T) {
+ scope := BuildMediaScope("discord", "chat99", "msg123")
+ expected := "discord:chat99:msg123"
+ if scope != expected {
+ t.Fatalf("expected %s, got %s", expected, scope)
+ }
+}
diff --git a/pkg/channels/media.go b/pkg/channels/media.go
new file mode 100644
index 000000000..c645a6180
--- /dev/null
+++ b/pkg/channels/media.go
@@ -0,0 +1,15 @@
+package channels
+
+import (
+ "context"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+)
+
+// MediaSender is an optional interface for channels that can send
+// media attachments (images, files, audio, video).
+// Manager discovers channels implementing this interface via type
+// assertion and routes OutboundMediaMessage to them.
+type MediaSender interface {
+ SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error
+}
diff --git a/pkg/channels/onebot/init.go b/pkg/channels/onebot/init.go
new file mode 100644
index 000000000..84c06dfd6
--- /dev/null
+++ b/pkg/channels/onebot/init.go
@@ -0,0 +1,13 @@
+package onebot
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewOneBotChannel(cfg.Channels.OneBot, b)
+ })
+}
diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot/onebot.go
similarity index 75%
rename from pkg/channels/onebot.go
rename to pkg/channels/onebot/onebot.go
index 4576a11ce..62a9eb34a 100644
--- a/pkg/channels/onebot.go
+++ b/pkg/channels/onebot/onebot.go
@@ -1,10 +1,9 @@
-package channels
+package onebot
import (
"context"
"encoding/json"
"fmt"
- "os"
"strconv"
"strings"
"sync"
@@ -14,30 +13,30 @@ import (
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
- "github.com/sipeed/picoclaw/pkg/voice"
)
type OneBotChannel struct {
- *BaseChannel
- config config.OneBotConfig
- conn *websocket.Conn
- ctx context.Context
- cancel context.CancelFunc
- dedup map[string]struct{}
- dedupRing []string
- dedupIdx int
- mu sync.Mutex
- writeMu sync.Mutex
- echoCounter int64
- selfID int64
- pending map[string]chan json.RawMessage
- pendingMu sync.Mutex
- transcriber *voice.GroqTranscriber
- lastMessageID sync.Map
- pendingEmojiMsg sync.Map
+ *channels.BaseChannel
+ config config.OneBotConfig
+ conn *websocket.Conn
+ ctx context.Context
+ cancel context.CancelFunc
+ dedup map[string]struct{}
+ dedupRing []string
+ dedupIdx int
+ mu sync.Mutex
+ writeMu sync.Mutex
+ echoCounter int64
+ selfID int64
+ pending map[string]chan json.RawMessage
+ pendingMu sync.Mutex
+ lastMessageID sync.Map
}
type oneBotRawEvent struct {
@@ -98,7 +97,10 @@ type oneBotMessageSegment struct {
}
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
- base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom)
+ base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom,
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
const dedupSize = 1024
return &OneBotChannel{
@@ -111,10 +113,6 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One
}, nil
}
-func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
- c.transcriber = transcriber
-}
-
func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) {
go func() {
_, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{
@@ -131,6 +129,22 @@ func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool)
}()
}
+// ReactToMessage implements channels.ReactionCapable.
+// It adds an emoji reaction (ID 289) to group messages and returns an undo function.
+// Private messages return a no-op since reactions are only meaningful in groups.
+func (c *OneBotChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
+ // Only react in group chats
+ if !strings.HasPrefix(chatID, "group:") {
+ return func() {}, nil
+ }
+
+ c.setMsgEmojiLike(messageID, 289, true)
+
+ return func() {
+ c.setMsgEmojiLike(messageID, 289, false)
+ }, nil
+}
+
func (c *OneBotChannel) Start(ctx context.Context) error {
if c.config.WSUrl == "" {
return fmt.Errorf("OneBot ws_url not configured")
@@ -159,7 +173,7 @@ func (c *OneBotChannel) Start(ctx context.Context) error {
}
}
- c.setRunning(true)
+ c.SetRunning(true)
logger.InfoC("onebot", "OneBot channel started successfully")
return nil
@@ -300,7 +314,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
}
c.writeMu.Lock()
+ _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, data)
+ _ = conn.SetWriteDeadline(time.Time{})
c.writeMu.Unlock()
if err != nil {
@@ -309,6 +325,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
select {
case resp := <-ch:
+ if resp == nil {
+ return nil, fmt.Errorf("API request %s: channel stopped", action)
+ }
return resp, nil
case <-time.After(timeout):
return nil, fmt.Errorf("API request %s timed out after %v", action, timeout)
@@ -318,10 +337,7 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D
}
func (c *OneBotChannel) reconnectLoop() {
- interval := time.Duration(c.config.ReconnectInterval) * time.Second
- if interval < 5*time.Second {
- interval = 5 * time.Second
- }
+ interval := max(time.Duration(c.config.ReconnectInterval)*time.Second, 5*time.Second)
for {
select {
@@ -349,7 +365,7 @@ func (c *OneBotChannel) reconnectLoop() {
func (c *OneBotChannel) Stop(ctx context.Context) error {
logger.InfoC("onebot", "Stopping OneBot channel")
- c.setRunning(false)
+ c.SetRunning(false)
if c.cancel != nil {
c.cancel()
@@ -357,7 +373,10 @@ func (c *OneBotChannel) Stop(ctx context.Context) error {
c.pendingMu.Lock()
for echo, ch := range c.pending {
- close(ch)
+ select {
+ case ch <- nil: // non-blocking wake for blocked sendAPIRequest goroutines
+ default:
+ }
delete(c.pending, echo)
}
c.pendingMu.Unlock()
@@ -374,7 +393,14 @@ func (c *OneBotChannel) Stop(ctx context.Context) error {
func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("OneBot channel not running")
+ return channels.ErrNotRunning
+ }
+
+ // Check ctx before entering write path
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
}
c.mu.Lock()
@@ -404,20 +430,127 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
}
c.writeMu.Lock()
+ _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
err = conn.WriteMessage(websocket.TextMessage, data)
+ _ = conn.SetWriteDeadline(time.Time{})
c.writeMu.Unlock()
if err != nil {
logger.ErrorCF("onebot", "Failed to send message", map[string]any{
"error": err.Error(),
})
- return err
+ return fmt.Errorf("onebot send: %w", channels.ErrTemporary)
}
- if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok {
- if mid, ok := msgID.(string); ok && mid != "" {
- c.setMsgEmojiLike(mid, 289, false)
+ return nil
+}
+
+// SendMedia implements the channels.MediaSender interface.
+func (c *OneBotChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
+ c.mu.Lock()
+ conn := c.conn
+ c.mu.Unlock()
+
+ if conn == nil {
+ return fmt.Errorf("OneBot WebSocket not connected")
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ // Build media segments
+ var segments []oneBotMessageSegment
+ for _, part := range msg.Parts {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("onebot", "Failed to resolve media ref", map[string]any{
+ "ref": part.Ref,
+ "error": err.Error(),
+ })
+ continue
}
+
+ var segType string
+ switch part.Type {
+ case "image":
+ segType = "image"
+ case "video":
+ segType = "video"
+ case "audio":
+ segType = "record"
+ default:
+ segType = "file"
+ }
+
+ segments = append(segments, oneBotMessageSegment{
+ Type: segType,
+ Data: map[string]any{"file": "file://" + localPath},
+ })
+
+ if part.Caption != "" {
+ segments = append(segments, oneBotMessageSegment{
+ Type: "text",
+ Data: map[string]any{"text": part.Caption},
+ })
+ }
+ }
+
+ if len(segments) == 0 {
+ return nil
+ }
+
+ chatID := msg.ChatID
+ var action, idKey string
+ var rawID string
+ if rest, ok := strings.CutPrefix(chatID, "group:"); ok {
+ action, idKey, rawID = "send_group_msg", "group_id", rest
+ } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok {
+ action, idKey, rawID = "send_private_msg", "user_id", rest
+ } else {
+ action, idKey, rawID = "send_private_msg", "user_id", chatID
+ }
+
+ id, err := strconv.ParseInt(rawID, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid %s in chatID: %s: %w", idKey, chatID, channels.ErrSendFailed)
+ }
+
+ echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1))
+
+ req := oneBotAPIRequest{
+ Action: action,
+ Params: map[string]any{idKey: id, "message": segments},
+ Echo: echo,
+ }
+
+ data, err := json.Marshal(req)
+ if err != nil {
+ return fmt.Errorf("failed to marshal OneBot request: %w", err)
+ }
+
+ c.writeMu.Lock()
+ _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
+ err = conn.WriteMessage(websocket.TextMessage, data)
+ _ = conn.SetWriteDeadline(time.Time{})
+ c.writeMu.Unlock()
+
+ if err != nil {
+ logger.ErrorCF("onebot", "Failed to send media message", map[string]any{
+ "error": err.Error(),
+ })
+ return fmt.Errorf("onebot send media: %w", channels.ErrTemporary)
}
return nil
@@ -574,11 +707,15 @@ type parseMessageResult struct {
Text string
IsBotMentioned bool
Media []string
- LocalFiles []string
ReplyTo string
}
-func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult {
+func (c *OneBotChannel) parseMessageSegments(
+ raw json.RawMessage,
+ selfID int64,
+ store media.MediaStore,
+ scope string,
+) parseMessageResult {
if len(raw) == 0 {
return parseMessageResult{}
}
@@ -605,10 +742,23 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
var textParts []string
mentioned := false
selfIDStr := strconv.FormatInt(selfID, 10)
- var media []string
- var localFiles []string
+ var mediaRefs []string
var replyTo string
+ // Helper to register a local file with the media store
+ storeFile := func(localPath, filename string) string {
+ if store != nil {
+ ref, err := store.Store(localPath, media.MediaMeta{
+ Filename: filename,
+ Source: "onebot",
+ }, scope)
+ if err == nil {
+ return ref
+ }
+ }
+ return localPath // fallback
+ }
+
for _, seg := range segments {
segType, _ := seg["type"].(string)
data, _ := seg["data"].(map[string]any)
@@ -644,8 +794,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
LoggerPrefix: "onebot",
})
if localPath != "" {
- media = append(media, localPath)
- localFiles = append(localFiles, localPath)
+ mediaRefs = append(mediaRefs, storeFile(localPath, filename))
textParts = append(textParts, fmt.Sprintf("[%s]", segType))
}
}
@@ -659,24 +808,8 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
LoggerPrefix: "onebot",
})
if localPath != "" {
- localFiles = append(localFiles, localPath)
- if c.transcriber != nil && c.transcriber.IsAvailable() {
- tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second)
- result, err := c.transcriber.Transcribe(tctx, localPath)
- tcancel()
- if err != nil {
- logger.WarnCF("onebot", "Voice transcription failed", map[string]any{
- "error": err.Error(),
- })
- textParts = append(textParts, "[voice (transcription failed)]")
- media = append(media, localPath)
- } else {
- textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text))
- }
- } else {
- textParts = append(textParts, "[voice]")
- media = append(media, localPath)
- }
+ textParts = append(textParts, "[voice]")
+ mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr"))
}
}
}
@@ -704,8 +837,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
return parseMessageResult{
Text: strings.TrimSpace(strings.Join(textParts, "")),
IsBotMentioned: mentioned,
- Media: media,
- LocalFiles: localFiles,
+ Media: mediaRefs,
ReplyTo: replyTo,
}
}
@@ -714,7 +846,13 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
switch raw.PostType {
case "message":
if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 {
- if !c.IsAllowed(strconv.FormatInt(userID, 10)) {
+ // Build minimal sender for allowlist check
+ sender := bus.SenderInfo{
+ Platform: "onebot",
+ PlatformID: strconv.FormatInt(userID, 10),
+ CanonicalID: identity.BuildCanonicalID("onebot", strconv.FormatInt(userID, 10)),
+ }
+ if !c.IsAllowedSender(sender) {
logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{
"user_id": userID,
})
@@ -797,7 +935,17 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
selfID = atomic.LoadInt64(&c.selfID)
}
- parsed := c.parseMessageSegments(raw.Message, selfID)
+ // Compute scope for media store before parsing (parsing may download files)
+ var chatIDForScope string
+ switch raw.MessageType {
+ case "group":
+ chatIDForScope = "group:" + strconv.FormatInt(groupID, 10)
+ default:
+ chatIDForScope = "private:" + strconv.FormatInt(userID, 10)
+ }
+ scope := channels.BuildMediaScope("onebot", chatIDForScope, messageID)
+
+ parsed := c.parseMessageSegments(raw.Message, selfID, c.GetMediaStore(), scope)
isBotMentioned := parsed.IsBotMentioned
content := raw.RawMessage
@@ -826,20 +974,6 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
}
}
- // Clean up temp files when done
- if len(parsed.LocalFiles) > 0 {
- defer func() {
- for _, f := range parsed.LocalFiles {
- if err := os.Remove(f); err != nil {
- logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{
- "path": f,
- "error": err.Error(),
- })
- }
- }
- }()
- }
-
if c.isDuplicate(messageID) {
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{
"message_id": messageID,
@@ -857,9 +991,9 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
senderID := strconv.FormatInt(userID, 10)
var chatID string
- metadata := map[string]string{
- "message_id": messageID,
- }
+ var peer bus.Peer
+
+ metadata := map[string]string{}
if parsed.ReplyTo != "" {
metadata["reply_to_message_id"] = parsed.ReplyTo
@@ -868,14 +1002,12 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
switch raw.MessageType {
case "private":
chatID = "private:" + senderID
- metadata["peer_kind"] = "direct"
- metadata["peer_id"] = senderID
+ peer = bus.Peer{Kind: "direct", ID: senderID}
case "group":
groupIDStr := strconv.FormatInt(groupID, 10)
chatID = "group:" + groupIDStr
- metadata["peer_kind"] = "group"
- metadata["peer_id"] = groupIDStr
+ peer = bus.Peer{Kind: "group", ID: groupIDStr}
metadata["group_id"] = groupIDStr
senderUserID, _ := parseJSONInt64(sender.UserID)
@@ -889,8 +1021,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
metadata["sender_name"] = sender.Nickname
}
- triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned)
- if !triggered {
+ respond, strippedContent := c.ShouldRespondInGroup(isBotMentioned, content)
+ if !respond {
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{
"sender": senderID,
"group": groupIDStr,
@@ -925,12 +1057,21 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
c.lastMessageID.Store(chatID, messageID)
- if raw.MessageType == "group" && messageID != "" && messageID != "0" {
- c.setMsgEmojiLike(messageID, 289, true)
- c.pendingEmojiMsg.Store(chatID, messageID)
+ senderInfo := bus.SenderInfo{
+ Platform: "onebot",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("onebot", senderID),
+ DisplayName: sender.Nickname,
}
- c.HandleMessage(senderID, chatID, content, parsed.Media, metadata)
+ if !c.IsAllowedSender(senderInfo) {
+ logger.DebugCF("onebot", "Message rejected by allowlist (senderInfo)", map[string]any{
+ "sender": senderID,
+ })
+ return
+ }
+
+ c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
@@ -962,23 +1103,3 @@ func truncate(s string, n int) string {
}
return string(runes[:n]) + "..."
}
-
-func (c *OneBotChannel) checkGroupTrigger(
- content string,
- isBotMentioned bool,
-) (triggered bool, strippedContent string) {
- if isBotMentioned {
- return true, strings.TrimSpace(content)
- }
-
- for _, prefix := range c.config.GroupTriggerPrefix {
- if prefix == "" {
- continue
- }
- if strings.HasPrefix(content, prefix) {
- return true, strings.TrimSpace(strings.TrimPrefix(content, prefix))
- }
- }
-
- return false, content
-}
diff --git a/pkg/channels/pico/init.go b/pkg/channels/pico/init.go
new file mode 100644
index 000000000..96d764418
--- /dev/null
+++ b/pkg/channels/pico/init.go
@@ -0,0 +1,13 @@
+package pico
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewPicoChannel(cfg.Channels.Pico, b)
+ })
+}
diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go
new file mode 100644
index 000000000..8d8b62a67
--- /dev/null
+++ b/pkg/channels/pico/pico.go
@@ -0,0 +1,462 @@
+package pico
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/gorilla/websocket"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// picoConn represents a single WebSocket connection.
+type picoConn struct {
+ id string
+ conn *websocket.Conn
+ sessionID string
+ writeMu sync.Mutex
+ closed atomic.Bool
+}
+
+// writeJSON sends a JSON message to the connection with write locking.
+func (pc *picoConn) writeJSON(v any) error {
+ if pc.closed.Load() {
+ return fmt.Errorf("connection closed")
+ }
+ pc.writeMu.Lock()
+ defer pc.writeMu.Unlock()
+ return pc.conn.WriteJSON(v)
+}
+
+// close closes the connection.
+func (pc *picoConn) close() {
+ if pc.closed.CompareAndSwap(false, true) {
+ pc.conn.Close()
+ }
+}
+
+// PicoChannel implements the native Pico Protocol WebSocket channel.
+// It serves as the reference implementation for all optional capability interfaces.
+type PicoChannel struct {
+ *channels.BaseChannel
+ config config.PicoConfig
+ upgrader websocket.Upgrader
+ connections sync.Map // connID → *picoConn
+ connCount atomic.Int32
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+// NewPicoChannel creates a new Pico Protocol channel.
+func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) {
+ if cfg.Token == "" {
+ return nil, fmt.Errorf("pico token is required")
+ }
+
+ base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom)
+
+ allowOrigins := cfg.AllowOrigins
+ checkOrigin := func(r *http.Request) bool {
+ if len(allowOrigins) == 0 {
+ return true // allow all if not configured
+ }
+ origin := r.Header.Get("Origin")
+ for _, allowed := range allowOrigins {
+ if allowed == "*" || allowed == origin {
+ return true
+ }
+ }
+ return false
+ }
+
+ return &PicoChannel{
+ BaseChannel: base,
+ config: cfg,
+ upgrader: websocket.Upgrader{
+ CheckOrigin: checkOrigin,
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ },
+ }, nil
+}
+
+// Start implements Channel.
+func (c *PicoChannel) Start(ctx context.Context) error {
+ logger.InfoC("pico", "Starting Pico Protocol channel")
+ c.ctx, c.cancel = context.WithCancel(ctx)
+ c.SetRunning(true)
+ logger.InfoC("pico", "Pico Protocol channel started")
+ return nil
+}
+
+// Stop implements Channel.
+func (c *PicoChannel) Stop(ctx context.Context) error {
+ logger.InfoC("pico", "Stopping Pico Protocol channel")
+ c.SetRunning(false)
+
+ // Close all connections
+ c.connections.Range(func(key, value any) bool {
+ if pc, ok := value.(*picoConn); ok {
+ pc.close()
+ }
+ c.connections.Delete(key)
+ return true
+ })
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ logger.InfoC("pico", "Pico Protocol channel stopped")
+ return nil
+}
+
+// WebhookPath implements channels.WebhookHandler.
+func (c *PicoChannel) WebhookPath() string { return "/pico/" }
+
+// ServeHTTP implements http.Handler for the shared HTTP server.
+func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ path := strings.TrimPrefix(r.URL.Path, "/pico")
+
+ switch {
+ case path == "/ws" || path == "/ws/":
+ c.handleWebSocket(w, r)
+ default:
+ http.NotFound(w, r)
+ }
+}
+
+// Send implements Channel — sends a message to the appropriate WebSocket connection.
+func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ outMsg := newMessage(TypeMessageCreate, map[string]any{
+ "content": msg.Content,
+ })
+
+ return c.broadcastToSession(msg.ChatID, outMsg)
+}
+
+// EditMessage implements channels.MessageEditor.
+func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
+ outMsg := newMessage(TypeMessageUpdate, map[string]any{
+ "message_id": messageID,
+ "content": content,
+ })
+ return c.broadcastToSession(chatID, outMsg)
+}
+
+// StartTyping implements channels.TypingCapable.
+func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
+ startMsg := newMessage(TypeTypingStart, nil)
+ if err := c.broadcastToSession(chatID, startMsg); err != nil {
+ return func() {}, err
+ }
+ return func() {
+ stopMsg := newMessage(TypeTypingStop, nil)
+ c.broadcastToSession(chatID, stopMsg)
+ }, nil
+}
+
+// SendPlaceholder implements channels.PlaceholderCapable.
+// It sends a placeholder message via the Pico Protocol that will later be
+// edited to the actual response via EditMessage (channels.MessageEditor).
+func (c *PicoChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ if !c.config.Placeholder.Enabled {
+ return "", nil
+ }
+
+ text := c.config.Placeholder.Text
+ if text == "" {
+ text = "Thinking... 💭"
+ }
+
+ msgID := uuid.New().String()
+ outMsg := newMessage(TypeMessageCreate, map[string]any{
+ "content": text,
+ "message_id": msgID,
+ })
+
+ if err := c.broadcastToSession(chatID, outMsg); err != nil {
+ return "", err
+ }
+
+ return msgID, nil
+}
+
+// broadcastToSession sends a message to all connections with a matching session.
+func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error {
+ // chatID format: "pico:"
+ sessionID := strings.TrimPrefix(chatID, "pico:")
+ msg.SessionID = sessionID
+
+ var sent bool
+ c.connections.Range(func(key, value any) bool {
+ pc, ok := value.(*picoConn)
+ if !ok {
+ return true
+ }
+ if pc.sessionID == sessionID {
+ if err := pc.writeJSON(msg); err != nil {
+ logger.DebugCF("pico", "Write to connection failed", map[string]any{
+ "conn_id": pc.id,
+ "error": err.Error(),
+ })
+ } else {
+ sent = true
+ }
+ }
+ return true
+ })
+
+ if !sent {
+ return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed)
+ }
+ return nil
+}
+
+// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle.
+func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
+ if !c.IsRunning() {
+ http.Error(w, "channel not running", http.StatusServiceUnavailable)
+ return
+ }
+
+ // Authenticate
+ if !c.authenticate(r) {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+
+ // Check connection limit
+ maxConns := c.config.MaxConnections
+ if maxConns <= 0 {
+ maxConns = 100
+ }
+ if int(c.connCount.Load()) >= maxConns {
+ http.Error(w, "too many connections", http.StatusServiceUnavailable)
+ return
+ }
+
+ conn, err := c.upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
+ "error": err.Error(),
+ })
+ return
+ }
+
+ // Determine session ID from query param or generate one
+ sessionID := r.URL.Query().Get("session_id")
+ if sessionID == "" {
+ sessionID = uuid.New().String()
+ }
+
+ pc := &picoConn{
+ id: uuid.New().String(),
+ conn: conn,
+ sessionID: sessionID,
+ }
+
+ c.connections.Store(pc.id, pc)
+ c.connCount.Add(1)
+
+ logger.InfoCF("pico", "WebSocket client connected", map[string]any{
+ "conn_id": pc.id,
+ "session_id": sessionID,
+ })
+
+ go c.readLoop(pc)
+}
+
+// authenticate checks the Bearer token from the Authorization header.
+// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
+func (c *PicoChannel) authenticate(r *http.Request) bool {
+ token := c.config.Token
+ if token == "" {
+ return false
+ }
+
+ // Check Authorization header
+ auth := r.Header.Get("Authorization")
+ if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
+ if after == token {
+ return true
+ }
+ }
+
+ // Check query parameter only when explicitly allowed
+ if c.config.AllowTokenQuery {
+ if r.URL.Query().Get("token") == token {
+ return true
+ }
+ }
+
+ return false
+}
+
+// readLoop reads messages from a WebSocket connection.
+func (c *PicoChannel) readLoop(pc *picoConn) {
+ defer func() {
+ pc.close()
+ c.connections.Delete(pc.id)
+ c.connCount.Add(-1)
+ logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{
+ "conn_id": pc.id,
+ "session_id": pc.sessionID,
+ })
+ }()
+
+ readTimeout := time.Duration(c.config.ReadTimeout) * time.Second
+ if readTimeout <= 0 {
+ readTimeout = 60 * time.Second
+ }
+
+ _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
+ pc.conn.SetPongHandler(func(appData string) error {
+ _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
+ return nil
+ })
+
+ // Start ping ticker
+ pingInterval := time.Duration(c.config.PingInterval) * time.Second
+ if pingInterval <= 0 {
+ pingInterval = 30 * time.Second
+ }
+ go c.pingLoop(pc, pingInterval)
+
+ for {
+ select {
+ case <-c.ctx.Done():
+ return
+ default:
+ }
+
+ _, rawMsg, err := pc.conn.ReadMessage()
+ if err != nil {
+ if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
+ logger.DebugCF("pico", "WebSocket read error", map[string]any{
+ "conn_id": pc.id,
+ "error": err.Error(),
+ })
+ }
+ return
+ }
+
+ _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout))
+
+ var msg PicoMessage
+ if err := json.Unmarshal(rawMsg, &msg); err != nil {
+ errMsg := newError("invalid_message", "failed to parse message")
+ pc.writeJSON(errMsg)
+ continue
+ }
+
+ c.handleMessage(pc, msg)
+ }
+}
+
+// pingLoop sends periodic ping frames to keep the connection alive.
+func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) {
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-c.ctx.Done():
+ return
+ case <-ticker.C:
+ if pc.closed.Load() {
+ return
+ }
+ pc.writeMu.Lock()
+ err := pc.conn.WriteMessage(websocket.PingMessage, nil)
+ pc.writeMu.Unlock()
+ if err != nil {
+ return
+ }
+ }
+ }
+}
+
+// handleMessage processes an inbound Pico Protocol message.
+func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) {
+ switch msg.Type {
+ case TypePing:
+ pong := newMessage(TypePong, nil)
+ pong.ID = msg.ID
+ pc.writeJSON(pong)
+
+ case TypeMessageSend:
+ c.handleMessageSend(pc, msg)
+
+ default:
+ errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type))
+ pc.writeJSON(errMsg)
+ }
+}
+
+// handleMessageSend processes an inbound message.send from a client.
+func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
+ content, _ := msg.Payload["content"].(string)
+ if strings.TrimSpace(content) == "" {
+ errMsg := newError("empty_content", "message content is empty")
+ pc.writeJSON(errMsg)
+ return
+ }
+
+ sessionID := msg.SessionID
+ if sessionID == "" {
+ sessionID = pc.sessionID
+ }
+
+ chatID := "pico:" + sessionID
+ senderID := "pico-user"
+
+ peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
+
+ metadata := map[string]string{
+ "platform": "pico",
+ "session_id": sessionID,
+ "conn_id": pc.id,
+ }
+
+ logger.DebugCF("pico", "Received message", map[string]any{
+ "session_id": sessionID,
+ "preview": truncate(content, 50),
+ })
+
+ sender := bus.SenderInfo{
+ Platform: "pico",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("pico", senderID),
+ }
+
+ if !c.IsAllowedSender(sender) {
+ return
+ }
+
+ c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender)
+}
+
+// truncate truncates a string to maxLen runes.
+func truncate(s string, maxLen int) string {
+ runes := []rune(s)
+ if len(runes) <= maxLen {
+ return s
+ }
+ return string(runes[:maxLen]) + "..."
+}
diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go
new file mode 100644
index 000000000..0a630e193
--- /dev/null
+++ b/pkg/channels/pico/protocol.go
@@ -0,0 +1,46 @@
+package pico
+
+import "time"
+
+// Protocol message types.
+const (
+ // TypeMessageSend is sent from client to server.
+ TypeMessageSend = "message.send"
+ TypeMediaSend = "media.send"
+ TypePing = "ping"
+
+ // TypeMessageCreate is sent from server to client.
+ TypeMessageCreate = "message.create"
+ TypeMessageUpdate = "message.update"
+ TypeMediaCreate = "media.create"
+ TypeTypingStart = "typing.start"
+ TypeTypingStop = "typing.stop"
+ TypeError = "error"
+ TypePong = "pong"
+)
+
+// PicoMessage is the wire format for all Pico Protocol messages.
+type PicoMessage struct {
+ Type string `json:"type"`
+ ID string `json:"id,omitempty"`
+ SessionID string `json:"session_id,omitempty"`
+ Timestamp int64 `json:"timestamp,omitempty"`
+ Payload map[string]any `json:"payload,omitempty"`
+}
+
+// newMessage creates a PicoMessage with the given type and payload.
+func newMessage(msgType string, payload map[string]any) PicoMessage {
+ return PicoMessage{
+ Type: msgType,
+ Timestamp: time.Now().UnixMilli(),
+ Payload: payload,
+ }
+}
+
+// newError creates an error PicoMessage.
+func newError(code, message string) PicoMessage {
+ return newMessage(TypeError, map[string]any{
+ "code": code,
+ "message": message,
+ })
+}
diff --git a/pkg/channels/qq/init.go b/pkg/channels/qq/init.go
new file mode 100644
index 000000000..15b955089
--- /dev/null
+++ b/pkg/channels/qq/init.go
@@ -0,0 +1,13 @@
+package qq
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewQQChannel(cfg.Channels.QQ, b)
+ })
+}
diff --git a/pkg/channels/qq.go b/pkg/channels/qq/qq.go
similarity index 74%
rename from pkg/channels/qq.go
rename to pkg/channels/qq/qq.go
index b10776db6..112964143 100644
--- a/pkg/channels/qq.go
+++ b/pkg/channels/qq/qq.go
@@ -1,4 +1,4 @@
-package channels
+package qq
import (
"context"
@@ -14,12 +14,14 @@ import (
"golang.org/x/oauth2"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
)
type QQChannel struct {
- *BaseChannel
+ *channels.BaseChannel
config config.QQConfig
api openapi.OpenAPI
tokenSource oauth2.TokenSource
@@ -31,7 +33,10 @@ type QQChannel struct {
}
func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) {
- base := NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom)
+ base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom,
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
return &QQChannel{
BaseChannel: base,
@@ -90,11 +95,11 @@ func (c *QQChannel) Start(ctx context.Context) error {
logger.ErrorCF("qq", "WebSocket session error", map[string]any{
"error": err.Error(),
})
- c.setRunning(false)
+ c.SetRunning(false)
}
}()
- c.setRunning(true)
+ c.SetRunning(true)
logger.InfoC("qq", "QQ bot started successfully")
return nil
@@ -102,7 +107,7 @@ func (c *QQChannel) Start(ctx context.Context) error {
func (c *QQChannel) Stop(ctx context.Context) error {
logger.InfoC("qq", "Stopping QQ bot")
- c.setRunning(false)
+ c.SetRunning(false)
if c.cancel != nil {
c.cancel()
@@ -113,7 +118,7 @@ func (c *QQChannel) Stop(ctx context.Context) error {
func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("QQ bot not running")
+ return channels.ErrNotRunning
}
// construct message
@@ -127,7 +132,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{
"error": err.Error(),
})
- return err
+ return fmt.Errorf("qq send: %w", channels.ErrTemporary)
}
return nil
@@ -162,20 +167,35 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
"length": len(content),
})
- // forward to message bus
- metadata := map[string]string{
- "message_id": data.ID,
- "peer_kind": "direct",
- "peer_id": senderID,
+ // 转发到消息总线
+ metadata := map[string]string{}
+
+ sender := bus.SenderInfo{
+ Platform: "qq",
+ PlatformID: data.Author.ID,
+ CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID),
}
- c.HandleMessage(senderID, senderID, content, []string{}, metadata)
+ if !c.IsAllowedSender(sender) {
+ return nil
+ }
+
+ c.HandleMessage(c.ctx,
+ bus.Peer{Kind: "direct", ID: senderID},
+ data.ID,
+ senderID,
+ senderID,
+ content,
+ []string{},
+ metadata,
+ sender,
+ )
return nil
}
}
-// handleGroupATMessage handles group @messages
+// handleGroupATMessage handles QQ group @ messages
func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error {
// deduplication check
@@ -192,34 +212,57 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
return nil
}
- // extract message content (remove @bot part)
+ // extract message content (remove @ bot part)
content := data.Content
if content == "" {
logger.DebugC("qq", "Received empty group message, ignoring")
return nil
}
+ // GroupAT event means bot is always mentioned; apply group trigger filtering
+ respond, cleaned := c.ShouldRespondInGroup(true, content)
+ if !respond {
+ return nil
+ }
+ content = cleaned
+
logger.InfoCF("qq", "Received group AT message", map[string]any{
"sender": senderID,
"group": data.GroupID,
"length": len(content),
})
- // forward to message bus (use GroupID as ChatID)
+ // 转发到消息总线(使用 GroupID 作为 ChatID)
metadata := map[string]string{
- "message_id": data.ID,
- "group_id": data.GroupID,
- "peer_kind": "group",
- "peer_id": data.GroupID,
+ "group_id": data.GroupID,
}
- c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata)
+ sender := bus.SenderInfo{
+ Platform: "qq",
+ PlatformID: data.Author.ID,
+ CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID),
+ }
+
+ if !c.IsAllowedSender(sender) {
+ return nil
+ }
+
+ c.HandleMessage(c.ctx,
+ bus.Peer{Kind: "group", ID: data.GroupID},
+ data.ID,
+ senderID,
+ data.GroupID,
+ content,
+ []string{},
+ metadata,
+ sender,
+ )
return nil
}
}
-// isDuplicate checks if message is duplicate
+// isDuplicate 检查消息是否重复
func (c *QQChannel) isDuplicate(messageID string) bool {
c.mu.Lock()
defer c.mu.Unlock()
@@ -230,9 +273,9 @@ func (c *QQChannel) isDuplicate(messageID string) bool {
c.processedIDs[messageID] = true
- // simple cleanup: limit map size
+ // 简单清理:限制 map 大小
if len(c.processedIDs) > 10000 {
- // clear half
+ // 清空一半
count := 0
for id := range c.processedIDs {
if count >= 5000 {
diff --git a/pkg/channels/registry.go b/pkg/channels/registry.go
new file mode 100644
index 000000000..36a05bf3e
--- /dev/null
+++ b/pkg/channels/registry.go
@@ -0,0 +1,32 @@
+package channels
+
+import (
+ "sync"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// ChannelFactory is a constructor function that creates a Channel from config and message bus.
+// Each channel subpackage registers one or more factories via init().
+type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error)
+
+var (
+ factoriesMu sync.RWMutex
+ factories = map[string]ChannelFactory{}
+)
+
+// RegisterFactory registers a named channel factory. Called from subpackage init() functions.
+func RegisterFactory(name string, f ChannelFactory) {
+ factoriesMu.Lock()
+ defer factoriesMu.Unlock()
+ factories[name] = f
+}
+
+// getFactory looks up a channel factory by name.
+func getFactory(name string) (ChannelFactory, bool) {
+ factoriesMu.RLock()
+ defer factoriesMu.RUnlock()
+ f, ok := factories[name]
+ return f, ok
+}
diff --git a/pkg/channels/slack/init.go b/pkg/channels/slack/init.go
new file mode 100644
index 000000000..c131bb291
--- /dev/null
+++ b/pkg/channels/slack/init.go
@@ -0,0 +1,13 @@
+package slack
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewSlackChannel(cfg.Channels.Slack, b)
+ })
+}
diff --git a/pkg/channels/slack.go b/pkg/channels/slack/slack.go
similarity index 66%
rename from pkg/channels/slack.go
rename to pkg/channels/slack/slack.go
index cfb731b16..024b1b023 100644
--- a/pkg/channels/slack.go
+++ b/pkg/channels/slack/slack.go
@@ -1,32 +1,31 @@
-package channels
+package slack
import (
"context"
"fmt"
- "os"
"strings"
"sync"
- "time"
"github.com/slack-go/slack"
"github.com/slack-go/slack/slackevents"
"github.com/slack-go/slack/socketmode"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/utils"
- "github.com/sipeed/picoclaw/pkg/voice"
)
type SlackChannel struct {
- *BaseChannel
+ *channels.BaseChannel
config config.SlackConfig
api *slack.Client
socketClient *socketmode.Client
botUserID string
teamID string
- transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
pendingAcks sync.Map
@@ -49,7 +48,11 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack
socketClient := socketmode.New(api)
- base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
+ base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(40000),
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
return &SlackChannel{
BaseChannel: base,
@@ -59,10 +62,6 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack
}, nil
}
-func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
- c.transcriber = transcriber
-}
-
func (c *SlackChannel) Start(ctx context.Context) error {
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
@@ -92,7 +91,7 @@ func (c *SlackChannel) Start(ctx context.Context) error {
}
}()
- c.setRunning(true)
+ c.SetRunning(true)
logger.InfoC("slack", "Slack channel started (Socket Mode)")
return nil
}
@@ -104,14 +103,14 @@ func (c *SlackChannel) Stop(ctx context.Context) error {
c.cancel()
}
- c.setRunning(false)
+ c.SetRunning(false)
logger.InfoC("slack", "Slack channel stopped")
return nil
}
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("slack channel not running")
+ return channels.ErrNotRunning
}
channelID, threadTS := parseSlackChatID(msg.ChatID)
@@ -129,7 +128,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
if err != nil {
- return fmt.Errorf("failed to send slack message: %w", err)
+ return fmt.Errorf("slack send: %w", channels.ErrTemporary)
}
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
@@ -148,6 +147,82 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
return nil
}
+// SendMedia implements the channels.MediaSender interface.
+func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ channelID, _ := parseSlackChatID(msg.ChatID)
+ if channelID == "" {
+ return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ for _, part := range msg.Parts {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("slack", "Failed to resolve media ref", map[string]any{
+ "ref": part.Ref,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ filename := part.Filename
+ if filename == "" {
+ filename = "file"
+ }
+
+ title := part.Caption
+ if title == "" {
+ title = filename
+ }
+
+ _, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{
+ Channel: channelID,
+ File: localPath,
+ Filename: filename,
+ Title: title,
+ })
+ if err != nil {
+ logger.ErrorCF("slack", "Failed to upload media", map[string]any{
+ "filename": filename,
+ "error": err.Error(),
+ })
+ return fmt.Errorf("slack send media: %w", channels.ErrTemporary)
+ }
+ }
+
+ return nil
+}
+
+// ReactToMessage implements channels.ReactionCapable.
+// It adds an "eyes" (👀) reaction to the inbound message and returns an undo function
+// that removes the reaction.
+func (c *SlackChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
+ channelID, _ := parseSlackChatID(chatID)
+ if channelID == "" {
+ return func() {}, nil
+ }
+
+ c.api.AddReaction("eyes", slack.ItemRef{
+ Channel: channelID,
+ Timestamp: messageID,
+ })
+
+ return func() {
+ c.api.RemoveReaction("eyes", slack.ItemRef{
+ Channel: channelID,
+ Timestamp: messageID,
+ })
+ }, nil
+}
+
func (c *SlackChannel) eventLoop() {
for {
select {
@@ -201,7 +276,12 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
}
// check allowlist to avoid downloading attachments for rejected users
- if !c.IsAllowed(ev.User) {
+ sender := bus.SenderInfo{
+ Platform: "slack",
+ PlatformID: ev.User,
+ CanonicalID: identity.BuildCanonicalID("slack", ev.User),
+ }
+ if !c.IsAllowedSender(sender) {
logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{
"user_id": ev.User,
})
@@ -218,11 +298,6 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
chatID = channelID + "/" + threadTS
}
- c.api.AddReaction("eyes", slack.ItemRef{
- Channel: channelID,
- Timestamp: messageTS,
- })
-
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
@@ -231,20 +306,32 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
content := ev.Text
content = c.stripBotMention(content)
- var mediaPaths []string
- localFiles := []string{} // track local files that need cleanup
+ // In non-DM channels, apply group trigger filtering
+ if !strings.HasPrefix(channelID, "D") {
+ respond, cleaned := c.ShouldRespondInGroup(false, content)
+ if !respond {
+ return
+ }
+ content = cleaned
+ }
- // ensure temp files are cleaned up when function returns
- defer func() {
- for _, file := range localFiles {
- if err := os.Remove(file); err != nil {
- logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{
- "file": file,
- "error": err.Error(),
- })
+ var mediaPaths []string
+
+ scope := channels.BuildMediaScope("slack", chatID, messageTS)
+
+ // Helper to register a local file with the media store
+ storeMedia := func(localPath, filename string) string {
+ if store := c.GetMediaStore(); store != nil {
+ ref, err := store.Store(localPath, media.MediaMeta{
+ Filename: filename,
+ Source: "slack",
+ }, scope)
+ if err == nil {
+ return ref
}
}
- }()
+ return localPath // fallback
+ }
if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files {
@@ -252,23 +339,8 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if localPath == "" {
continue
}
- localFiles = append(localFiles, localPath)
- mediaPaths = append(mediaPaths, localPath)
-
- if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
- ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
- defer cancel()
- result, err := c.transcriber.Transcribe(ctx, localPath)
-
- if err != nil {
- logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()})
- content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
- } else {
- content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
- }
- } else {
- content += fmt.Sprintf("\n[file: %s]", file.Name)
- }
+ mediaPaths = append(mediaPaths, storeMedia(localPath, file.Name))
+ content += fmt.Sprintf("\n[file: %s]", file.Name)
}
}
@@ -283,13 +355,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
peerID = senderID
}
+ peer := bus.Peer{Kind: peerKind, ID: peerID}
+
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
- "peer_kind": peerKind,
- "peer_id": peerID,
"team_id": c.teamID,
}
@@ -300,7 +372,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
"has_thread": threadTS != "",
})
- c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
+ c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
@@ -308,7 +380,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
return
}
- if !c.IsAllowed(ev.User) {
+ if !c.IsAllowedSender(bus.SenderInfo{
+ Platform: "slack",
+ PlatformID: ev.User,
+ CanonicalID: identity.BuildCanonicalID("slack", ev.User),
+ }) {
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{
"user_id": ev.User,
})
@@ -316,6 +392,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
}
senderID := ev.User
+ mentionSender := bus.SenderInfo{
+ Platform: "slack",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("slack", senderID),
+ }
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
@@ -327,11 +408,6 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
chatID = channelID + "/" + messageTS
}
- c.api.AddReaction("eyes", slack.ItemRef{
- Channel: channelID,
- Timestamp: messageTS,
- })
-
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
@@ -350,18 +426,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
mentionPeerID = senderID
}
+ mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID}
+
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"is_mention": "true",
- "peer_kind": mentionPeerKind,
- "peer_id": mentionPeerID,
"team_id": c.teamID,
}
- c.HandleMessage(senderID, chatID, content, nil, metadata)
+ c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
@@ -374,7 +450,12 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
c.socketClient.Ack(*event.Request)
}
- if !c.IsAllowed(cmd.UserID) {
+ cmdSender := bus.SenderInfo{
+ Platform: "slack",
+ PlatformID: cmd.UserID,
+ CanonicalID: identity.BuildCanonicalID("slack", cmd.UserID),
+ }
+ if !c.IsAllowedSender(cmdSender) {
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{
"user_id": cmd.UserID,
})
@@ -395,8 +476,6 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"platform": "slack",
"is_command": "true",
"trigger_id": cmd.TriggerID,
- "peer_kind": "channel",
- "peer_id": channelID,
"team_id": c.teamID,
}
@@ -406,7 +485,17 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"text": utils.Truncate(content, 50),
})
- c.HandleMessage(senderID, chatID, content, nil, metadata)
+ c.HandleMessage(
+ c.ctx,
+ bus.Peer{Kind: "channel", ID: channelID},
+ "",
+ senderID,
+ chatID,
+ content,
+ nil,
+ metadata,
+ cmdSender,
+ )
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack/slack_test.go
similarity index 99%
rename from pkg/channels/slack_test.go
rename to pkg/channels/slack/slack_test.go
index 3707c2703..30e0d2d73 100644
--- a/pkg/channels/slack_test.go
+++ b/pkg/channels/slack/slack_test.go
@@ -1,4 +1,4 @@
-package channels
+package slack
import (
"testing"
diff --git a/pkg/channels/split.go b/pkg/channels/split.go
new file mode 100644
index 000000000..bb26c6d8f
--- /dev/null
+++ b/pkg/channels/split.go
@@ -0,0 +1,208 @@
+package channels
+
+import (
+ "strings"
+)
+
+// SplitMessage splits long messages into chunks, preserving code block integrity.
+// The maxLen parameter is measured in runes (Unicode characters), not bytes.
+// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks,
+// but may extend to maxLen when needed.
+// Call SplitMessage with the full text content and the maximum allowed length of a single message;
+// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks.
+func SplitMessage(content string, maxLen int) []string {
+ if maxLen <= 0 {
+ if content == "" {
+ return nil
+ }
+ return []string{content}
+ }
+
+ runes := []rune(content)
+ totalLen := len(runes)
+ var messages []string
+
+ // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
+ codeBlockBuffer := max(maxLen/10, 50)
+ if codeBlockBuffer > maxLen/2 {
+ codeBlockBuffer = maxLen / 2
+ }
+
+ start := 0
+ for start < totalLen {
+ remaining := totalLen - start
+ if remaining <= maxLen {
+ messages = append(messages, string(runes[start:totalLen]))
+ break
+ }
+
+ // Effective split point: maxLen minus buffer, to leave room for code blocks
+ effectiveLimit := max(maxLen-codeBlockBuffer, maxLen/2)
+
+ end := start + effectiveLimit
+
+ // Find natural split point within the effective limit
+ msgEnd := findLastNewlineInRange(runes, start, end, 200)
+ if msgEnd <= start {
+ msgEnd = findLastSpaceInRange(runes, start, end, 100)
+ }
+ if msgEnd <= start {
+ msgEnd = end
+ }
+
+ // Check if this would end with an incomplete code block
+ unclosedIdx := findLastUnclosedCodeBlockInRange(runes, start, msgEnd)
+
+ if unclosedIdx >= 0 {
+ // Message would end with incomplete code block
+ // Try to extend up to maxLen to include the closing ```
+ if totalLen > msgEnd {
+ closingIdx := findNextClosingCodeBlockInRange(runes, msgEnd, totalLen)
+ if closingIdx > 0 && closingIdx-start <= maxLen {
+ // Extend to include the closing ```
+ msgEnd = closingIdx
+ } else {
+ // Code block is too long to fit in one chunk or missing closing fence.
+ // Try to split inside by injecting closing and reopening fences.
+ headerEnd := findNewlineFrom(runes, unclosedIdx)
+ var header string
+ if headerEnd == -1 {
+ header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+3]))
+ } else {
+ header = strings.TrimSpace(string(runes[unclosedIdx:headerEnd]))
+ }
+ headerEndIdx := unclosedIdx + len([]rune(header))
+ if headerEnd != -1 {
+ headerEndIdx = headerEnd
+ }
+
+ // If we have a reasonable amount of content after the header, split inside
+ if msgEnd > headerEndIdx+20 {
+ // Find a better split point closer to maxLen
+ innerLimit := min(
+ // Leave room for "\n```"
+ start+maxLen-5, totalLen)
+ betterEnd := findLastNewlineInRange(runes, start, innerLimit, 200)
+ if betterEnd > headerEndIdx {
+ msgEnd = betterEnd
+ } else {
+ msgEnd = innerLimit
+ }
+ chunk := strings.TrimRight(string(runes[start:msgEnd]), " \t\n\r") + "\n```"
+ messages = append(messages, chunk)
+ remaining := strings.TrimSpace(header + "\n" + string(runes[msgEnd:totalLen]))
+ // Replace the tail of runes with the reconstructed remaining
+ runes = []rune(remaining)
+ totalLen = len(runes)
+ start = 0
+ continue
+ }
+
+ // Otherwise, try to split before the code block starts
+ newEnd := findLastNewlineInRange(runes, start, unclosedIdx, 200)
+ if newEnd <= start {
+ newEnd = findLastSpaceInRange(runes, start, unclosedIdx, 100)
+ }
+ if newEnd > start {
+ msgEnd = newEnd
+ } else {
+ // If we can't split before, we MUST split inside (last resort)
+ if unclosedIdx-start > 20 {
+ msgEnd = unclosedIdx
+ } else {
+ splitAt := min(start+maxLen-5, totalLen)
+ chunk := strings.TrimRight(string(runes[start:splitAt]), " \t\n\r") + "\n```"
+ messages = append(messages, chunk)
+ remaining := strings.TrimSpace(header + "\n" + string(runes[splitAt:totalLen]))
+ runes = []rune(remaining)
+ totalLen = len(runes)
+ start = 0
+ continue
+ }
+ }
+ }
+ }
+ }
+
+ if msgEnd <= start {
+ msgEnd = start + effectiveLimit
+ }
+
+ messages = append(messages, string(runes[start:msgEnd]))
+ // Advance start, skipping leading whitespace of next chunk
+ start = msgEnd
+ for start < totalLen && (runes[start] == ' ' || runes[start] == '\t' || runes[start] == '\n' || runes[start] == '\r') {
+ start++
+ }
+ }
+
+ return messages
+}
+
+// findLastUnclosedCodeBlockInRange finds the last opening ``` that doesn't have a closing ```
+// within runes[start:end]. Returns the absolute rune index or -1.
+func findLastUnclosedCodeBlockInRange(runes []rune, start, end int) int {
+ inCodeBlock := false
+ lastOpenIdx := -1
+
+ for i := start; i < end; i++ {
+ if i+2 < end && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' {
+ if !inCodeBlock {
+ lastOpenIdx = i
+ }
+ inCodeBlock = !inCodeBlock
+ i += 2
+ }
+ }
+
+ if inCodeBlock {
+ return lastOpenIdx
+ }
+ return -1
+}
+
+// findNextClosingCodeBlockInRange finds the next closing ``` starting from startIdx
+// within runes[startIdx:end]. Returns the absolute index after the closing ``` or -1.
+func findNextClosingCodeBlockInRange(runes []rune, startIdx, end int) int {
+ for i := startIdx; i < end; i++ {
+ if i+2 < end && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' {
+ return i + 3
+ }
+ }
+ return -1
+}
+
+// findNewlineFrom finds the first newline character starting from the given index.
+// Returns the absolute index or -1 if not found.
+func findNewlineFrom(runes []rune, from int) int {
+ for i := from; i < len(runes); i++ {
+ if runes[i] == '\n' {
+ return i
+ }
+ }
+ return -1
+}
+
+// findLastNewlineInRange finds the last newline within the last searchWindow runes
+// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
+func findLastNewlineInRange(runes []rune, start, end, searchWindow int) int {
+ searchStart := max(end-searchWindow, start)
+ for i := end - 1; i >= searchStart; i-- {
+ if runes[i] == '\n' {
+ return i
+ }
+ }
+ return start - 1
+}
+
+// findLastSpaceInRange finds the last space/tab within the last searchWindow runes
+// of the range runes[start:end]. Returns the absolute index or start-1 (indicating not found).
+func findLastSpaceInRange(runes []rune, start, end, searchWindow int) int {
+ searchStart := max(end-searchWindow, start)
+ for i := end - 1; i >= searchStart; i-- {
+ if runes[i] == ' ' || runes[i] == '\t' {
+ return i
+ }
+ }
+ return start - 1
+}
diff --git a/pkg/channels/split_test.go b/pkg/channels/split_test.go
new file mode 100644
index 000000000..a922f9558
--- /dev/null
+++ b/pkg/channels/split_test.go
@@ -0,0 +1,362 @@
+package channels
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestSplitMessage(t *testing.T) {
+ longText := strings.Repeat("a", 2500)
+ longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars
+
+ tests := []struct {
+ name string
+ content string
+ maxLen int
+ expectChunks int // Check number of chunks
+ checkContent func(t *testing.T, chunks []string) // Custom validation
+ }{
+ {
+ name: "Empty message",
+ content: "",
+ maxLen: 2000,
+ expectChunks: 0,
+ },
+ {
+ name: "Short message fits in one chunk",
+ content: "Hello world",
+ maxLen: 2000,
+ expectChunks: 1,
+ },
+ {
+ name: "Simple split regular text",
+ content: longText,
+ maxLen: 2000,
+ expectChunks: 2,
+ checkContent: func(t *testing.T, chunks []string) {
+ if len([]rune(chunks[0])) > 2000 {
+ t.Errorf("Chunk 0 too large: %d runes", len([]rune(chunks[0])))
+ }
+ if len([]rune(chunks[0]))+len([]rune(chunks[1])) != len([]rune(longText)) {
+ t.Errorf(
+ "Total rune length mismatch. Got %d, want %d",
+ len([]rune(chunks[0]))+len([]rune(chunks[1])),
+ len([]rune(longText)),
+ )
+ }
+ },
+ },
+ {
+ name: "Split at newline",
+ // 1750 chars then newline, then more chars.
+ // Dynamic buffer: 2000 / 10 = 200.
+ // Effective limit: 2000 - 200 = 1800.
+ // Split should happen at newline because it's at 1750 (< 1800).
+ // Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051.
+ content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300),
+ maxLen: 2000,
+ expectChunks: 2,
+ checkContent: func(t *testing.T, chunks []string) {
+ if len([]rune(chunks[0])) != 1750 {
+ t.Errorf("Expected chunk 0 to be 1750 runes (split at newline), got %d", len([]rune(chunks[0])))
+ }
+ if chunks[1] != strings.Repeat("b", 300) {
+ t.Errorf("Chunk 1 content mismatch. Len: %d", len([]rune(chunks[1])))
+ }
+ },
+ },
+ {
+ name: "Long code block split",
+ content: "Prefix\n" + longCode,
+ maxLen: 2000,
+ expectChunks: 2,
+ checkContent: func(t *testing.T, chunks []string) {
+ // Check that first chunk ends with closing fence
+ if !strings.HasSuffix(chunks[0], "\n```") {
+ t.Error("First chunk should end with injected closing fence")
+ }
+ // Check that second chunk starts with execution header
+ if !strings.HasPrefix(chunks[1], "```go") {
+ t.Error("Second chunk should start with injected code block header")
+ }
+ },
+ },
+ {
+ name: "Preserve Unicode characters (rune-aware)",
+ content: strings.Repeat("\u4e16", 2500), // 2500 runes, 7500 bytes
+ maxLen: 2000,
+ expectChunks: 2,
+ checkContent: func(t *testing.T, chunks []string) {
+ // Verify chunks contain valid unicode and don't split mid-rune
+ for i, chunk := range chunks {
+ runeCount := len([]rune(chunk))
+ if runeCount > 2000 {
+ t.Errorf("Chunk %d has %d runes, exceeds maxLen 2000", i, runeCount)
+ }
+ if !strings.Contains(chunk, "\u4e16") {
+ t.Errorf("Chunk %d should contain unicode characters", i)
+ }
+ }
+ // Verify total rune count is preserved
+ totalRunes := 0
+ for _, chunk := range chunks {
+ totalRunes += len([]rune(chunk))
+ }
+ if totalRunes != 2500 {
+ t.Errorf("Total rune count mismatch. Got %d, want 2500", totalRunes)
+ }
+ },
+ },
+ {
+ name: "Zero maxLen returns single chunk",
+ content: "Hello world",
+ maxLen: 0,
+ expectChunks: 1,
+ checkContent: func(t *testing.T, chunks []string) {
+ if chunks[0] != "Hello world" {
+ t.Errorf("Expected original content, got %q", chunks[0])
+ }
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := SplitMessage(tc.content, tc.maxLen)
+
+ if tc.expectChunks == 0 {
+ if len(got) != 0 {
+ t.Errorf("Expected 0 chunks, got %d", len(got))
+ }
+ return
+ }
+
+ if len(got) != tc.expectChunks {
+ t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got))
+ // Log sizes for debugging
+ for i, c := range got {
+ t.Logf("Chunk %d length: %d", i, len(c))
+ }
+ return // Stop further checks if count assumes specific split
+ }
+
+ if tc.checkContent != nil {
+ tc.checkContent(t, got)
+ }
+ })
+ }
+}
+
+// --- Helper function tests for index-based rune operations ---
+
+func TestFindLastNewlineInRange(t *testing.T) {
+ runes := []rune("aaa\nbbb\nccc")
+ // Indices: 0123 4567 89 10
+
+ tests := []struct {
+ name string
+ start, end int
+ searchWindow int
+ want int
+ }{
+ {"finds last newline in full range", 0, 11, 200, 7},
+ {"finds newline within search window", 0, 11, 4, 7},
+ {"narrow window misses newline outside window", 4, 11, 3, 3}, // returns start-1 (not found)
+ {"no newline in range", 0, 3, 200, -1}, // start-1 = -1
+ {"range limited to first segment", 0, 4, 200, 3},
+ {"search window of 1 at newline", 0, 8, 1, 7},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := findLastNewlineInRange(runes, tc.start, tc.end, tc.searchWindow)
+ if got != tc.want {
+ t.Errorf("findLastNewlineInRange(runes, %d, %d, %d) = %d, want %d",
+ tc.start, tc.end, tc.searchWindow, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestFindLastSpaceInRange(t *testing.T) {
+ runes := []rune("abc def\tghi")
+ // Indices: 0123 4567 89 10
+
+ tests := []struct {
+ name string
+ start, end int
+ searchWindow int
+ want int
+ }{
+ {"finds tab as last space/tab", 0, 11, 200, 7},
+ {"finds space when tab out of window", 0, 7, 200, 3},
+ {"no space in range", 0, 3, 200, -1},
+ {"narrow window finds tab", 5, 11, 4, 7},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := findLastSpaceInRange(runes, tc.start, tc.end, tc.searchWindow)
+ if got != tc.want {
+ t.Errorf("findLastSpaceInRange(runes, %d, %d, %d) = %d, want %d",
+ tc.start, tc.end, tc.searchWindow, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestFindNewlineFrom(t *testing.T) {
+ runes := []rune("hello\nworld\n")
+
+ tests := []struct {
+ name string
+ from int
+ want int
+ }{
+ {"from start", 0, 5},
+ {"from after first newline", 6, 11},
+ {"from past all newlines", 12, -1},
+ {"from newline itself", 5, 5},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := findNewlineFrom(runes, tc.from)
+ if got != tc.want {
+ t.Errorf("findNewlineFrom(runes, %d) = %d, want %d", tc.from, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestFindLastUnclosedCodeBlockInRange(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ start, end int
+ want int
+ }{
+ {
+ name: "no code blocks",
+ content: "hello world",
+ start: 0, end: 11,
+ want: -1,
+ },
+ {
+ name: "complete code block",
+ content: "```go\ncode\n```",
+ start: 0, end: 14,
+ want: -1,
+ },
+ {
+ name: "unclosed code block",
+ content: "text\n```go\ncode here",
+ start: 0, end: 20,
+ want: 5,
+ },
+ {
+ name: "closed then unclosed",
+ content: "```a\n```\n```b\ncode",
+ start: 0, end: 17,
+ want: 9,
+ },
+ {
+ name: "search within subrange",
+ content: "```a\n```\n```b\ncode",
+ start: 9, end: 17,
+ want: 9,
+ },
+ {
+ name: "subrange with no code blocks",
+ content: "```a\n```\nhello",
+ start: 9, end: 14,
+ want: -1,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ runes := []rune(tc.content)
+ got := findLastUnclosedCodeBlockInRange(runes, tc.start, tc.end)
+ if got != tc.want {
+ t.Errorf("findLastUnclosedCodeBlockInRange(%q, %d, %d) = %d, want %d",
+ tc.content, tc.start, tc.end, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestFindNextClosingCodeBlockInRange(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ startIdx int
+ end int
+ want int
+ }{
+ {
+ name: "finds closing fence",
+ content: "code\n```\nmore",
+ startIdx: 0, end: 13,
+ want: 8, // position after ```
+ },
+ {
+ name: "no closing fence",
+ content: "just code here",
+ startIdx: 0, end: 14,
+ want: -1,
+ },
+ {
+ name: "fence at start of search",
+ content: "```end",
+ startIdx: 0, end: 6,
+ want: 3,
+ },
+ {
+ name: "fence outside range",
+ content: "code\n```",
+ startIdx: 0, end: 4,
+ want: -1,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ runes := []rune(tc.content)
+ got := findNextClosingCodeBlockInRange(runes, tc.startIdx, tc.end)
+ if got != tc.want {
+ t.Errorf("findNextClosingCodeBlockInRange(%q, %d, %d) = %d, want %d",
+ tc.content, tc.startIdx, tc.end, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestSplitMessage_CodeBlockIntegrity(t *testing.T) {
+ // Focused test for the core requirement: splitting inside a code block preserves syntax highlighting
+
+ // 60 chars total approximately
+ content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```"
+ maxLen := 40
+
+ chunks := SplitMessage(content, maxLen)
+
+ if len(chunks) != 2 {
+ t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks)
+ }
+
+ // First chunk must end with "\n```"
+ if !strings.HasSuffix(chunks[0], "\n```") {
+ t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0])
+ }
+
+ // Second chunk must start with the header "```go"
+ if !strings.HasPrefix(chunks[1], "```go") {
+ t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1])
+ }
+
+ // First chunk should contain meaningful content
+ if len([]rune(chunks[0])) > 40 {
+ t.Errorf("First chunk exceeded maxLen: length %d runes", len([]rune(chunks[0])))
+ }
+}
diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go
deleted file mode 100644
index 524494849..000000000
--- a/pkg/channels/telegram.go
+++ /dev/null
@@ -1,529 +0,0 @@
-package channels
-
-import (
- "context"
- "fmt"
- "net/http"
- "net/url"
- "os"
- "regexp"
- "strings"
- "sync"
- "time"
-
- "github.com/mymmrac/telego"
- "github.com/mymmrac/telego/telegohandler"
- th "github.com/mymmrac/telego/telegohandler"
- tu "github.com/mymmrac/telego/telegoutil"
-
- "github.com/sipeed/picoclaw/pkg/bus"
- "github.com/sipeed/picoclaw/pkg/config"
- "github.com/sipeed/picoclaw/pkg/logger"
- "github.com/sipeed/picoclaw/pkg/utils"
- "github.com/sipeed/picoclaw/pkg/voice"
-)
-
-type TelegramChannel struct {
- *BaseChannel
- bot *telego.Bot
- commands TelegramCommander
- config *config.Config
- chatIDs map[string]int64
- transcriber *voice.GroqTranscriber
- placeholders sync.Map // chatID -> messageID
- stopThinking sync.Map // chatID -> thinkingCancel
-}
-
-type thinkingCancel struct {
- fn context.CancelFunc
-}
-
-func (c *thinkingCancel) Cancel() {
- if c != nil && c.fn != nil {
- c.fn()
- }
-}
-
-func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
- var opts []telego.BotOption
- telegramCfg := cfg.Channels.Telegram
-
- if telegramCfg.Proxy != "" {
- proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
- if parseErr != nil {
- return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
- }
- opts = append(opts, telego.WithHTTPClient(&http.Client{
- Transport: &http.Transport{
- Proxy: http.ProxyURL(proxyURL),
- },
- }))
- } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
- // Use environment proxy if configured
- opts = append(opts, telego.WithHTTPClient(&http.Client{
- Transport: &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- },
- }))
- }
-
- bot, err := telego.NewBot(telegramCfg.Token, opts...)
- if err != nil {
- return nil, fmt.Errorf("failed to create telegram bot: %w", err)
- }
-
- base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
-
- return &TelegramChannel{
- BaseChannel: base,
- commands: NewTelegramCommands(bot, cfg),
- bot: bot,
- config: cfg,
- chatIDs: make(map[string]int64),
- transcriber: nil,
- placeholders: sync.Map{},
- stopThinking: sync.Map{},
- }, nil
-}
-
-func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
- c.transcriber = transcriber
-}
-
-func (c *TelegramChannel) Start(ctx context.Context) error {
- logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
-
- updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
- Timeout: 30,
- })
- if err != nil {
- return fmt.Errorf("failed to start long polling: %w", err)
- }
-
- bh, err := telegohandler.NewBotHandler(c.bot, updates)
- if err != nil {
- return fmt.Errorf("failed to create bot handler: %w", err)
- }
-
- bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
- c.commands.Help(ctx, message)
- return nil
- }, th.CommandEqual("help"))
- bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
- return c.commands.Start(ctx, message)
- }, th.CommandEqual("start"))
-
- bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
- return c.commands.Show(ctx, message)
- }, th.CommandEqual("show"))
-
- bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
- return c.commands.List(ctx, message)
- }, th.CommandEqual("list"))
-
- bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
- return c.handleMessage(ctx, &message)
- }, th.AnyMessage())
-
- c.setRunning(true)
- logger.InfoCF("telegram", "Telegram bot connected", map[string]any{
- "username": c.bot.Username(),
- })
-
- go bh.Start()
-
- go func() {
- <-ctx.Done()
- bh.Stop()
- }()
-
- return nil
-}
-
-func (c *TelegramChannel) Stop(ctx context.Context) error {
- logger.InfoC("telegram", "Stopping Telegram bot...")
- c.setRunning(false)
- return nil
-}
-
-func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
- if !c.IsRunning() {
- return fmt.Errorf("telegram bot not running")
- }
-
- chatID, err := parseChatID(msg.ChatID)
- if err != nil {
- return fmt.Errorf("invalid chat ID: %w", err)
- }
-
- // Stop thinking animation
- if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
- if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
- cf.Cancel()
- }
- c.stopThinking.Delete(msg.ChatID)
- }
-
- htmlContent := markdownToTelegramHTML(msg.Content)
-
- // Try to edit placeholder
- if pID, ok := c.placeholders.Load(msg.ChatID); ok {
- c.placeholders.Delete(msg.ChatID)
- editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent)
- editMsg.ParseMode = telego.ModeHTML
-
- if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil {
- return nil
- }
- // Fallback to new message if edit fails
- }
-
- tgMsg := tu.Message(tu.ID(chatID), htmlContent)
- tgMsg.ParseMode = telego.ModeHTML
-
- if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
- logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
- "error": err.Error(),
- })
- tgMsg.ParseMode = ""
- _, err = c.bot.SendMessage(ctx, tgMsg)
- return err
- }
-
- return nil
-}
-
-func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
- if message == nil {
- return fmt.Errorf("message is nil")
- }
-
- user := message.From
- if user == nil {
- return fmt.Errorf("message sender (user) is nil")
- }
-
- senderID := fmt.Sprintf("%d", user.ID)
- if user.Username != "" {
- senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
- }
-
- // check allowlist to avoid downloading attachments for rejected users
- if !c.IsAllowed(senderID) {
- logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{
- "user_id": senderID,
- })
- return nil
- }
-
- chatID := message.Chat.ID
- c.chatIDs[senderID] = chatID
-
- content := ""
- mediaPaths := []string{}
- localFiles := []string{} // track local files that need cleanup
-
- // ensure temp files are cleaned up when function returns
- defer func() {
- for _, file := range localFiles {
- if err := os.Remove(file); err != nil {
- logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{
- "file": file,
- "error": err.Error(),
- })
- }
- }
- }()
-
- if message.Text != "" {
- content += message.Text
- }
-
- if message.Caption != "" {
- if content != "" {
- content += "\n"
- }
- content += message.Caption
- }
-
- if len(message.Photo) > 0 {
- photo := message.Photo[len(message.Photo)-1]
- photoPath := c.downloadPhoto(ctx, photo.FileID)
- if photoPath != "" {
- localFiles = append(localFiles, photoPath)
- mediaPaths = append(mediaPaths, photoPath)
- if content != "" {
- content += "\n"
- }
- content += "[image: photo]"
- }
- }
-
- if message.Voice != nil {
- voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
- if voicePath != "" {
- localFiles = append(localFiles, voicePath)
- mediaPaths = append(mediaPaths, voicePath)
-
- var transcribedText string
- if c.transcriber != nil && c.transcriber.IsAvailable() {
- transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
- defer cancel()
-
- result, err := c.transcriber.Transcribe(transcriberCtx, voicePath)
- if err != nil {
- logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{
- "error": err.Error(),
- "path": voicePath,
- })
- transcribedText = "[voice (transcription failed)]"
- } else {
- transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
- logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{
- "text": result.Text,
- })
- }
- } else {
- transcribedText = "[voice]"
- }
-
- if content != "" {
- content += "\n"
- }
- content += transcribedText
- }
- }
-
- if message.Audio != nil {
- audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
- if audioPath != "" {
- localFiles = append(localFiles, audioPath)
- mediaPaths = append(mediaPaths, audioPath)
- if content != "" {
- content += "\n"
- }
- content += "[audio]"
- }
- }
-
- if message.Document != nil {
- docPath := c.downloadFile(ctx, message.Document.FileID, "")
- if docPath != "" {
- localFiles = append(localFiles, docPath)
- mediaPaths = append(mediaPaths, docPath)
- if content != "" {
- content += "\n"
- }
- content += "[file]"
- }
- }
-
- if content == "" {
- content = "[empty message]"
- }
-
- logger.DebugCF("telegram", "Received message", map[string]any{
- "sender_id": senderID,
- "chat_id": fmt.Sprintf("%d", chatID),
- "preview": utils.Truncate(content, 50),
- })
-
- // Thinking indicator
- err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
- if err != nil {
- logger.ErrorCF("telegram", "Failed to send chat action", map[string]any{
- "error": err.Error(),
- })
- }
-
- // Stop any previous thinking animation
- chatIDStr := fmt.Sprintf("%d", chatID)
- if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
- if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
- cf.Cancel()
- }
- }
-
- // Create cancel function for thinking state
- _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
- c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
-
- pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
- if err == nil {
- pID := pMsg.MessageID
- c.placeholders.Store(chatIDStr, pID)
- }
-
- peerKind := "direct"
- peerID := fmt.Sprintf("%d", user.ID)
- if message.Chat.Type != "private" {
- peerKind = "group"
- peerID = fmt.Sprintf("%d", chatID)
- }
-
- metadata := map[string]string{
- "message_id": fmt.Sprintf("%d", message.MessageID),
- "user_id": fmt.Sprintf("%d", user.ID),
- "username": user.Username,
- "first_name": user.FirstName,
- "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
- "peer_kind": peerKind,
- "peer_id": peerID,
- }
-
- c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
- return nil
-}
-
-func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
- file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
- if err != nil {
- logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{
- "error": err.Error(),
- })
- return ""
- }
-
- return c.downloadFileWithInfo(file, ".jpg")
-}
-
-func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
- if file.FilePath == "" {
- return ""
- }
-
- url := c.bot.FileDownloadURL(file.FilePath)
- logger.DebugCF("telegram", "File URL", map[string]any{"url": url})
-
- // Use FilePath as filename for better identification
- filename := file.FilePath + ext
- return utils.DownloadFile(url, filename, utils.DownloadOptions{
- LoggerPrefix: "telegram",
- })
-}
-
-func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
- file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
- if err != nil {
- logger.ErrorCF("telegram", "Failed to get file", map[string]any{
- "error": err.Error(),
- })
- return ""
- }
-
- return c.downloadFileWithInfo(file, ext)
-}
-
-func parseChatID(chatIDStr string) (int64, error) {
- var id int64
- _, err := fmt.Sscanf(chatIDStr, "%d", &id)
- return id, err
-}
-
-func markdownToTelegramHTML(text string) string {
- if text == "" {
- return ""
- }
-
- codeBlocks := extractCodeBlocks(text)
- text = codeBlocks.text
-
- inlineCodes := extractInlineCodes(text)
- text = inlineCodes.text
-
- text = regexp.MustCompile(`^#{1,6}\s+(.+)$`).ReplaceAllString(text, "$1")
-
- text = regexp.MustCompile(`^>\s*(.*)$`).ReplaceAllString(text, "$1")
-
- text = escapeHTML(text)
-
- text = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(text, `$1`)
-
- text = regexp.MustCompile(`\*\*(.+?)\*\*`).ReplaceAllString(text, "$1")
-
- text = regexp.MustCompile(`__(.+?)__`).ReplaceAllString(text, "$1")
-
- reItalic := regexp.MustCompile(`_([^_]+)_`)
- text = reItalic.ReplaceAllStringFunc(text, func(s string) string {
- match := reItalic.FindStringSubmatch(s)
- if len(match) < 2 {
- return s
- }
- return "" + match[1] + ""
- })
-
- text = regexp.MustCompile(`~~(.+?)~~`).ReplaceAllString(text, "$1")
-
- text = regexp.MustCompile(`^[-*]\s+`).ReplaceAllString(text, "• ")
-
- for i, code := range inlineCodes.codes {
- escaped := escapeHTML(code)
- text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("%s", escaped))
- }
-
- for i, code := range codeBlocks.codes {
- escaped := escapeHTML(code)
- text = strings.ReplaceAll(
- text,
- fmt.Sprintf("\x00CB%d\x00", i),
- fmt.Sprintf("%s
", escaped),
- )
- }
-
- return text
-}
-
-type codeBlockMatch struct {
- text string
- codes []string
-}
-
-func extractCodeBlocks(text string) codeBlockMatch {
- re := regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```")
- matches := re.FindAllStringSubmatch(text, -1)
-
- codes := make([]string, 0, len(matches))
- for _, match := range matches {
- codes = append(codes, match[1])
- }
-
- i := 0
- text = re.ReplaceAllStringFunc(text, func(m string) string {
- placeholder := fmt.Sprintf("\x00CB%d\x00", i)
- i++
- return placeholder
- })
-
- return codeBlockMatch{text: text, codes: codes}
-}
-
-type inlineCodeMatch struct {
- text string
- codes []string
-}
-
-func extractInlineCodes(text string) inlineCodeMatch {
- re := regexp.MustCompile("`([^`]+)`")
- matches := re.FindAllStringSubmatch(text, -1)
-
- codes := make([]string, 0, len(matches))
- for _, match := range matches {
- codes = append(codes, match[1])
- }
-
- i := 0
- text = re.ReplaceAllStringFunc(text, func(m string) string {
- placeholder := fmt.Sprintf("\x00IC%d\x00", i)
- i++
- return placeholder
- })
-
- return inlineCodeMatch{text: text, codes: codes}
-}
-
-func escapeHTML(text string) string {
- text = strings.ReplaceAll(text, "&", "&")
- text = strings.ReplaceAll(text, "<", "<")
- text = strings.ReplaceAll(text, ">", ">")
- return text
-}
diff --git a/pkg/channels/telegram/init.go b/pkg/channels/telegram/init.go
new file mode 100644
index 000000000..ac87bb805
--- /dev/null
+++ b/pkg/channels/telegram/init.go
@@ -0,0 +1,13 @@
+package telegram
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewTelegramChannel(cfg, b)
+ })
+}
diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go
new file mode 100644
index 000000000..f328f32b8
--- /dev/null
+++ b/pkg/channels/telegram/telegram.go
@@ -0,0 +1,769 @@
+package telegram
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/url"
+ "os"
+ "regexp"
+ "slices"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/mymmrac/telego"
+ th "github.com/mymmrac/telego/telegohandler"
+ tu "github.com/mymmrac/telego/telegoutil"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+var (
+ reHeading = regexp.MustCompile(`^#{1,6}\s+(.+)$`)
+ reBlockquote = regexp.MustCompile(`^>\s*(.*)$`)
+ reLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
+ reBoldStar = regexp.MustCompile(`\*\*(.+?)\*\*`)
+ reBoldUnder = regexp.MustCompile(`__(.+?)__`)
+ reItalic = regexp.MustCompile(`_([^_]+)_`)
+ reStrike = regexp.MustCompile(`~~(.+?)~~`)
+ reListItem = regexp.MustCompile(`^[-*]\s+`)
+ reCodeBlock = regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```")
+ reInlineCode = regexp.MustCompile("`([^`]+)`")
+)
+
+type TelegramChannel struct {
+ *channels.BaseChannel
+ bot *telego.Bot
+ bh *th.BotHandler
+ commands TelegramCommander
+ config *config.Config
+ chatIDs map[string]int64
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
+ var opts []telego.BotOption
+ telegramCfg := cfg.Channels.Telegram
+
+ if telegramCfg.Proxy != "" {
+ proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
+ if parseErr != nil {
+ return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
+ }
+ opts = append(opts, telego.WithHTTPClient(&http.Client{
+ Transport: &http.Transport{
+ Proxy: http.ProxyURL(proxyURL),
+ },
+ }))
+ } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" {
+ // Use environment proxy if configured
+ opts = append(opts, telego.WithHTTPClient(&http.Client{
+ Transport: &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ },
+ }))
+ }
+
+ if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" {
+ opts = append(opts, telego.WithAPIServer(baseURL))
+ }
+
+ bot, err := telego.NewBot(telegramCfg.Token, opts...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create telegram bot: %w", err)
+ }
+
+ base := channels.NewBaseChannel(
+ "telegram",
+ telegramCfg,
+ bus,
+ telegramCfg.AllowFrom,
+ channels.WithMaxMessageLength(4096),
+ channels.WithGroupTrigger(telegramCfg.GroupTrigger),
+ channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID),
+ )
+
+ return &TelegramChannel{
+ BaseChannel: base,
+ commands: NewTelegramCommands(bot, cfg),
+ bot: bot,
+ config: cfg,
+ chatIDs: make(map[string]int64),
+ }, nil
+}
+
+func (c *TelegramChannel) Start(ctx context.Context) error {
+ logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ if err := c.initBotCommands(c.ctx); err != nil {
+ logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
+ "error": err.Error(),
+ })
+ }
+
+ updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
+ Timeout: 30,
+ })
+ if err != nil {
+ c.cancel()
+ return fmt.Errorf("failed to start long polling: %w", err)
+ }
+
+ bh, err := th.NewBotHandler(c.bot, updates)
+ if err != nil {
+ c.cancel()
+ return fmt.Errorf("failed to create bot handler: %w", err)
+ }
+ c.bh = bh
+
+ bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
+ return c.commands.Start(ctx, message)
+ }, th.CommandEqual("start"))
+ bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
+ return c.commands.Help(ctx, message)
+ }, th.CommandEqual("help"))
+
+ bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
+ return c.commands.Show(ctx, message)
+ }, th.CommandEqual("show"))
+
+ bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
+ return c.commands.List(ctx, message)
+ }, th.CommandEqual("list"))
+
+ bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
+ return c.handleMessage(ctx, &message)
+ }, th.AnyMessage())
+
+ c.SetRunning(true)
+ logger.InfoCF("telegram", "Telegram bot connected", map[string]any{
+ "username": c.bot.Username(),
+ })
+
+ go func() {
+ if err = bh.Start(); err != nil {
+ logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
+
+ return nil
+}
+
+func (c *TelegramChannel) Stop(ctx context.Context) error {
+ logger.InfoC("telegram", "Stopping Telegram bot...")
+ c.SetRunning(false)
+
+ // Stop the bot handler
+ if c.bh != nil {
+ _ = c.bh.StopWithContext(ctx)
+ }
+
+ // Cancel our context (stops long polling)
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ return nil
+}
+
+func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
+ currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
+ Scope: tu.ScopeDefault(),
+ })
+ if err != nil {
+ return fmt.Errorf("get commands: %w", err)
+ }
+
+ commands := []telego.BotCommand{
+ {
+ Command: "start",
+ Description: "Start the bot",
+ },
+ {
+ Command: "help",
+ Description: "Show a help message",
+ },
+ {
+ Command: "show",
+ Description: "Show current configuration",
+ },
+ {
+ Command: "list",
+ Description: "List available options",
+ },
+ }
+
+ // Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
+ if !slices.Equal(currentCommands, commands) {
+ logger.InfoC("telegram", "Updating bot commands")
+
+ err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
+ Commands: commands,
+ Scope: tu.ScopeDefault(),
+ })
+ if err != nil {
+ return fmt.Errorf("set commands: %w", err)
+ }
+ } else {
+ logger.DebugC("telegram", "Bot commands are up to date")
+ }
+
+ return nil
+}
+
+func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ chatID, err := parseChatID(msg.ChatID)
+ if err != nil {
+ return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
+ }
+
+ htmlContent := markdownToTelegramHTML(msg.Content)
+
+ // Typing/placeholder handled by Manager.preSend — just send the message
+ tgMsg := tu.Message(tu.ID(chatID), htmlContent)
+ tgMsg.ParseMode = telego.ModeHTML
+
+ if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
+ logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
+ "error": err.Error(),
+ })
+ tgMsg.ParseMode = ""
+ if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
+ return fmt.Errorf("telegram send: %w", channels.ErrTemporary)
+ }
+ }
+
+ return nil
+}
+
+// StartTyping implements channels.TypingCapable.
+// It sends ChatAction(typing) immediately and then repeats every 4 seconds
+// (Telegram's typing indicator expires after ~5s) in a background goroutine.
+// The returned stop function is idempotent and cancels the goroutine.
+func (c *TelegramChannel) StartTyping(ctx context.Context, chatID string) (func(), error) {
+ cid, err := parseChatID(chatID)
+ if err != nil {
+ return func() {}, err
+ }
+
+ // Send the first typing action immediately
+ _ = c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping))
+
+ typingCtx, cancel := context.WithCancel(ctx)
+ go func() {
+ ticker := time.NewTicker(4 * time.Second)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-typingCtx.Done():
+ return
+ case <-ticker.C:
+ _ = c.bot.SendChatAction(typingCtx, tu.ChatAction(tu.ID(cid), telego.ChatActionTyping))
+ }
+ }
+ }()
+
+ return cancel, nil
+}
+
+// EditMessage implements channels.MessageEditor.
+func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error {
+ cid, err := parseChatID(chatID)
+ if err != nil {
+ return err
+ }
+ mid, err := strconv.Atoi(messageID)
+ if err != nil {
+ return err
+ }
+ htmlContent := markdownToTelegramHTML(content)
+ editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent)
+ editMsg.ParseMode = telego.ModeHTML
+ _, err = c.bot.EditMessageText(ctx, editMsg)
+ return err
+}
+
+// SendPlaceholder implements channels.PlaceholderCapable.
+// It sends a placeholder message (e.g. "Thinking... 💭") that will later be
+// edited to the actual response via EditMessage (channels.MessageEditor).
+func (c *TelegramChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
+ phCfg := c.config.Channels.Telegram.Placeholder
+ if !phCfg.Enabled {
+ return "", nil
+ }
+
+ text := phCfg.Text
+ if text == "" {
+ text = "Thinking... 💭"
+ }
+
+ cid, err := parseChatID(chatID)
+ if err != nil {
+ return "", err
+ }
+
+ pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(cid), text))
+ if err != nil {
+ return "", err
+ }
+
+ return fmt.Sprintf("%d", pMsg.MessageID), nil
+}
+
+// SendMedia implements the channels.MediaSender interface.
+func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ chatID, err := parseChatID(msg.ChatID)
+ if err != nil {
+ return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ for _, part := range msg.Parts {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("telegram", "Failed to resolve media ref", map[string]any{
+ "ref": part.Ref,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ file, err := os.Open(localPath)
+ if err != nil {
+ logger.ErrorCF("telegram", "Failed to open media file", map[string]any{
+ "path": localPath,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ switch part.Type {
+ case "image":
+ params := &telego.SendPhotoParams{
+ ChatID: tu.ID(chatID),
+ Photo: telego.InputFile{File: file},
+ Caption: part.Caption,
+ }
+ _, err = c.bot.SendPhoto(ctx, params)
+ case "audio":
+ params := &telego.SendAudioParams{
+ ChatID: tu.ID(chatID),
+ Audio: telego.InputFile{File: file},
+ Caption: part.Caption,
+ }
+ _, err = c.bot.SendAudio(ctx, params)
+ case "video":
+ params := &telego.SendVideoParams{
+ ChatID: tu.ID(chatID),
+ Video: telego.InputFile{File: file},
+ Caption: part.Caption,
+ }
+ _, err = c.bot.SendVideo(ctx, params)
+ default: // "file" or unknown types
+ params := &telego.SendDocumentParams{
+ ChatID: tu.ID(chatID),
+ Document: telego.InputFile{File: file},
+ Caption: part.Caption,
+ }
+ _, err = c.bot.SendDocument(ctx, params)
+ }
+
+ file.Close()
+
+ if err != nil {
+ logger.ErrorCF("telegram", "Failed to send media", map[string]any{
+ "type": part.Type,
+ "error": err.Error(),
+ })
+ return fmt.Errorf("telegram send media: %w", channels.ErrTemporary)
+ }
+ }
+
+ return nil
+}
+
+func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
+ if message == nil {
+ return fmt.Errorf("message is nil")
+ }
+
+ user := message.From
+ if user == nil {
+ return fmt.Errorf("message sender (user) is nil")
+ }
+
+ platformID := fmt.Sprintf("%d", user.ID)
+ sender := bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: platformID,
+ CanonicalID: identity.BuildCanonicalID("telegram", platformID),
+ Username: user.Username,
+ DisplayName: user.FirstName,
+ }
+
+ // check allowlist to avoid downloading attachments for rejected users
+ if !c.IsAllowedSender(sender) {
+ logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{
+ "user_id": platformID,
+ })
+ return nil
+ }
+
+ chatID := message.Chat.ID
+ c.chatIDs[platformID] = chatID
+
+ content := ""
+ mediaPaths := []string{}
+
+ chatIDStr := fmt.Sprintf("%d", chatID)
+ messageIDStr := fmt.Sprintf("%d", message.MessageID)
+ scope := channels.BuildMediaScope("telegram", chatIDStr, messageIDStr)
+
+ // Helper to register a local file with the media store
+ storeMedia := func(localPath, filename string) string {
+ if store := c.GetMediaStore(); store != nil {
+ ref, err := store.Store(localPath, media.MediaMeta{
+ Filename: filename,
+ Source: "telegram",
+ }, scope)
+ if err == nil {
+ return ref
+ }
+ }
+ return localPath // fallback: use raw path
+ }
+
+ if message.Text != "" {
+ content += message.Text
+ }
+
+ if message.Caption != "" {
+ if content != "" {
+ content += "\n"
+ }
+ content += message.Caption
+ }
+
+ if len(message.Photo) > 0 {
+ photo := message.Photo[len(message.Photo)-1]
+ photoPath := c.downloadPhoto(ctx, photo.FileID)
+ if photoPath != "" {
+ mediaPaths = append(mediaPaths, storeMedia(photoPath, "photo.jpg"))
+ if content != "" {
+ content += "\n"
+ }
+ content += "[image: photo]"
+ }
+ }
+
+ if message.Voice != nil {
+ voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
+ if voicePath != "" {
+ mediaPaths = append(mediaPaths, storeMedia(voicePath, "voice.ogg"))
+
+ if content != "" {
+ content += "\n"
+ }
+ content += "[voice]"
+ }
+ }
+
+ if message.Audio != nil {
+ audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
+ if audioPath != "" {
+ mediaPaths = append(mediaPaths, storeMedia(audioPath, "audio.mp3"))
+ if content != "" {
+ content += "\n"
+ }
+ content += "[audio]"
+ }
+ }
+
+ if message.Document != nil {
+ docPath := c.downloadFile(ctx, message.Document.FileID, "")
+ if docPath != "" {
+ mediaPaths = append(mediaPaths, storeMedia(docPath, "document"))
+ if content != "" {
+ content += "\n"
+ }
+ content += "[file]"
+ }
+ }
+
+ if content == "" {
+ content = "[empty message]"
+ }
+
+ // In group chats, apply unified group trigger filtering
+ if message.Chat.Type != "private" {
+ isMentioned := c.isBotMentioned(message)
+ if isMentioned {
+ content = c.stripBotMention(content)
+ }
+ respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
+ if !respond {
+ return nil
+ }
+ content = cleaned
+ }
+
+ logger.DebugCF("telegram", "Received message", map[string]any{
+ "sender_id": sender.CanonicalID,
+ "chat_id": fmt.Sprintf("%d", chatID),
+ "preview": utils.Truncate(content, 50),
+ })
+
+ // Placeholder is now auto-triggered by BaseChannel.HandleMessage via PlaceholderCapable
+
+ peerKind := "direct"
+ peerID := fmt.Sprintf("%d", user.ID)
+ if message.Chat.Type != "private" {
+ peerKind = "group"
+ peerID = fmt.Sprintf("%d", chatID)
+ }
+
+ peer := bus.Peer{Kind: peerKind, ID: peerID}
+ messageID := fmt.Sprintf("%d", message.MessageID)
+
+ metadata := map[string]string{
+ "user_id": fmt.Sprintf("%d", user.ID),
+ "username": user.Username,
+ "first_name": user.FirstName,
+ "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
+ }
+
+ c.HandleMessage(c.ctx,
+ peer,
+ messageID,
+ platformID,
+ fmt.Sprintf("%d", chatID),
+ content,
+ mediaPaths,
+ metadata,
+ sender,
+ )
+ return nil
+}
+
+func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
+ file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
+ if err != nil {
+ logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{
+ "error": err.Error(),
+ })
+ return ""
+ }
+
+ return c.downloadFileWithInfo(file, ".jpg")
+}
+
+func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
+ if file.FilePath == "" {
+ return ""
+ }
+
+ url := c.bot.FileDownloadURL(file.FilePath)
+ logger.DebugCF("telegram", "File URL", map[string]any{"url": url})
+
+ // Use FilePath as filename for better identification
+ filename := file.FilePath + ext
+ return utils.DownloadFile(url, filename, utils.DownloadOptions{
+ LoggerPrefix: "telegram",
+ })
+}
+
+func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
+ file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
+ if err != nil {
+ logger.ErrorCF("telegram", "Failed to get file", map[string]any{
+ "error": err.Error(),
+ })
+ return ""
+ }
+
+ return c.downloadFileWithInfo(file, ext)
+}
+
+func parseChatID(chatIDStr string) (int64, error) {
+ var id int64
+ _, err := fmt.Sscanf(chatIDStr, "%d", &id)
+ return id, err
+}
+
+func markdownToTelegramHTML(text string) string {
+ if text == "" {
+ return ""
+ }
+
+ codeBlocks := extractCodeBlocks(text)
+ text = codeBlocks.text
+
+ inlineCodes := extractInlineCodes(text)
+ text = inlineCodes.text
+
+ text = reHeading.ReplaceAllString(text, "$1")
+
+ text = reBlockquote.ReplaceAllString(text, "$1")
+
+ text = escapeHTML(text)
+
+ text = reLink.ReplaceAllString(text, `$1`)
+
+ text = reBoldStar.ReplaceAllString(text, "$1")
+
+ text = reBoldUnder.ReplaceAllString(text, "$1")
+
+ text = reItalic.ReplaceAllStringFunc(text, func(s string) string {
+ match := reItalic.FindStringSubmatch(s)
+ if len(match) < 2 {
+ return s
+ }
+ return "" + match[1] + ""
+ })
+
+ text = reStrike.ReplaceAllString(text, "$1")
+
+ text = reListItem.ReplaceAllString(text, "• ")
+
+ for i, code := range inlineCodes.codes {
+ escaped := escapeHTML(code)
+ text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("%s", escaped))
+ }
+
+ for i, code := range codeBlocks.codes {
+ escaped := escapeHTML(code)
+ text = strings.ReplaceAll(
+ text,
+ fmt.Sprintf("\x00CB%d\x00", i),
+ fmt.Sprintf("%s
", escaped),
+ )
+ }
+
+ return text
+}
+
+type codeBlockMatch struct {
+ text string
+ codes []string
+}
+
+func extractCodeBlocks(text string) codeBlockMatch {
+ matches := reCodeBlock.FindAllStringSubmatch(text, -1)
+
+ codes := make([]string, 0, len(matches))
+ for _, match := range matches {
+ codes = append(codes, match[1])
+ }
+
+ i := 0
+ text = reCodeBlock.ReplaceAllStringFunc(text, func(m string) string {
+ placeholder := fmt.Sprintf("\x00CB%d\x00", i)
+ i++
+ return placeholder
+ })
+
+ return codeBlockMatch{text: text, codes: codes}
+}
+
+type inlineCodeMatch struct {
+ text string
+ codes []string
+}
+
+func extractInlineCodes(text string) inlineCodeMatch {
+ matches := reInlineCode.FindAllStringSubmatch(text, -1)
+
+ codes := make([]string, 0, len(matches))
+ for _, match := range matches {
+ codes = append(codes, match[1])
+ }
+
+ i := 0
+ text = reInlineCode.ReplaceAllStringFunc(text, func(m string) string {
+ placeholder := fmt.Sprintf("\x00IC%d\x00", i)
+ i++
+ return placeholder
+ })
+
+ return inlineCodeMatch{text: text, codes: codes}
+}
+
+func escapeHTML(text string) string {
+ text = strings.ReplaceAll(text, "&", "&")
+ text = strings.ReplaceAll(text, "<", "<")
+ text = strings.ReplaceAll(text, ">", ">")
+ return text
+}
+
+// isBotMentioned checks if the bot is mentioned in the message via entities.
+func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool {
+ botUsername := c.bot.Username()
+ if botUsername == "" {
+ return false
+ }
+
+ entities := message.Entities
+ if entities == nil {
+ entities = message.CaptionEntities
+ }
+
+ for _, entity := range entities {
+ if entity.Type == "mention" {
+ // Extract the mention text from the message
+ text := message.Text
+ if text == "" {
+ text = message.Caption
+ }
+ runes := []rune(text)
+ end := entity.Offset + entity.Length
+ if end <= len(runes) {
+ mention := string(runes[entity.Offset:end])
+ if strings.EqualFold(mention, "@"+botUsername) {
+ return true
+ }
+ }
+ }
+ if entity.Type == "text_mention" && entity.User != nil {
+ if entity.User.Username == botUsername {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// stripBotMention removes the @bot mention from the content.
+func (c *TelegramChannel) stripBotMention(content string) string {
+ botUsername := c.bot.Username()
+ if botUsername == "" {
+ return content
+ }
+ // Case-insensitive replacement
+ re := regexp.MustCompile(`(?i)@` + regexp.QuoteMeta(botUsername))
+ content = re.ReplaceAllString(content, "")
+ return strings.TrimSpace(content)
+}
diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go
similarity index 98%
rename from pkg/channels/telegram_commands.go
rename to pkg/channels/telegram/telegram_commands.go
index f28434f46..496fc5e4f 100644
--- a/pkg/channels/telegram_commands.go
+++ b/pkg/channels/telegram/telegram_commands.go
@@ -1,4 +1,4 @@
-package channels
+package telegram
import (
"context"
@@ -119,7 +119,7 @@ func (c *cmd) List(ctx context.Context, message telego.Message) error {
if provider == "" {
provider = "configured default"
}
- response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
+ response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.json",
c.config.Agents.Defaults.GetModelName(), provider)
case "channels":
diff --git a/pkg/channels/webhook.go b/pkg/channels/webhook.go
new file mode 100644
index 000000000..3cf27baf6
--- /dev/null
+++ b/pkg/channels/webhook.go
@@ -0,0 +1,20 @@
+package channels
+
+import "net/http"
+
+// WebhookHandler is an optional interface for channels that receive messages
+// via HTTP webhooks. Manager discovers channels implementing this interface
+// and registers them on the shared HTTP server.
+type WebhookHandler interface {
+ // WebhookPath returns the path to mount this handler on the shared server.
+ // Examples: "/webhook/line", "/webhook/wecom"
+ WebhookPath() string
+ http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request)
+}
+
+// HealthChecker is an optional interface for channels that expose
+// a health check endpoint on the shared HTTP server.
+type HealthChecker interface {
+ HealthPath() string
+ HealthHandler(w http.ResponseWriter, r *http.Request)
+}
diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go
new file mode 100644
index 000000000..6c5aca40b
--- /dev/null
+++ b/pkg/channels/wecom/aibot.go
@@ -0,0 +1,1014 @@
+package wecom
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "math/big"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+// WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人)
+type WeComAIBotChannel struct {
+ *channels.BaseChannel
+ config config.WeComAIBotConfig
+ ctx context.Context
+ cancel context.CancelFunc
+ streamTasks map[string]*streamTask // streamID -> task (for poll lookups)
+ chatTasks map[string][]*streamTask // chatID -> in-flight tasks queue (FIFO)
+ taskMu sync.RWMutex
+}
+
+// streamTask represents a streaming task for AI Bot.
+//
+// Mutable fields (Finished, StreamClosed, StreamClosedAt) must be read/written
+// while holding WeComAIBotChannel.taskMu. Immutable fields (StreamID, ChatID,
+// ResponseURL, Question, CreatedTime, Deadline, answerCh, ctx, cancel) are set
+// once at creation and never modified, so they are safe to read without a lock.
+type streamTask struct {
+ // immutable after creation
+ StreamID string
+ ChatID string // used by Send() to find this task
+ ResponseURL string // temporary URL for proactive reply (valid 1 hour, use once)
+ Question string
+ CreatedTime time.Time
+ Deadline time.Time // ~30s, we close the stream here and switch to response_url
+ answerCh chan string // receives agent reply from Send()
+ ctx context.Context // canceled when task is removed; used to interrupt the agent goroutine
+ cancel context.CancelFunc // call on task removal to cancel ctx
+
+ // mutable — guarded by WeComAIBotChannel.taskMu
+ StreamClosed bool // stream returned finish:true; waiting for agent to reply via response_url
+ StreamClosedAt time.Time // set when StreamClosed becomes true; used for accelerated cleanup
+ Finished bool // fully done
+}
+
+// WeComAIBotMessage represents the decrypted JSON message from WeCom AI Bot
+// Ref: https://developer.work.weixin.qq.com/document/path/100719
+type WeComAIBotMessage struct {
+ MsgID string `json:"msgid"`
+ AIBotID string `json:"aibotid"`
+ ChatID string `json:"chatid"` // only for group chat
+ ChatType string `json:"chattype"` // "single" or "group"
+ From struct {
+ UserID string `json:"userid"`
+ } `json:"from"`
+ ResponseURL string `json:"response_url"` // temporary URL for proactive reply
+ MsgType string `json:"msgtype"`
+ // text message
+ Text *struct {
+ Content string `json:"content"`
+ } `json:"text,omitempty"`
+ // stream polling refresh
+ Stream *struct {
+ ID string `json:"id"`
+ } `json:"stream,omitempty"`
+ // image message
+ Image *struct {
+ URL string `json:"url"`
+ } `json:"image,omitempty"`
+ // mixed message (text + image)
+ Mixed *struct {
+ MsgItem []struct {
+ MsgType string `json:"msgtype"`
+ Text *struct {
+ Content string `json:"content"`
+ } `json:"text,omitempty"`
+ Image *struct {
+ URL string `json:"url"`
+ } `json:"image,omitempty"`
+ } `json:"msg_item"`
+ } `json:"mixed,omitempty"`
+ // event field
+ Event *struct {
+ EventType string `json:"eventtype"`
+ } `json:"event,omitempty"`
+}
+
+// WeComAIBotMsgItemImage holds the image payload inside a stream message item.
+type WeComAIBotMsgItemImage struct {
+ Base64 string `json:"base64"`
+ MD5 string `json:"md5"`
+}
+
+// WeComAIBotMsgItem is a single item inside a stream's msg_item list.
+type WeComAIBotMsgItem struct {
+ MsgType string `json:"msgtype"`
+ Image *WeComAIBotMsgItemImage `json:"image,omitempty"`
+}
+
+// WeComAIBotStreamInfo represents the detailed stream content in streaming responses.
+type WeComAIBotStreamInfo struct {
+ ID string `json:"id"`
+ Finish bool `json:"finish"`
+ Content string `json:"content,omitempty"`
+ MsgItem []WeComAIBotMsgItem `json:"msg_item,omitempty"`
+}
+
+// WeComAIBotStreamResponse represents the streaming response format
+type WeComAIBotStreamResponse struct {
+ MsgType string `json:"msgtype"`
+ Stream WeComAIBotStreamInfo `json:"stream"`
+}
+
+// WeComAIBotEncryptedResponse represents the encrypted response wrapper
+// Fields match WXBizJsonMsgCrypt.generate() in Python SDK
+type WeComAIBotEncryptedResponse struct {
+ Encrypt string `json:"encrypt"`
+ MsgSignature string `json:"msgsignature"`
+ Timestamp string `json:"timestamp"`
+ Nonce string `json:"nonce"`
+}
+
+// NewWeComAIBotChannel creates a new WeCom AI Bot channel instance
+func NewWeComAIBotChannel(
+ cfg config.WeComAIBotConfig,
+ messageBus *bus.MessageBus,
+) (*WeComAIBotChannel, error) {
+ if cfg.Token == "" || cfg.EncodingAESKey == "" {
+ return nil, fmt.Errorf("token and encoding_aes_key are required for WeCom AI Bot")
+ }
+
+ base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(2048),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
+
+ return &WeComAIBotChannel{
+ BaseChannel: base,
+ config: cfg,
+ streamTasks: make(map[string]*streamTask),
+ chatTasks: make(map[string][]*streamTask),
+ }, nil
+}
+
+// Name returns the channel name
+func (c *WeComAIBotChannel) Name() string {
+ return "wecom_aibot"
+}
+
+// Start initializes the WeCom AI Bot channel
+func (c *WeComAIBotChannel) Start(ctx context.Context) error {
+ logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel...")
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ // Start cleanup goroutine for old tasks
+ go c.cleanupLoop()
+
+ c.SetRunning(true)
+ logger.InfoC("wecom_aibot", "WeCom AI Bot channel started")
+
+ return nil
+}
+
+// Stop gracefully stops the WeCom AI Bot channel
+func (c *WeComAIBotChannel) Stop(ctx context.Context) error {
+ logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel...")
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ c.SetRunning(false)
+ logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped")
+ return nil
+}
+
+// Send delivers the agent reply into the active streamTask for msg.ChatID.
+// It writes into the earliest unfinished task in the queue (FIFO per chatID).
+// If the stream has already closed (deadline passed), it posts directly to response_url.
+func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+ c.taskMu.Lock()
+ queue := c.chatTasks[msg.ChatID]
+ // Only compact Finished tasks at the head of the queue.
+ // Tasks that are Finished in the middle are NOT removed here: doing a full
+ // scan on every Send() call would be O(n) and is unnecessary given that
+ // removeTask() always splices the task out of the queue immediately.
+ // Any Finished task left stranded in the middle (e.g. due to an unexpected
+ // code path) will be collected by cleanupOldTasks.
+ for len(queue) > 0 && queue[0].Finished {
+ queue = queue[1:]
+ }
+ c.chatTasks[msg.ChatID] = queue
+ var task *streamTask
+ var streamClosed bool
+ var responseURL string
+ if len(queue) > 0 {
+ task = queue[0]
+ // Read mutable fields while holding c.taskMu to avoid data races.
+ streamClosed = task.StreamClosed
+ responseURL = task.ResponseURL
+ }
+ c.taskMu.Unlock()
+
+ if task == nil {
+ logger.DebugCF(
+ "wecom_aibot",
+ "Send: no active task for chat (may have timed out)",
+ map[string]any{
+ "chat_id": msg.ChatID,
+ },
+ )
+ return nil
+ }
+
+ if streamClosed {
+ // Stream already ended with a "please wait" notice; send the real reply via response_url.
+ // Note: task.StreamID and task.ChatID are immutable, safe to read without a lock.
+ logger.InfoCF("wecom_aibot", "Sending reply via response_url", map[string]any{
+ "stream_id": task.StreamID,
+ "chat_id": msg.ChatID,
+ })
+ if responseURL != "" {
+ if err := c.sendViaResponseURL(responseURL, msg.Content); err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to send via response_url", map[string]any{
+ "error": err,
+ "stream_id": task.StreamID,
+ })
+ c.removeTask(task)
+ return fmt.Errorf("response_url delivery failed: %w", channels.ErrSendFailed)
+ }
+ } else {
+ logger.WarnCF("wecom_aibot", "Stream closed but no response_url available", map[string]any{
+ "stream_id": task.StreamID,
+ })
+ }
+ c.removeTask(task)
+ return nil
+ }
+
+ // Stream still open: deliver via answerCh for the next poll response.
+ select {
+ case task.answerCh <- msg.Content:
+ case <-task.ctx.Done():
+ // Task was canceled (cleanup removed it); silently drop the reply.
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ return nil
+}
+
+// WebhookPath returns the path for registering on the shared HTTP server
+func (c *WeComAIBotChannel) WebhookPath() string {
+ if c.config.WebhookPath == "" {
+ return "/webhook/wecom-aibot"
+ }
+ return c.config.WebhookPath
+}
+
+// ServeHTTP implements http.Handler for the shared HTTP server
+func (c *WeComAIBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ c.handleWebhook(w, r)
+}
+
+// HealthPath returns the health check endpoint path
+func (c *WeComAIBotChannel) HealthPath() string {
+ return c.WebhookPath() + "/health"
+}
+
+// HealthHandler handles health check requests
+func (c *WeComAIBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
+ c.handleHealth(w, r)
+}
+
+// handleWebhook handles incoming webhook requests from WeCom AI Bot
+func (c *WeComAIBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+
+ // Log all incoming requests for debugging
+ logger.DebugCF("wecom_aibot", "Received webhook request", map[string]any{
+ "method": r.Method,
+ "path": r.URL.Path,
+ "query": r.URL.RawQuery,
+ })
+
+ switch r.Method {
+ case http.MethodGet:
+ // URL verification
+ c.handleVerification(ctx, w, r)
+ case http.MethodPost:
+ // Message callback
+ c.handleMessageCallback(ctx, w, r)
+ default:
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+// handleVerification handles the URL verification request from WeCom
+func (c *WeComAIBotChannel) handleVerification(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+) {
+ msgSignature := r.URL.Query().Get("msg_signature")
+ timestamp := r.URL.Query().Get("timestamp")
+ nonce := r.URL.Query().Get("nonce")
+ echostr := r.URL.Query().Get("echostr")
+
+ logger.DebugCF("wecom_aibot", "URL verification request", map[string]any{
+ "msg_signature": msgSignature,
+ "timestamp": timestamp,
+ "nonce": nonce,
+ })
+
+ // Verify signature
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
+ logger.ErrorC("wecom_aibot", "Signature verification failed")
+ http.Error(w, "Signature verification failed", http.StatusUnauthorized)
+ return
+ }
+
+ // Decrypt echostr
+ // For WeCom AI Bot (智能机器人), receiveid should be empty string
+ decrypted, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "")
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to decrypt echostr", map[string]any{
+ "error": err,
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Remove BOM and whitespace as per WeCom documentation
+ decrypted = strings.TrimPrefix(decrypted, "\ufeff")
+ decrypted = strings.TrimSpace(decrypted)
+
+ logger.InfoC("wecom_aibot", "URL verification successful")
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(decrypted))
+}
+
+// handleMessageCallback handles incoming messages from WeCom AI Bot
+func (c *WeComAIBotChannel) handleMessageCallback(
+ ctx context.Context,
+ w http.ResponseWriter,
+ r *http.Request,
+) {
+ msgSignature := r.URL.Query().Get("msg_signature")
+ timestamp := r.URL.Query().Get("timestamp")
+ nonce := r.URL.Query().Get("nonce")
+
+ // Read request body (limit to 4 MB to prevent memory exhaustion)
+ const maxBodySize = 4 << 20 // 4 MB
+ body, err := io.ReadAll(io.LimitReader(r.Body, maxBodySize+1))
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to read request body", map[string]any{
+ "error": err,
+ })
+ http.Error(w, "Failed to read body", http.StatusBadRequest)
+ return
+ }
+ if len(body) > maxBodySize {
+ http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
+ return
+ }
+
+ // Parse JSON body to get encrypted message
+ // Format: {"encrypt": "base64_encrypted_string"}
+ var encryptedMsg struct {
+ Encrypt string `json:"encrypt"`
+ }
+ if unmarshalErr := json.Unmarshal(body, &encryptedMsg); unmarshalErr != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to parse JSON body", map[string]any{
+ "error": unmarshalErr,
+ "body": string(body),
+ })
+ http.Error(w, "Failed to parse JSON", http.StatusBadRequest)
+ return
+ }
+
+ // Verify signature
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
+ logger.ErrorC("wecom_aibot", "Signature verification failed")
+ http.Error(w, "Signature verification failed", http.StatusUnauthorized)
+ return
+ }
+
+ // Decrypt message
+ // For WeCom AI Bot (智能机器人), receiveid is empty string
+ decrypted, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "")
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to decrypt message", map[string]any{
+ "error": err,
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Parse decrypted JSON message
+ var msg WeComAIBotMessage
+ if unmarshalErr := json.Unmarshal([]byte(decrypted), &msg); unmarshalErr != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to parse decrypted JSON", map[string]any{
+ "error": unmarshalErr,
+ "decrypted": decrypted,
+ })
+ http.Error(w, "Failed to parse message", http.StatusInternalServerError)
+ return
+ }
+
+ logger.DebugCF("wecom_aibot", "Decrypted message", map[string]any{
+ "msgtype": msg.MsgType,
+ })
+
+ // Process the message and get streaming response
+ response := c.processMessage(ctx, msg, timestamp, nonce)
+
+ // Check if response is empty (e.g. due to unsupported message type)
+ if response == "" {
+ response = c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ // Return encrypted JSON response
+ w.Header().Set("Content-Type", "application/json; charset=utf-8")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(response))
+}
+
+// processMessage processes the received message and returns encrypted response
+func (c *WeComAIBotChannel) processMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ logger.DebugCF("wecom_aibot", "Processing message", map[string]any{
+ "msgtype": msg.MsgType,
+ })
+
+ switch msg.MsgType {
+ case "text":
+ return c.handleTextMessage(ctx, msg, timestamp, nonce)
+ case "stream":
+ return c.handleStreamMessage(ctx, msg, timestamp, nonce)
+ case "image":
+ return c.handleImageMessage(ctx, msg, timestamp, nonce)
+ case "mixed":
+ return c.handleMixedMessage(ctx, msg, timestamp, nonce)
+ case "event":
+ return c.handleEventMessage(ctx, msg, timestamp, nonce)
+ default:
+ logger.WarnCF("wecom_aibot", "Unsupported message type", map[string]any{
+ "msgtype": msg.MsgType,
+ })
+ return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: c.generateStreamID(),
+ Finish: true,
+ Content: "Unsupported message type: " + msg.MsgType,
+ },
+ })
+ }
+}
+
+// handleTextMessage handles text messages by starting a new streaming task
+func (c *WeComAIBotChannel) handleTextMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ if msg.Text == nil {
+ logger.ErrorC("wecom_aibot", "text message missing text field")
+ return c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ content := msg.Text.Content
+ userID := msg.From.UserID
+ if userID == "" {
+ userID = "unknown"
+ }
+
+ // chatID: group chat uses chatid, single chat uses userid
+ chatID := msg.ChatID
+ if chatID == "" {
+ chatID = userID
+ }
+
+ streamID := c.generateStreamID()
+
+ // WeCom stops sending stream-refresh callbacks after 6 minutes.
+ // Set a slightly shorter deadline so we can send a timeout notice before it gives up.
+ deadline := time.Now().Add(30 * time.Second)
+
+ // Each task gets its own context derived from the channel lifetime context.
+ // Canceling taskCancel interrupts the agent goroutine when the task is removed.
+ taskCtx, taskCancel := context.WithCancel(c.ctx)
+
+ task := &streamTask{
+ StreamID: streamID,
+ ChatID: chatID,
+ ResponseURL: msg.ResponseURL,
+ Question: content,
+ CreatedTime: time.Now(),
+ Deadline: deadline,
+ Finished: false,
+ answerCh: make(chan string, 1),
+ ctx: taskCtx,
+ cancel: taskCancel,
+ }
+
+ c.taskMu.Lock()
+ c.streamTasks[streamID] = task
+ c.chatTasks[chatID] = append(c.chatTasks[chatID], task)
+ c.taskMu.Unlock()
+
+ // Publish to agent asynchronously; agent will call Send() with reply.
+ // Use task.ctx (not c.ctx) so the agent goroutine is canceled when the task is removed.
+ go func() {
+ sender := bus.SenderInfo{
+ Platform: "wecom_aibot",
+ PlatformID: userID,
+ CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID),
+ DisplayName: userID,
+ }
+ peerKind := "direct"
+ if msg.ChatType == "group" {
+ peerKind = "group"
+ }
+ peer := bus.Peer{Kind: peerKind, ID: chatID}
+ metadata := map[string]string{
+ "channel": "wecom_aibot",
+ "chat_type": msg.ChatType,
+ "msg_type": "text",
+ "msgid": msg.MsgID,
+ "aibotid": msg.AIBotID,
+ "stream_id": streamID,
+ "response_url": msg.ResponseURL,
+ }
+ c.HandleMessage(task.ctx, peer, msg.MsgID, userID, chatID,
+ content, nil, metadata, sender)
+ }()
+
+ // Return first streaming response immediately (finish=false, content empty)
+ return c.getStreamResponse(task, timestamp, nonce)
+}
+
+// handleStreamMessage handles stream polling requests
+func (c *WeComAIBotChannel) handleStreamMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ if msg.Stream == nil {
+ logger.ErrorC("wecom_aibot", "Stream message missing stream field")
+ return c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ streamID := msg.Stream.ID
+
+ c.taskMu.RLock()
+ task, exists := c.streamTasks[streamID]
+ c.taskMu.RUnlock()
+
+ if !exists {
+ logger.DebugCF(
+ "wecom_aibot",
+ "Stream task not found (may be from previous session)",
+ map[string]any{
+ "stream_id": streamID,
+ },
+ )
+ return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: streamID,
+ Finish: true,
+ Content: "Task not found or already finished. Please resend your message to start a new session.",
+ },
+ })
+ }
+
+ // Get next response
+ return c.getStreamResponse(task, timestamp, nonce)
+}
+
+// handleImageMessage handles image messages
+func (c *WeComAIBotChannel) handleImageMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ logger.WarnC("wecom_aibot", "Image message type not yet fully implemented")
+ if msg.Image == nil {
+ logger.ErrorC("wecom_aibot", "Image message missing image field")
+ return c.encryptEmptyResponse(timestamp, nonce)
+ }
+
+ imageURL := msg.Image.URL
+
+ // For now, just acknowledge receipt without echoing the image
+ return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: c.generateStreamID(),
+ Finish: true,
+ Content: fmt.Sprintf(
+ "Image received (URL: %s), but image messages are not yet supported",
+ imageURL,
+ ),
+ },
+ })
+}
+
+// handleMixedMessage handles mixed (text + image) messages
+func (c *WeComAIBotChannel) handleMixedMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ logger.WarnC("wecom_aibot", "Mixed message type not yet fully implemented")
+ return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: c.generateStreamID(),
+ Finish: true,
+ Content: "Mixed message type is not yet supported",
+ },
+ })
+}
+
+// handleEventMessage handles event messages
+func (c *WeComAIBotChannel) handleEventMessage(
+ ctx context.Context,
+ msg WeComAIBotMessage,
+ timestamp, nonce string,
+) string {
+ eventType := ""
+ if msg.Event != nil {
+ eventType = msg.Event.EventType
+ }
+ logger.DebugCF("wecom_aibot", "Received event", map[string]any{
+ "event_type": eventType,
+ })
+
+ // Send welcome message when user opens the chat window
+ if eventType == "enter_chat" && c.config.WelcomeMessage != "" {
+ streamID := c.generateStreamID()
+ return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: streamID,
+ Finish: true,
+ Content: c.config.WelcomeMessage,
+ },
+ })
+ }
+
+ return c.encryptEmptyResponse(timestamp, nonce)
+}
+
+// getStreamResponse gets the next streaming response for a task.
+// - If agent replied: return finish=true with the real answer.
+// - If deadline passed: return finish=true with a "please wait" notice, keep task alive for response_url.
+// - Otherwise: return finish=false (empty), client will poll again.
+func (c *WeComAIBotChannel) getStreamResponse(task *streamTask, timestamp, nonce string) string {
+ var content string
+ var finish bool
+ var closeStreamOnly bool // close stream but do NOT remove task (response_url still pending)
+
+ select {
+ case answer := <-task.answerCh:
+ // Agent replied before deadline — normal finish.
+ content = answer
+ finish = true
+ default:
+ if time.Now().After(task.Deadline) {
+ // Deadline reached: close the stream with a notice, then wait for agent via response_url.
+ content = "⏳ Processing, please wait. The results will be sent shortly."
+ finish = true
+ closeStreamOnly = true
+ logger.InfoCF(
+ "wecom_aibot",
+ "Stream deadline reached, switching to response_url mode",
+ map[string]any{
+ "stream_id": task.StreamID,
+ "chat_id": task.ChatID,
+ "response_url": task.ResponseURL != "",
+ },
+ )
+ }
+ // else: still waiting, return finish=false
+ }
+
+ if finish && !closeStreamOnly {
+ // Normal finish: remove from all maps.
+ c.removeTask(task)
+ } else if closeStreamOnly {
+ // Mark stream as closed and remove from streamTasks under a single lock
+ // to keep StreamClosed/StreamClosedAt consistent with map membership.
+ c.taskMu.Lock()
+ task.StreamClosed = true
+ task.StreamClosedAt = time.Now()
+ delete(c.streamTasks, task.StreamID)
+ c.taskMu.Unlock()
+ }
+
+ response := WeComAIBotStreamResponse{
+ MsgType: "stream",
+ Stream: WeComAIBotStreamInfo{
+ ID: task.StreamID,
+ Finish: finish,
+ Content: content,
+ },
+ }
+
+ return c.encryptResponse(task.StreamID, timestamp, nonce, response)
+}
+
+// removeTask removes a task from both streamTasks and chatTasks, marks it finished,
+// and cancels its context to interrupt the associated agent goroutine.
+func (c *WeComAIBotChannel) removeTask(task *streamTask) {
+ // Cancel first so the agent goroutine stops as soon as possible,
+ // before we acquire the write lock.
+ task.cancel()
+
+ c.taskMu.Lock()
+ task.Finished = true // written under c.taskMu, consistent with all readers
+ delete(c.streamTasks, task.StreamID)
+ queue := c.chatTasks[task.ChatID]
+ for i, t := range queue {
+ if t == task {
+ c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...)
+ break
+ }
+ }
+ if len(c.chatTasks[task.ChatID]) == 0 {
+ delete(c.chatTasks, task.ChatID)
+ }
+ c.taskMu.Unlock()
+}
+
+// sendViaResponseURL posts a markdown reply to the WeCom response_url.
+// response_url is valid for 1 hour and can only be used once per callback.
+// Returned errors are wrapped with channels.ErrRateLimit, channels.ErrTemporary,
+// or channels.ErrSendFailed so the manager can apply the right retry policy.
+func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) error {
+ payload := map[string]any{
+ "msgtype": "markdown",
+ "markdown": map[string]string{
+ "content": content,
+ },
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("failed to marshal payload: %w", err)
+ }
+
+ ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, responseURL, bytes.NewBuffer(body))
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+
+ client := &http.Client{Timeout: 15 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode == http.StatusOK {
+ return nil
+ }
+
+ respBody, _ := io.ReadAll(resp.Body)
+ switch {
+ case resp.StatusCode == http.StatusTooManyRequests:
+ return fmt.Errorf("response_url rate limited (%d): %s: %w",
+ resp.StatusCode, respBody, channels.ErrRateLimit)
+ case resp.StatusCode >= 500:
+ return fmt.Errorf("response_url server error (%d): %s: %w",
+ resp.StatusCode, respBody, channels.ErrTemporary)
+ default:
+ return fmt.Errorf("response_url returned %d: %s: %w",
+ resp.StatusCode, respBody, channels.ErrSendFailed)
+ }
+}
+
+// encryptResponse encrypts a streaming response
+func (c *WeComAIBotChannel) encryptResponse(
+ streamID, timestamp, nonce string,
+ response WeComAIBotStreamResponse,
+) string {
+ // Marshal response to JSON
+ plaintext, err := json.Marshal(response)
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to marshal response", map[string]any{
+ "error": err,
+ })
+ return ""
+ }
+
+ logger.DebugCF("wecom_aibot", "Encrypting response", map[string]any{
+ "stream_id": streamID,
+ "finish": response.Stream.Finish,
+ "preview": utils.Truncate(response.Stream.Content, 100),
+ })
+
+ // Encrypt message
+ encrypted, err := c.encryptMessage(string(plaintext), "")
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to encrypt message", map[string]any{
+ "error": err,
+ })
+ return ""
+ }
+
+ // Generate signature
+ signature := computeSignature(c.config.Token, timestamp, nonce, encrypted)
+
+ // Build encrypted response
+ encryptedResp := WeComAIBotEncryptedResponse{
+ Encrypt: encrypted,
+ MsgSignature: signature,
+ Timestamp: timestamp,
+ Nonce: nonce,
+ }
+
+ respJSON, err := json.Marshal(encryptedResp)
+ if err != nil {
+ logger.ErrorCF("wecom_aibot", "Failed to marshal encrypted response", map[string]any{
+ "error": err,
+ })
+ return ""
+ }
+
+ logger.DebugCF("wecom_aibot", "Response encrypted", map[string]any{
+ "stream_id": streamID,
+ })
+
+ return string(respJSON)
+}
+
+// encryptEmptyResponse returns a minimal valid encrypted response
+func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string {
+ // Construct a zero-value stream response and encrypt it so that
+ // WeCom always receives a syntactically valid encrypted JSON object.
+ emptyResp := WeComAIBotStreamResponse{}
+ return c.encryptResponse("", timestamp, nonce, emptyResp)
+}
+
+// encryptMessage encrypts a plain text message for WeCom AI Bot
+func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) {
+ aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey)
+ if err != nil {
+ return "", err
+ }
+
+ frame, err := packWeComFrame(plaintext, receiveid)
+ if err != nil {
+ return "", err
+ }
+
+ // PKCS7 padding then AES-CBC encrypt
+ paddedFrame := pkcs7Pad(frame, blockSize)
+ ciphertext, err := encryptAESCBC(aesKey, paddedFrame)
+ if err != nil {
+ return "", err
+ }
+
+ return base64.StdEncoding.EncodeToString(ciphertext), nil
+}
+
+// generateStreamID generates a random stream ID
+func (c *WeComAIBotChannel) generateStreamID() string {
+ const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+ b := make([]byte, 10)
+ for i := range b {
+ n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
+ b[i] = letters[n.Int64()]
+ }
+ return string(b)
+}
+
+// cleanupLoop periodically cleans up old streaming tasks
+func (c *WeComAIBotChannel) cleanupLoop() {
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ c.cleanupOldTasks()
+ case <-c.ctx.Done():
+ return
+ }
+ }
+}
+
+// cleanupOldTasks removes tasks that have exceeded their expected lifetime:
+// - Active tasks (in streamTasks): cleaned up after 1 hour (response_url validity window).
+// - StreamClosed tasks (in chatTasks only): cleaned up after streamClosedGracePeriod.
+// These tasks are waiting for the agent to call Send() via response_url. If the agent
+// crashes or times out without calling Send(), we must not let them accumulate indefinitely.
+// The grace period is generous enough to cover typical LLM latency but far shorter than 1 hour,
+// preventing chatTasks from filling up when many requests time out in quick succession.
+const (
+ streamClosedGracePeriod = 10 * time.Minute // max wait for agent after stream closes
+ taskMaxLifetime = 1 * time.Hour // absolute max (≈ response_url validity)
+)
+
+func (c *WeComAIBotChannel) cleanupOldTasks() {
+ c.taskMu.Lock()
+ defer c.taskMu.Unlock()
+
+ now := time.Now()
+ cutoff := now.Add(-taskMaxLifetime)
+ for id, task := range c.streamTasks {
+ if task.CreatedTime.Before(cutoff) {
+ delete(c.streamTasks, id)
+ task.cancel() // interrupt agent goroutine still waiting for LLM
+ queue := c.chatTasks[task.ChatID]
+ for i, t := range queue {
+ if t == task {
+ c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...)
+ break
+ }
+ }
+ if len(c.chatTasks[task.ChatID]) == 0 {
+ delete(c.chatTasks, task.ChatID)
+ }
+ logger.DebugCF("wecom_aibot", "Cleaned up expired task", map[string]any{
+ "stream_id": id,
+ })
+ }
+ }
+ // Clean up StreamClosed tasks from chatTasks.
+ // Two expiry conditions are checked:
+ // 1. Absolute expiry: task was created more than taskMaxLifetime ago.
+ // 2. Grace expiry: stream closed more than streamClosedGracePeriod ago
+ // (agent had enough time to reply; it is not coming back).
+ for chatID, queue := range c.chatTasks {
+ filtered := queue[:0]
+ for i, t := range queue {
+ absoluteExpired := t.CreatedTime.Before(cutoff)
+ graceExpired := t.StreamClosed &&
+ !t.StreamClosedAt.IsZero() &&
+ t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod))
+ if t.Finished {
+ // Finished tasks should have been removed by removeTask().
+ // Finding one here (especially not at position 0) means an
+ // unexpected code path left it stranded, causing the queue to
+ // grow silently. Log a warning so it is visible, then drop it.
+ if i > 0 {
+ logger.WarnCF("wecom_aibot",
+ "Found stranded Finished task in the middle of chatTasks queue; "+
+ "this should not happen — removeTask() should have spliced it out",
+ map[string]any{
+ "chat_id": chatID,
+ "stream_id": t.StreamID,
+ "position": i,
+ })
+ }
+ // The task is already finished; its context was already canceled
+ // by removeTask(), so no further action is required.
+ continue
+ } else if !absoluteExpired && !graceExpired {
+ filtered = append(filtered, t)
+ } else {
+ t.cancel() // cancel any lingering agent goroutine
+ }
+ }
+ if len(filtered) == 0 {
+ delete(c.chatTasks, chatID)
+ } else {
+ c.chatTasks[chatID] = filtered
+ }
+ }
+}
+
+// handleHealth handles health check requests
+func (c *WeComAIBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
+ status := "ok"
+ if !c.IsRunning() {
+ status = "not running"
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": status,
+ })
+}
diff --git a/pkg/channels/wecom/aibot_test.go b/pkg/channels/wecom/aibot_test.go
new file mode 100644
index 000000000..6f0664187
--- /dev/null
+++ b/pkg/channels/wecom/aibot_test.go
@@ -0,0 +1,210 @@
+package wecom
+
+import (
+ "context"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func TestNewWeComAIBotChannel(t *testing.T) {
+ t.Run("success with valid config", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ WebhookPath: "/webhook/test",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, err := NewWeComAIBotChannel(cfg, messageBus)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ if ch == nil {
+ t.Fatal("Expected channel to be created")
+ }
+
+ if ch.Name() != "wecom_aibot" {
+ t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name())
+ }
+ })
+
+ t.Run("error with missing token", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ _, err := NewWeComAIBotChannel(cfg, messageBus)
+
+ if err == nil {
+ t.Fatal("Expected error for missing token, got nil")
+ }
+ })
+
+ t.Run("error with missing encoding key", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ }
+
+ messageBus := bus.NewMessageBus()
+ _, err := NewWeComAIBotChannel(cfg, messageBus)
+
+ if err == nil {
+ t.Fatal("Expected error for missing encoding key, got nil")
+ }
+ })
+}
+
+func TestWeComAIBotChannelStartStop(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, err := NewWeComAIBotChannel(cfg, messageBus)
+ if err != nil {
+ t.Fatalf("Failed to create channel: %v", err)
+ }
+
+ ctx := context.Background()
+
+ // Test Start
+ if err := ch.Start(ctx); err != nil {
+ t.Fatalf("Failed to start channel: %v", err)
+ }
+
+ if !ch.IsRunning() {
+ t.Error("Expected channel to be running")
+ }
+
+ // Test Stop
+ if err := ch.Stop(ctx); err != nil {
+ t.Fatalf("Failed to stop channel: %v", err)
+ }
+
+ if ch.IsRunning() {
+ t.Error("Expected channel to be stopped")
+ }
+}
+
+func TestWeComAIBotChannelWebhookPath(t *testing.T) {
+ t.Run("default path", func(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ expectedPath := "/webhook/wecom-aibot"
+ if ch.WebhookPath() != expectedPath {
+ t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, ch.WebhookPath())
+ }
+ })
+
+ t.Run("custom path", func(t *testing.T) {
+ customPath := "/custom/webhook"
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ WebhookPath: customPath,
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ if ch.WebhookPath() != customPath {
+ t.Errorf("Expected webhook path '%s', got '%s'", customPath, ch.WebhookPath())
+ }
+ })
+}
+
+func TestGenerateStreamID(t *testing.T) {
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "testkey1234567890123456789012345678901234567",
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ // Generate multiple IDs and check they are unique
+ ids := make(map[string]bool)
+ for i := 0; i < 100; i++ {
+ id := ch.generateStreamID()
+
+ if len(id) != 10 {
+ t.Errorf("Expected stream ID length 10, got %d", len(id))
+ }
+
+ if ids[id] {
+ t.Errorf("Duplicate stream ID generated: %s", id)
+ }
+ ids[id] = true
+ }
+}
+
+func TestEncryptDecrypt(t *testing.T) {
+ // Use a valid 43-character base64 key (企业微信标准格式)
+ cfg := config.WeComAIBotConfig{
+ Enabled: true,
+ Token: "test_token",
+ EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters
+ }
+
+ messageBus := bus.NewMessageBus()
+ ch, _ := NewWeComAIBotChannel(cfg, messageBus)
+
+ plaintext := "Hello, World!"
+ receiveid := ""
+
+ // Encrypt
+ encrypted, err := ch.encryptMessage(plaintext, receiveid)
+ if err != nil {
+ t.Fatalf("Failed to encrypt message: %v", err)
+ }
+
+ if encrypted == "" {
+ t.Fatal("Encrypted message is empty")
+ }
+
+ // Decrypt
+ decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid)
+ if err != nil {
+ t.Fatalf("Failed to decrypt message: %v", err)
+ }
+
+ if decrypted != plaintext {
+ t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted)
+ }
+}
+
+func TestGenerateSignature(t *testing.T) {
+ token := "test_token"
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ encrypt := "encrypted_msg"
+
+ signature := computeSignature(token, timestamp, nonce, encrypt)
+
+ if signature == "" {
+ t.Error("Generated signature is empty")
+ }
+
+ // Verify signature using verifySignature function
+ if !verifySignature(token, signature, timestamp, nonce, encrypt) {
+ t.Error("Generated signature does not verify correctly")
+ }
+}
diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom/app.go
similarity index 65%
rename from pkg/channels/wecom_app.go
rename to pkg/channels/wecom/app.go
index 302603445..717815b9f 100644
--- a/pkg/channels/wecom_app.go
+++ b/pkg/channels/wecom/app.go
@@ -1,8 +1,4 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// WeCom App (企业微信自建应用) channel implementation
-// Supports receiving messages via webhook callback and sending messages proactively
-
-package channels
+package wecom
import (
"bytes"
@@ -11,14 +7,19 @@ import (
"encoding/xml"
"fmt"
"io"
+ "mime/multipart"
"net/http"
"net/url"
+ "os"
+ "path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -29,16 +30,15 @@ const (
// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用)
type WeComAppChannel struct {
- *BaseChannel
+ *channels.BaseChannel
config config.WeComAppConfig
- server *http.Server
+ client *http.Client
accessToken string
tokenExpiry time.Time
tokenMu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
- processedMsgs map[string]bool // Message deduplication: msg_id -> processed
- msgMu sync.RWMutex
+ processedMsgs *MessageDeduplicator
}
// WeComXMLMessage represents the XML message structure from WeCom
@@ -123,12 +123,27 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required")
}
- base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom)
+ base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(2048),
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
+ // Client timeout must be >= the configured ReplyTimeout so the
+ // per-request context deadline is always the effective limit.
+ clientTimeout := 30 * time.Second
+ if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
+ clientTimeout = d
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
return &WeComAppChannel{
BaseChannel: base,
config: cfg,
- processedMsgs: make(map[string]bool),
+ client: &http.Client{Timeout: clientTimeout},
+ ctx: ctx,
+ cancel: cancel,
+ processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
@@ -137,10 +152,14 @@ func (c *WeComAppChannel) Name() string {
return "wecom_app"
}
-// Start initializes the WeCom App channel with HTTP webhook server
+// Start initializes the WeCom App channel
func (c *WeComAppChannel) Start(ctx context.Context) error {
logger.InfoC("wecom_app", "Starting WeCom App channel...")
+ // Cancel the context created in the constructor to avoid a resource leak.
+ if c.cancel != nil {
+ c.cancel()
+ }
c.ctx, c.cancel = context.WithCancel(ctx)
// Get initial access token
@@ -153,37 +172,8 @@ func (c *WeComAppChannel) Start(ctx context.Context) error {
// Start token refresh goroutine
go c.tokenRefreshLoop()
- // Setup HTTP server for webhook
- mux := http.NewServeMux()
- webhookPath := c.config.WebhookPath
- if webhookPath == "" {
- webhookPath = "/webhook/wecom-app"
- }
- mux.HandleFunc(webhookPath, c.handleWebhook)
-
- // Health check endpoint
- mux.HandleFunc("/health/wecom-app", c.handleHealth)
-
- addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
- c.server = &http.Server{
- Addr: addr,
- Handler: mux,
- }
-
- c.setRunning(true)
- logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{
- "address": addr,
- "path": webhookPath,
- })
-
- // Start server in goroutine
- go func() {
- if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{
- "error": err.Error(),
- })
- }
- }()
+ c.SetRunning(true)
+ logger.InfoC("wecom_app", "WeCom App channel started")
return nil
}
@@ -196,13 +186,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error {
c.cancel()
}
- if c.server != nil {
- shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
- c.server.Shutdown(shutdownCtx)
- }
-
- c.setRunning(false)
+ c.SetRunning(false)
logger.InfoC("wecom_app", "WeCom App channel stopped")
return nil
}
@@ -210,7 +194,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error {
// Send sends a message to WeCom user proactively using access token
func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("wecom_app channel not running")
+ return channels.ErrNotRunning
}
accessToken := c.getAccessToken()
@@ -226,6 +210,222 @@ func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content)
}
+// SendMedia implements the channels.MediaSender interface.
+func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ accessToken := c.getAccessToken()
+ if accessToken == "" {
+ return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary)
+ }
+
+ store := c.GetMediaStore()
+ if store == nil {
+ return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
+ }
+
+ for _, part := range msg.Parts {
+ localPath, err := store.Resolve(part.Ref)
+ if err != nil {
+ logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{
+ "ref": part.Ref,
+ "error": err.Error(),
+ })
+ continue
+ }
+
+ // Map part type to WeCom media type
+ var mediaType string
+ switch part.Type {
+ case "image":
+ mediaType = "image"
+ case "audio":
+ mediaType = "voice"
+ case "video":
+ mediaType = "video"
+ default:
+ mediaType = "file"
+ }
+
+ // Upload media to get media_id
+ mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath)
+ if err != nil {
+ logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{
+ "type": mediaType,
+ "error": err.Error(),
+ })
+ // Fallback: send caption as text
+ if part.Caption != "" {
+ _ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption)
+ }
+ continue
+ }
+
+ // Send media message using the media_id
+ if mediaType == "image" {
+ err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID)
+ } else {
+ // For non-image types, send as text fallback with caption
+ caption := part.Caption
+ if caption == "" {
+ caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename)
+ }
+ err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption)
+ }
+
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// uploadMedia uploads a local file to WeCom temporary media storage.
+func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) {
+ apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s",
+ wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType))
+
+ file, err := os.Open(localPath)
+ if err != nil {
+ return "", fmt.Errorf("failed to open file: %w", err)
+ }
+ defer file.Close()
+
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+
+ filename := filepath.Base(localPath)
+ formFile, err := writer.CreateFormFile("media", filename)
+ if err != nil {
+ return "", fmt.Errorf("failed to create form file: %w", err)
+ }
+
+ if _, err = io.Copy(formFile, file); err != nil {
+ return "", fmt.Errorf("failed to copy file content: %w", err)
+ }
+ writer.Close()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body)
+ if err != nil {
+ return "", fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+
+ resp, err := c.client.Do(req)
+ if err != nil {
+ return "", channels.ClassifyNetError(err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ respBody, _ := io.ReadAll(resp.Body)
+ return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody)))
+ }
+
+ var result struct {
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ MediaID string `json:"media_id"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return "", fmt.Errorf("failed to parse upload response: %w", err)
+ }
+
+ if result.ErrCode != 0 {
+ return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
+ }
+
+ return result.MediaID, nil
+}
+
+// sendWeComMessage marshals payload and POSTs it to the WeCom message API.
+func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error {
+ apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
+
+ jsonData, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("failed to marshal message: %w", err)
+ }
+
+ timeout := c.config.ReplyTimeout
+ if timeout <= 0 {
+ timeout = 5
+ }
+
+ reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.client.Do(req)
+ if err != nil {
+ return channels.ClassifyNetError(err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ respBody, _ := io.ReadAll(resp.Body)
+ return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody)))
+ }
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("failed to read response: %w", err)
+ }
+
+ var sendResp WeComSendMessageResponse
+ if err := json.Unmarshal(respBody, &sendResp); err != nil {
+ return fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ if sendResp.ErrCode != 0 {
+ return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
+ }
+
+ return nil
+}
+
+// sendImageMessage sends an image message using a media_id.
+func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error {
+ msg := WeComImageMessage{
+ ToUser: userID,
+ MsgType: "image",
+ AgentID: c.config.AgentID,
+ }
+ msg.Image.MediaID = mediaID
+ return c.sendWeComMessage(ctx, accessToken, msg)
+}
+
+// WebhookPath returns the path for registering on the shared HTTP server.
+func (c *WeComAppChannel) WebhookPath() string {
+ if c.config.WebhookPath != "" {
+ return c.config.WebhookPath
+ }
+ return "/webhook/wecom-app"
+}
+
+// ServeHTTP implements http.Handler for the shared HTTP server.
+func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ c.handleWebhook(w, r)
+}
+
+// HealthPath returns the health check endpoint path.
+func (c *WeComAppChannel) HealthPath() string {
+ return "/health/wecom-app"
+}
+
+// HealthHandler handles health check requests.
+func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
+ c.handleHealth(w, r)
+}
+
// handleWebhook handles incoming webhook requests from WeCom
func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -279,7 +479,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons
}
// Verify signature
- if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{
"token": c.config.Token,
"msg_signature": msgSignature,
@@ -298,7 +498,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons
"encoding_aes_key": c.config.EncodingAESKey,
"corp_id": c.config.CorpID,
})
- decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID)
+ decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{
"error": err.Error(),
@@ -357,7 +557,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp
}
// Verify signature
- if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
logger.WarnC("wecom_app", "Message signature verification failed")
http.Error(w, "Invalid signature", http.StatusForbidden)
return
@@ -365,7 +565,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp
// Decrypt message with CorpID verification
// For WeCom App (自建应用), receiveid should be corp_id
- decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID)
+ decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID)
if err != nil {
logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{
"error": err.Error(),
@@ -384,8 +584,9 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp
return
}
- // Process the message with context
- go c.processMessage(ctx, msg)
+ // Process the message with the channel's long-lived context (not the HTTP
+ // request context, which is canceled as soon as we return the response).
+ go c.processMessage(c.ctx, msg)
// Return success response immediately
// WeCom App requires response within configured timeout (default 5 seconds)
@@ -405,29 +606,21 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
// Message deduplication: Use msg_id to prevent duplicate processing
// As per WeCom documentation, use msg_id for deduplication
msgID := fmt.Sprintf("%d", msg.MsgId)
- c.msgMu.Lock()
- if c.processedMsgs[msgID] {
- c.msgMu.Unlock()
+ if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
- c.processedMsgs[msgID] = true
- c.msgMu.Unlock()
-
- // Clean up old messages periodically (keep last 1000)
- if len(c.processedMsgs) > 1000 {
- c.msgMu.Lock()
- c.processedMsgs = make(map[string]bool)
- c.msgMu.Unlock()
- }
senderID := msg.FromUserName
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
// Build metadata
// WeCom App only supports direct messages (private chat)
+ peer := bus.Peer{Kind: "direct", ID: senderID}
+ messageID := fmt.Sprintf("%d", msg.MsgId)
+
metadata := map[string]string{
"msg_type": msg.MsgType,
"msg_id": fmt.Sprintf("%d", msg.MsgId),
@@ -435,8 +628,6 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
"platform": "wecom_app",
"media_id": msg.MediaId,
"create_time": fmt.Sprintf("%d", msg.CreateTime),
- "peer_kind": "direct",
- "peer_id": senderID,
}
content := msg.Content
@@ -447,8 +638,15 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
"preview": utils.Truncate(content, 50),
})
+ // Build sender info
+ appSender := bus.SenderInfo{
+ Platform: "wecom",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("wecom", senderID),
+ }
+
// Handle the message through the base channel
- c.HandleMessage(senderID, chatID, content, nil, metadata)
+ c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender)
}
// tokenRefreshLoop periodically refreshes the access token
@@ -516,59 +714,15 @@ func (c *WeComAppChannel) getAccessToken() string {
return c.accessToken
}
-// sendTextMessage sends a text message to a user
+// sendTextMessage sends a text message to a user.
func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
- apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
-
msg := WeComTextMessage{
ToUser: userID,
MsgType: "text",
AgentID: c.config.AgentID,
}
msg.Text.Content = content
-
- jsonData, err := json.Marshal(msg)
- if err != nil {
- return fmt.Errorf("failed to marshal message: %w", err)
- }
-
- // Use configurable timeout (default 5 seconds)
- timeout := c.config.ReplyTimeout
- if timeout <= 0 {
- timeout = 5
- }
-
- reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
- defer cancel()
-
- req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
- if err != nil {
- return fmt.Errorf("failed to create request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
-
- client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
- resp, err := client.Do(req)
- if err != nil {
- return fmt.Errorf("failed to send message: %w", err)
- }
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("failed to read response: %w", err)
- }
-
- var sendResp WeComSendMessageResponse
- if err := json.Unmarshal(body, &sendResp); err != nil {
- return fmt.Errorf("failed to parse response: %w", err)
- }
-
- if sendResp.ErrCode != 0 {
- return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
- }
-
- return nil
+ return c.sendWeComMessage(ctx, accessToken, msg)
}
// handleHealth handles health check requests
diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom/app_test.go
similarity index 92%
rename from pkg/channels/wecom_app_test.go
rename to pkg/channels/wecom/app_test.go
index abf15c52b..7f230494f 100644
--- a/pkg/channels/wecom_app_test.go
+++ b/pkg/channels/wecom/app_test.go
@@ -1,7 +1,4 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// WeCom App (企业微信自建应用) channel tests
-
-package channels
+package wecom
import (
"bytes"
@@ -46,7 +43,7 @@ func encryptTestMessageApp(message, aesKey string) (string, error) {
// Prepare message: random(16) + msg_len(4) + msg + corp_id
random := make([]byte, 0, 16)
- for i := 0; i < 16; i++ {
+ for i := range 16 {
random = append(random, byte(i+1))
}
@@ -197,7 +194,7 @@ func TestWeComAppVerifySignature(t *testing.T) {
msgEncrypt := "test_message"
expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt)
- if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
+ if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
t.Error("valid signature should pass verification")
}
})
@@ -207,7 +204,7 @@ func TestWeComAppVerifySignature(t *testing.T) {
nonce := "test_nonce"
msgEncrypt := "test_message"
- if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
+ if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
t.Error("invalid signature should fail verification")
}
})
@@ -221,7 +218,7 @@ func TestWeComAppVerifySignature(t *testing.T) {
}
chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus)
- if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
+ if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should skip verification and return true")
}
})
@@ -243,7 +240,7 @@ func TestWeComAppDecryptMessage(t *testing.T) {
plainText := "hello world"
encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
- result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey)
+ result, err := decryptMessage(encoded, ch.config.EncodingAESKey)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -268,7 +265,7 @@ func TestWeComAppDecryptMessage(t *testing.T) {
t.Fatalf("failed to encrypt test message: %v", err)
}
- result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey)
+ result, err := decryptMessage(encrypted, ch.config.EncodingAESKey)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -286,7 +283,7 @@ func TestWeComAppDecryptMessage(t *testing.T) {
}
ch, _ := NewWeComAppChannel(cfg, msgBus)
- _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey)
+ _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey)
if err == nil {
t.Error("expected error for invalid base64, got nil")
}
@@ -301,7 +298,7 @@ func TestWeComAppDecryptMessage(t *testing.T) {
}
ch, _ := NewWeComAppChannel(cfg, msgBus)
- _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey)
+ _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey)
if err == nil {
t.Error("expected error for invalid AES key, got nil")
}
@@ -319,67 +316,13 @@ func TestWeComAppDecryptMessage(t *testing.T) {
// Encrypt a very short message that results in ciphertext less than block size
shortData := make([]byte, 8)
- _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey)
+ _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey)
if err == nil {
t.Error("expected error for short ciphertext, got nil")
}
})
}
-func TestWeComAppPKCS7Unpad(t *testing.T) {
- tests := []struct {
- name string
- input []byte
- expected []byte
- }{
- {
- name: "empty input",
- input: []byte{},
- expected: []byte{},
- },
- {
- name: "valid padding 3 bytes",
- input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
- expected: []byte("hello"),
- },
- {
- name: "valid padding 16 bytes (full block)",
- input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
- expected: []byte("123456789012345"),
- },
- {
- name: "invalid padding larger than data",
- input: []byte{20},
- expected: nil, // should return error
- },
- {
- name: "invalid padding zero",
- input: append([]byte("test"), byte(0)),
- expected: nil, // should return error
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := pkcs7UnpadWeCom(tt.input)
- if tt.expected == nil {
- // This case should return an error
- if err == nil {
- t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
- }
- return
- }
- if err != nil {
- t.Errorf("pkcs7Unpad() unexpected error: %v", err)
- return
- }
- if !bytes.Equal(result, tt.expected) {
- t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
- }
- })
- }
-}
-
func TestWeComAppHandleVerification(t *testing.T) {
msgBus := bus.NewMessageBus()
aesKey := generateTestAESKeyApp()
@@ -852,6 +795,28 @@ func TestWeComAppMessageStructures(t *testing.T) {
}
})
+ t.Run("WeComImageMessage structure", func(t *testing.T) {
+ msg := WeComImageMessage{
+ ToUser: "user123",
+ MsgType: "image",
+ AgentID: 1000002,
+ }
+ msg.Image.MediaID = "media_123456"
+
+ if msg.Image.MediaID != "media_123456" {
+ t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456")
+ }
+ if msg.ToUser != "user123" {
+ t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123")
+ }
+ if msg.MsgType != "image" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "image")
+ }
+ if msg.AgentID != 1000002 {
+ t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002)
+ }
+ })
+
t.Run("WeComAccessTokenResponse structure", func(t *testing.T) {
jsonData := `{
"errcode": 0,
diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom/bot.go
similarity index 62%
rename from pkg/channels/wecom.go
rename to pkg/channels/wecom/bot.go
index f8daf89de..9126a847d 100644
--- a/pkg/channels/wecom.go
+++ b/pkg/channels/wecom/bot.go
@@ -1,29 +1,20 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// WeCom Bot (企业微信智能机器人) channel implementation
-// Uses webhook callback mode for receiving messages and webhook API for sending replies
-
-package channels
+package wecom
import (
"bytes"
"context"
- "crypto/aes"
- "crypto/cipher"
- "crypto/sha1"
- "encoding/base64"
- "encoding/binary"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"net/http"
- "sort"
"strings"
- "sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -31,13 +22,12 @@ import (
// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人)
// Uses webhook callback mode - simpler than WeCom App but only supports passive replies
type WeComBotChannel struct {
- *BaseChannel
+ *channels.BaseChannel
config config.WeComConfig
- server *http.Server
+ client *http.Client
ctx context.Context
cancel context.CancelFunc
- processedMsgs map[string]bool // Message deduplication: msg_id -> processed
- msgMu sync.RWMutex
+ processedMsgs *MessageDeduplicator
}
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
@@ -96,12 +86,27 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
return nil, fmt.Errorf("wecom token and webhook_url are required")
}
- base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom)
+ base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom,
+ channels.WithMaxMessageLength(2048),
+ channels.WithGroupTrigger(cfg.GroupTrigger),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
+ // Client timeout must be >= the configured ReplyTimeout so the
+ // per-request context deadline is always the effective limit.
+ clientTimeout := 30 * time.Second
+ if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout {
+ clientTimeout = d
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
return &WeComBotChannel{
BaseChannel: base,
config: cfg,
- processedMsgs: make(map[string]bool),
+ client: &http.Client{Timeout: clientTimeout},
+ ctx: ctx,
+ cancel: cancel,
+ processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
}, nil
}
@@ -110,43 +115,18 @@ func (c *WeComBotChannel) Name() string {
return "wecom"
}
-// Start initializes the WeCom Bot channel with HTTP webhook server
+// Start initializes the WeCom Bot channel
func (c *WeComBotChannel) Start(ctx context.Context) error {
logger.InfoC("wecom", "Starting WeCom Bot channel...")
+ // Cancel the context created in the constructor to avoid a resource leak.
+ if c.cancel != nil {
+ c.cancel()
+ }
c.ctx, c.cancel = context.WithCancel(ctx)
- // Setup HTTP server for webhook
- mux := http.NewServeMux()
- webhookPath := c.config.WebhookPath
- if webhookPath == "" {
- webhookPath = "/webhook/wecom"
- }
- mux.HandleFunc(webhookPath, c.handleWebhook)
-
- // Health check endpoint
- mux.HandleFunc("/health/wecom", c.handleHealth)
-
- addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
- c.server = &http.Server{
- Addr: addr,
- Handler: mux,
- }
-
- c.setRunning(true)
- logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{
- "address": addr,
- "path": webhookPath,
- })
-
- // Start server in goroutine
- go func() {
- if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.ErrorCF("wecom", "HTTP server error", map[string]any{
- "error": err.Error(),
- })
- }
- }()
+ c.SetRunning(true)
+ logger.InfoC("wecom", "WeCom Bot channel started")
return nil
}
@@ -159,13 +139,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error {
c.cancel()
}
- if c.server != nil {
- shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
- defer cancel()
- c.server.Shutdown(shutdownCtx)
- }
-
- c.setRunning(false)
+ c.SetRunning(false)
logger.InfoC("wecom", "WeCom Bot channel stopped")
return nil
}
@@ -175,7 +149,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error {
// For delayed responses, we use the webhook URL
func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
- return fmt.Errorf("wecom channel not running")
+ return channels.ErrNotRunning
}
logger.DebugCF("wecom", "Sending message via webhook", map[string]any{
@@ -186,6 +160,29 @@ func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return c.sendWebhookReply(ctx, msg.ChatID, msg.Content)
}
+// WebhookPath returns the path for registering on the shared HTTP server.
+func (c *WeComBotChannel) WebhookPath() string {
+ if c.config.WebhookPath != "" {
+ return c.config.WebhookPath
+ }
+ return "/webhook/wecom"
+}
+
+// ServeHTTP implements http.Handler for the shared HTTP server.
+func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ c.handleWebhook(w, r)
+}
+
+// HealthPath returns the health check endpoint path.
+func (c *WeComBotChannel) HealthPath() string {
+ return "/health/wecom"
+}
+
+// HealthHandler handles health check requests.
+func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) {
+ c.handleHealth(w, r)
+}
+
// handleWebhook handles incoming webhook requests from WeCom
func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -219,7 +216,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons
}
// Verify signature
- if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
logger.WarnC("wecom", "Signature verification failed")
http.Error(w, "Invalid signature", http.StatusForbidden)
return
@@ -228,7 +225,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons
// Decrypt echostr
// For AIBOT (智能机器人), receiveid should be empty string ""
// Reference: https://developer.work.weixin.qq.com/document/path/101033
- decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "")
+ decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "")
if err != nil {
logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{
"error": err.Error(),
@@ -281,7 +278,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp
}
// Verify signature
- if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
+ if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
logger.WarnC("wecom", "Message signature verification failed")
http.Error(w, "Invalid signature", http.StatusForbidden)
return
@@ -290,7 +287,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp
// Decrypt message
// For AIBOT (智能机器人), receiveid should be empty string ""
// Reference: https://developer.work.weixin.qq.com/document/path/101033
- decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "")
+ decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "")
if err != nil {
logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{
"error": err.Error(),
@@ -309,8 +306,9 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp
return
}
- // Process the message asynchronously with context
- go c.processMessage(ctx, msg)
+ // Process the message with the channel's long-lived context (not the HTTP
+ // request context, which is canceled as soon as we return the response).
+ go c.processMessage(c.ctx, msg)
// Return success response immediately
// WeCom Bot requires response within configured timeout (default 5 seconds)
@@ -330,23 +328,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
// Message deduplication: Use msg_id to prevent duplicate processing
msgID := msg.MsgID
- c.msgMu.Lock()
- if c.processedMsgs[msgID] {
- c.msgMu.Unlock()
+ if !c.processedMsgs.MarkMessageProcessed(msgID) {
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
"msg_id": msgID,
})
return
}
- c.processedMsgs[msgID] = true
- c.msgMu.Unlock()
-
- // Clean up old messages periodically (keep last 1000)
- if len(c.processedMsgs) > 1000 {
- c.msgMu.Lock()
- c.processedMsgs = make(map[string]bool)
- c.msgMu.Unlock()
- }
senderID := msg.From.UserID
@@ -387,12 +374,21 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
}
// Build metadata
+ peer := bus.Peer{Kind: peerKind, ID: peerID}
+
+ // In group chats, apply unified group trigger filtering
+ if isGroupChat {
+ respond, cleaned := c.ShouldRespondInGroup(false, content)
+ if !respond {
+ return
+ }
+ content = cleaned
+ }
+
metadata := map[string]string{
"msg_type": msg.MsgType,
"msg_id": msg.MsgID,
"platform": "wecom",
- "peer_kind": peerKind,
- "peer_id": peerID,
"response_url": msg.ResponseURL,
}
if isGroupChat {
@@ -408,8 +404,19 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
"preview": utils.Truncate(content, 50),
})
+ // Build sender info
+ sender := bus.SenderInfo{
+ Platform: "wecom",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("wecom", senderID),
+ }
+
+ if !c.IsAllowedSender(sender) {
+ return
+ }
+
// Handle the message through the base channel
- c.HandleMessage(senderID, chatID, content, nil, metadata)
+ c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender)
}
// sendWebhookReply sends a reply using the webhook URL
@@ -439,13 +446,17 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content
}
req.Header.Set("Content-Type", "application/json")
- client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
- resp, err := client.Do(req)
+ resp, err := c.client.Do(req)
if err != nil {
- return fmt.Errorf("failed to send webhook reply: %w", err)
+ return channels.ClassifyNetError(err)
}
defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body)))
+ }
+
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
@@ -477,129 +488,3 @@ func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
-
-// WeCom common utilities for both WeCom Bot and WeCom App
-// The following functions were moved from wecom_common.go
-
-// WeComVerifySignature verifies the message signature for WeCom
-// This is a common function used by both WeCom Bot and WeCom App
-func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
- if token == "" {
- return true // Skip verification if token is not set
- }
-
- // Sort parameters
- params := []string{token, timestamp, nonce, msgEncrypt}
- sort.Strings(params)
-
- // Concatenate
- str := strings.Join(params, "")
-
- // SHA1 hash
- hash := sha1.Sum([]byte(str))
- expectedSignature := fmt.Sprintf("%x", hash)
-
- return expectedSignature == msgSignature
-}
-
-// WeComDecryptMessage decrypts the encrypted message using AES
-// This is a common function used by both WeCom Bot and WeCom App
-// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id
-func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) {
- return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "")
-}
-
-// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid
-// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification.
-func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) {
- if encodingAESKey == "" {
- // No encryption, return as is (base64 decode)
- decoded, err := base64.StdEncoding.DecodeString(encryptedMsg)
- if err != nil {
- return "", err
- }
- return string(decoded), nil
- }
-
- // Decode AES key (base64)
- aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
- if err != nil {
- return "", fmt.Errorf("failed to decode AES key: %w", err)
- }
-
- // Decode encrypted message
- cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
- if err != nil {
- return "", fmt.Errorf("failed to decode message: %w", err)
- }
-
- // AES decrypt
- block, err := aes.NewCipher(aesKey)
- if err != nil {
- return "", fmt.Errorf("failed to create cipher: %w", err)
- }
-
- if len(cipherText) < aes.BlockSize {
- return "", fmt.Errorf("ciphertext too short")
- }
-
- // IV is the first 16 bytes of AESKey
- iv := aesKey[:aes.BlockSize]
- mode := cipher.NewCBCDecrypter(block, iv)
- plainText := make([]byte, len(cipherText))
- mode.CryptBlocks(plainText, cipherText)
-
- // Remove PKCS7 padding
- plainText, err = pkcs7UnpadWeCom(plainText)
- if err != nil {
- return "", fmt.Errorf("failed to unpad: %w", err)
- }
-
- // Parse message structure
- // Format: random(16) + msg_len(4) + msg + receiveid
- if len(plainText) < 20 {
- return "", fmt.Errorf("decrypted message too short")
- }
-
- msgLen := binary.BigEndian.Uint32(plainText[16:20])
- if int(msgLen) > len(plainText)-20 {
- return "", fmt.Errorf("invalid message length")
- }
-
- msg := plainText[20 : 20+msgLen]
-
- // Verify receiveid if provided
- if receiveid != "" && len(plainText) > 20+int(msgLen) {
- actualReceiveID := string(plainText[20+msgLen:])
- if actualReceiveID != receiveid {
- return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
- }
- }
-
- return string(msg), nil
-}
-
-// pkcs7UnpadWeCom removes PKCS7 padding with validation
-// WeCom uses block size of 32 (not standard AES block size of 16)
-const wecomBlockSize = 32
-
-func pkcs7UnpadWeCom(data []byte) ([]byte, error) {
- if len(data) == 0 {
- return data, nil
- }
- padding := int(data[len(data)-1])
- // WeCom uses 32-byte block size for PKCS7 padding
- if padding == 0 || padding > wecomBlockSize {
- return nil, fmt.Errorf("invalid padding size: %d", padding)
- }
- if padding > len(data) {
- return nil, fmt.Errorf("padding size larger than data")
- }
- // Verify all padding bytes
- for i := 0; i < padding; i++ {
- if data[len(data)-1-i] != byte(padding) {
- return nil, fmt.Errorf("invalid padding byte at position %d", i)
- }
- }
- return data[:len(data)-padding], nil
-}
diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom/bot_test.go
similarity index 91%
rename from pkg/channels/wecom_test.go
rename to pkg/channels/wecom/bot_test.go
index 8afa7e8c3..c053578b1 100644
--- a/pkg/channels/wecom_test.go
+++ b/pkg/channels/wecom/bot_test.go
@@ -1,7 +1,4 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// WeCom Bot (企业微信智能机器人) channel tests
-
-package channels
+package wecom
import (
"bytes"
@@ -45,7 +42,7 @@ func encryptTestMessage(message, aesKey string) (string, error) {
// Prepare message: random(16) + msg_len(4) + msg + receiveid
random := make([]byte, 0, 16)
- for i := 0; i < 16; i++ {
+ for i := range 16 {
random = append(random, byte(i))
}
@@ -177,7 +174,7 @@ func TestWeComBotVerifySignature(t *testing.T) {
msgEncrypt := "test_message"
expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt)
- if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
+ if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
t.Error("valid signature should pass verification")
}
})
@@ -187,7 +184,7 @@ func TestWeComBotVerifySignature(t *testing.T) {
nonce := "test_nonce"
msgEncrypt := "test_message"
- if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
+ if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
t.Error("invalid signature should fail verification")
}
})
@@ -202,7 +199,7 @@ func TestWeComBotVerifySignature(t *testing.T) {
config: cfgEmpty,
}
- if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
+ if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
t.Error("empty token should skip verification and return true")
}
})
@@ -223,7 +220,7 @@ func TestWeComBotDecryptMessage(t *testing.T) {
plainText := "hello world"
encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
- result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey)
+ result, err := decryptMessage(encoded, ch.config.EncodingAESKey)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -247,7 +244,7 @@ func TestWeComBotDecryptMessage(t *testing.T) {
t.Fatalf("failed to encrypt test message: %v", err)
}
- result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey)
+ result, err := decryptMessage(encrypted, ch.config.EncodingAESKey)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -264,7 +261,7 @@ func TestWeComBotDecryptMessage(t *testing.T) {
}
ch, _ := NewWeComBotChannel(cfg, msgBus)
- _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey)
+ _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey)
if err == nil {
t.Error("expected error for invalid base64, got nil")
}
@@ -278,7 +275,7 @@ func TestWeComBotDecryptMessage(t *testing.T) {
}
ch, _ := NewWeComBotChannel(cfg, msgBus)
- _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey)
+ _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey)
if err == nil {
t.Error("expected error for invalid AES key, got nil")
}
@@ -320,20 +317,20 @@ func TestWeComBotPKCS7Unpad(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- result, err := pkcs7UnpadWeCom(tt.input)
+ result, err := pkcs7Unpad(tt.input)
if tt.expected == nil {
// This case should return an error
if err == nil {
- t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result)
+ t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
}
return
}
if err != nil {
- t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err)
+ t.Errorf("pkcs7Unpad() unexpected error: %v", err)
return
}
if !bytes.Equal(result, tt.expected) {
- t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected)
+ t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
}
})
}
@@ -415,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
}
ch, _ := NewWeComBotChannel(cfg, msgBus)
- t.Run("valid direct message callback", func(t *testing.T) {
- // Create JSON message for direct chat (single)
- jsonMsg := `{
- "msgid": "test_msg_id_123",
- "aibotid": "test_aibot_id",
- "chattype": "single",
- "from": {"userid": "user123"},
- "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
- "msgtype": "text",
- "text": {"content": "Hello World"}
- }`
-
- // Encrypt message
+ runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder {
+ t.Helper()
encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
-
- // Create encrypted XML wrapper
encryptedWrapper := struct {
XMLName xml.Name `xml:"xml"`
Encrypt string `xml:"Encrypt"`
@@ -438,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
Encrypt: encrypted,
}
wrapperData, _ := xml.Marshal(encryptedWrapper)
-
timestamp := "1234567890"
nonce := "test_nonce"
signature := generateSignature("test_token", timestamp, nonce, encrypted)
-
req := httptest.NewRequest(
http.MethodPost,
"/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
bytes.NewReader(wrapperData),
)
w := httptest.NewRecorder()
-
ch.handleMessageCallback(context.Background(), w, req)
+ return w
+ }
+ t.Run("valid direct message callback", func(t *testing.T) {
+ w := runBotMessageCallback(t, `{
+ "msgid": "test_msg_id_123",
+ "aibotid": "test_aibot_id",
+ "chattype": "single",
+ "from": {"userid": "user123"},
+ "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ "msgtype": "text",
+ "text": {"content": "Hello World"}
+ }`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
@@ -461,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
})
t.Run("valid group message callback", func(t *testing.T) {
- // Create JSON message for group chat
- jsonMsg := `{
+ w := runBotMessageCallback(t, `{
"msgid": "test_msg_id_456",
"aibotid": "test_aibot_id",
"chatid": "group_chat_id_123",
@@ -471,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) {
"response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
"msgtype": "text",
"text": {"content": "Hello Group"}
- }`
-
- // Encrypt message
- encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
-
- // Create encrypted XML wrapper
- encryptedWrapper := struct {
- XMLName xml.Name `xml:"xml"`
- Encrypt string `xml:"Encrypt"`
- }{
- Encrypt: encrypted,
- }
- wrapperData, _ := xml.Marshal(encryptedWrapper)
-
- timestamp := "1234567890"
- nonce := "test_nonce"
- signature := generateSignature("test_token", timestamp, nonce, encrypted)
-
- req := httptest.NewRequest(
- http.MethodPost,
- "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
- bytes.NewReader(wrapperData),
- )
- w := httptest.NewRecorder()
-
- ch.handleMessageCallback(context.Background(), w, req)
-
+ }`)
if w.Code != http.StatusOK {
t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
}
diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go
new file mode 100644
index 000000000..6510e6f81
--- /dev/null
+++ b/pkg/channels/wecom/common.go
@@ -0,0 +1,199 @@
+package wecom
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/binary"
+ "fmt"
+ "math/big"
+ "sort"
+ "strings"
+)
+
+// blockSize is the PKCS7 block size used by WeCom (32)
+const blockSize = 32
+
+// computeSignature computes the WeCom message signature from the given parameters.
+// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest.
+func computeSignature(token, timestamp, nonce, encrypt string) string {
+ params := []string{token, timestamp, nonce, encrypt}
+ sort.Strings(params)
+ str := strings.Join(params, "")
+ hash := sha1.Sum([]byte(str))
+ return fmt.Sprintf("%x", hash)
+}
+
+// verifySignature verifies the message signature for WeCom
+// This is a common function used by both WeCom Bot and WeCom App
+func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
+ if token == "" {
+ return true // Skip verification if token is not set
+ }
+ return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature
+}
+
+// decryptMessage decrypts the encrypted message using AES
+// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id
+func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) {
+ return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "")
+}
+
+// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid
+// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification.
+func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) {
+ if encodingAESKey == "" {
+ // No encryption, return as is (base64 decode)
+ decoded, err := base64.StdEncoding.DecodeString(encryptedMsg)
+ if err != nil {
+ return "", err
+ }
+ return string(decoded), nil
+ }
+
+ aesKey, err := decodeWeComAESKey(encodingAESKey)
+ if err != nil {
+ return "", err
+ }
+
+ cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
+ if err != nil {
+ return "", fmt.Errorf("failed to decode message: %w", err)
+ }
+
+ plainText, err := decryptAESCBC(aesKey, cipherText)
+ if err != nil {
+ return "", err
+ }
+
+ return unpackWeComFrame(plainText, receiveid)
+}
+
+// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is
+// appended automatically) and validates that the result is exactly 32 bytes.
+// It is the single place that handles this repeated pattern in both encrypt and decrypt paths.
+func decodeWeComAESKey(encodingAESKey string) ([]byte, error) {
+ aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode AES key: %w", err)
+ }
+ if len(aesKey) != 32 {
+ return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey))
+ }
+ return aesKey, nil
+}
+
+// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring
+// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the
+// plaintext to a multiple of aes.BlockSize before calling.
+func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) {
+ block, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cipher: %w", err)
+ }
+ iv := aesKey[:aes.BlockSize]
+ ciphertext := make([]byte, len(plaintext))
+ cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext)
+ return ciphertext, nil
+}
+
+// packWeComFrame builds the WeCom wire format:
+//
+// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid
+func packWeComFrame(msg, receiveid string) ([]byte, error) {
+ randomBytes := make([]byte, 16)
+ for i := range 16 {
+ n, err := rand.Int(rand.Reader, big.NewInt(10))
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate random: %w", err)
+ }
+ randomBytes[i] = byte('0' + n.Int64())
+ }
+ msgBytes := []byte(msg)
+ msgLenBytes := make([]byte, 4)
+ binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes)))
+ var buf bytes.Buffer
+ buf.Write(randomBytes)
+ buf.Write(msgLenBytes)
+ buf.Write(msgBytes)
+ buf.WriteString(receiveid)
+ return buf.Bytes(), nil
+}
+
+// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame.
+// If receiveid is non-empty it verifies the frame's trailing receiveid field.
+func unpackWeComFrame(data []byte, receiveid string) (string, error) {
+ if len(data) < 20 {
+ return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data))
+ }
+ msgLen := binary.BigEndian.Uint32(data[16:20])
+ if int(msgLen) > len(data)-20 {
+ return "", fmt.Errorf("invalid message length: %d", msgLen)
+ }
+ msg := data[20 : 20+msgLen]
+ if receiveid != "" && len(data) > 20+int(msgLen) {
+ actualReceiveID := string(data[20+msgLen:])
+ if actualReceiveID != receiveid {
+ return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
+ }
+ }
+ return string(msg), nil
+}
+
+// decryptAESCBC decrypts ciphertext using AES-CBC with the given key.
+// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext.
+func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) {
+ if len(ciphertext) == 0 {
+ return nil, fmt.Errorf("ciphertext is empty")
+ }
+ if len(ciphertext)%aes.BlockSize != 0 {
+ return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext))
+ }
+ block, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create cipher: %w", err)
+ }
+ iv := aesKey[:aes.BlockSize]
+ plaintext := make([]byte, len(ciphertext))
+ cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext)
+ plaintext, err = pkcs7Unpad(plaintext)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unpad: %w", err)
+ }
+ return plaintext, nil
+}
+
+// pkcs7Pad adds PKCS7 padding
+func pkcs7Pad(data []byte, blockSize int) []byte {
+ padding := blockSize - (len(data) % blockSize)
+ if padding == 0 {
+ padding = blockSize
+ }
+ padText := bytes.Repeat([]byte{byte(padding)}, padding)
+ return append(data, padText...)
+}
+
+// pkcs7Unpad removes PKCS7 padding with validation
+func pkcs7Unpad(data []byte) ([]byte, error) {
+ if len(data) == 0 {
+ return data, nil
+ }
+ padding := int(data[len(data)-1])
+ // WeCom uses 32-byte block size for PKCS7 padding
+ if padding == 0 || padding > blockSize {
+ return nil, fmt.Errorf("invalid padding size: %d", padding)
+ }
+ if padding > len(data) {
+ return nil, fmt.Errorf("padding size larger than data")
+ }
+ // Verify all padding bytes
+ for i := range padding {
+ if data[len(data)-1-i] != byte(padding) {
+ return nil, fmt.Errorf("invalid padding byte at position %d", i)
+ }
+ }
+ return data[:len(data)-padding], nil
+}
diff --git a/pkg/channels/wecom/dedupe.go b/pkg/channels/wecom/dedupe.go
new file mode 100644
index 000000000..865be668e
--- /dev/null
+++ b/pkg/channels/wecom/dedupe.go
@@ -0,0 +1,54 @@
+package wecom
+
+import "sync"
+
+const wecomMaxProcessedMessages = 1000
+
+// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
+// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
+// messages without causing "amnesia cliffs" when the limit is reached.
+type MessageDeduplicator struct {
+ mu sync.Mutex
+ msgs map[string]bool
+ ring []string
+ idx int
+ max int
+}
+
+// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
+func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
+ if maxEntries <= 0 {
+ maxEntries = wecomMaxProcessedMessages
+ }
+ return &MessageDeduplicator{
+ msgs: make(map[string]bool, maxEntries),
+ ring: make([]string, maxEntries),
+ max: maxEntries,
+ }
+}
+
+// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
+func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // 1. Check for duplicate
+ if d.msgs[msgID] {
+ return false
+ }
+
+ // 2. Evict the oldest message at our current ring position (if any)
+ oldestID := d.ring[d.idx]
+ if oldestID != "" {
+ delete(d.msgs, oldestID)
+ }
+
+ // 3. Store the new message
+ d.msgs[msgID] = true
+ d.ring[d.idx] = msgID
+
+ // 4. Advance the circle queue index
+ d.idx = (d.idx + 1) % d.max
+
+ return true
+}
diff --git a/pkg/channels/wecom/dedupe_test.go b/pkg/channels/wecom/dedupe_test.go
new file mode 100644
index 000000000..10dff4cfe
--- /dev/null
+++ b/pkg/channels/wecom/dedupe_test.go
@@ -0,0 +1,83 @@
+package wecom
+
+import (
+ "sync"
+ "testing"
+)
+
+func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
+ d := NewMessageDeduplicator(wecomMaxProcessedMessages)
+
+ if ok := d.MarkMessageProcessed("msg-1"); !ok {
+ t.Fatalf("first message should be accepted")
+ }
+
+ if ok := d.MarkMessageProcessed("msg-1"); ok {
+ t.Fatalf("duplicate message should be rejected")
+ }
+}
+
+func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
+ d := NewMessageDeduplicator(wecomMaxProcessedMessages)
+
+ const goroutines = 64
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ results := make(chan bool, goroutines)
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+ results <- d.MarkMessageProcessed("msg-concurrent")
+ }()
+ }
+
+ wg.Wait()
+ close(results)
+
+ successes := 0
+ for ok := range results {
+ if ok {
+ successes++
+ }
+ }
+
+ if successes != 1 {
+ t.Fatalf("expected exactly 1 successful mark, got %d", successes)
+ }
+}
+
+func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
+ // Create a deduplicator with a very small capacity to test eviction easily.
+ capacity := 3
+ d := NewMessageDeduplicator(capacity)
+
+ // Fill the queue.
+ d.MarkMessageProcessed("msg-1")
+ d.MarkMessageProcessed("msg-2")
+ d.MarkMessageProcessed("msg-3")
+
+ // At this point, the queue is full. msg-1 is the oldest.
+ if len(d.msgs) != 3 {
+ t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
+ }
+
+ // This should evict msg-1 and add msg-4.
+ if ok := d.MarkMessageProcessed("msg-4"); !ok {
+ t.Fatalf("msg-4 should be accepted")
+ }
+
+ if len(d.msgs) != 3 {
+ t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
+ }
+
+ // msg-1 should now be forgotten (evicted).
+ if ok := d.MarkMessageProcessed("msg-1"); !ok {
+ t.Fatalf("msg-1 should be accepted again because it was evicted")
+ }
+
+ // msg-2 should have been evicted when we added msg-1 back.
+ if ok := d.MarkMessageProcessed("msg-2"); !ok {
+ t.Fatalf("msg-2 should be accepted again because it was evicted")
+ }
+}
diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go
new file mode 100644
index 000000000..bc5a70fa3
--- /dev/null
+++ b/pkg/channels/wecom/init.go
@@ -0,0 +1,19 @@
+package wecom
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewWeComBotChannel(cfg.Channels.WeCom, b)
+ })
+ channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewWeComAppChannel(cfg.Channels.WeComApp, b)
+ })
+ channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b)
+ })
+}
diff --git a/pkg/channels/whatsapp/init.go b/pkg/channels/whatsapp/init.go
new file mode 100644
index 000000000..d9c2669c3
--- /dev/null
+++ b/pkg/channels/whatsapp/init.go
@@ -0,0 +1,13 @@
+package whatsapp
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewWhatsAppChannel(cfg.Channels.WhatsApp, b)
+ })
+}
diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go
similarity index 54%
rename from pkg/channels/whatsapp.go
rename to pkg/channels/whatsapp/whatsapp.go
index 2dc4017ac..70b3e02bf 100644
--- a/pkg/channels/whatsapp.go
+++ b/pkg/channels/whatsapp/whatsapp.go
@@ -1,31 +1,42 @@
-package channels
+package whatsapp
import (
"context"
"encoding/json"
"fmt"
- "log"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type WhatsAppChannel struct {
- *BaseChannel
+ *channels.BaseChannel
conn *websocket.Conn
config config.WhatsAppConfig
url string
+ ctx context.Context
+ cancel context.CancelFunc
mu sync.Mutex
connected bool
}
func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) {
- base := NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom)
+ base := channels.NewBaseChannel(
+ "whatsapp",
+ cfg,
+ bus,
+ cfg.AllowFrom,
+ channels.WithMaxMessageLength(65536),
+ channels.WithReasoningChannelID(cfg.ReasoningChannelID),
+ )
return &WhatsAppChannel{
BaseChannel: base,
@@ -36,7 +47,11 @@ func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsA
}
func (c *WhatsAppChannel) Start(ctx context.Context) error {
- log.Printf("Starting WhatsApp channel connecting to %s...", c.url)
+ logger.InfoCF("whatsapp", "Starting WhatsApp channel", map[string]any{
+ "bridge_url": c.url,
+ })
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
dialer := websocket.DefaultDialer
dialer.HandshakeTimeout = 10 * time.Second
@@ -46,6 +61,7 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error {
resp.Body.Close()
}
if err != nil {
+ c.cancel()
return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err)
}
@@ -54,39 +70,57 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error {
c.connected = true
c.mu.Unlock()
- c.setRunning(true)
- log.Println("WhatsApp channel connected")
+ c.SetRunning(true)
+ logger.InfoC("whatsapp", "WhatsApp channel connected")
- go c.listen(ctx)
+ go c.listen()
return nil
}
func (c *WhatsAppChannel) Stop(ctx context.Context) error {
- log.Println("Stopping WhatsApp channel...")
+ logger.InfoC("whatsapp", "Stopping WhatsApp channel...")
+
+ // Cancel context first to signal listen goroutine to exit
+ if c.cancel != nil {
+ c.cancel()
+ }
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
if err := c.conn.Close(); err != nil {
- log.Printf("Error closing WhatsApp connection: %v", err)
+ logger.ErrorCF("whatsapp", "Error closing WhatsApp connection", map[string]any{
+ "error": err.Error(),
+ })
}
c.conn = nil
}
c.connected = false
- c.setRunning(false)
+ c.SetRunning(false)
return nil
}
func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+
+ // Check ctx before acquiring lock
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
- return fmt.Errorf("whatsapp connection not established")
+ return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary)
}
payload := map[string]any{
@@ -100,17 +134,20 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("failed to marshal message: %w", err)
}
+ _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
- return fmt.Errorf("failed to send message: %w", err)
+ _ = c.conn.SetWriteDeadline(time.Time{})
+ return fmt.Errorf("whatsapp send: %w", channels.ErrTemporary)
}
+ _ = c.conn.SetWriteDeadline(time.Time{})
return nil
}
-func (c *WhatsAppChannel) listen(ctx context.Context) {
+func (c *WhatsAppChannel) listen() {
for {
select {
- case <-ctx.Done():
+ case <-c.ctx.Done():
return
default:
c.mu.Lock()
@@ -124,14 +161,18 @@ func (c *WhatsAppChannel) listen(ctx context.Context) {
_, message, err := conn.ReadMessage()
if err != nil {
- log.Printf("WhatsApp read error: %v", err)
+ logger.ErrorCF("whatsapp", "WhatsApp read error", map[string]any{
+ "error": err.Error(),
+ })
time.Sleep(2 * time.Second)
continue
}
var msg map[string]any
if err := json.Unmarshal(message, &msg); err != nil {
- log.Printf("Failed to unmarshal WhatsApp message: %v", err)
+ logger.ErrorCF("whatsapp", "Failed to unmarshal WhatsApp message", map[string]any{
+ "error": err.Error(),
+ })
continue
}
@@ -174,22 +215,38 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
}
metadata := make(map[string]string)
- if messageID, ok := msg["id"].(string); ok {
- metadata["message_id"] = messageID
+ var messageID string
+ if mid, ok := msg["id"].(string); ok {
+ messageID = mid
}
if userName, ok := msg["from_name"].(string); ok {
metadata["user_name"] = userName
}
+ var peer bus.Peer
if chatID == senderID {
- metadata["peer_kind"] = "direct"
- metadata["peer_id"] = senderID
+ peer = bus.Peer{Kind: "direct", ID: senderID}
} else {
- metadata["peer_kind"] = "group"
- metadata["peer_id"] = chatID
+ peer = bus.Peer{Kind: "group", ID: chatID}
}
- log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
+ logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
+ "sender": senderID,
+ "preview": utils.Truncate(content, 50),
+ })
- c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
+ sender := bus.SenderInfo{
+ Platform: "whatsapp",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("whatsapp", senderID),
+ }
+ if display, ok := metadata["user_name"]; ok {
+ sender.DisplayName = display
+ }
+
+ if !c.IsAllowedSender(sender) {
+ return
+ }
+
+ c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
}
diff --git a/pkg/channels/whatsapp_native/init.go b/pkg/channels/whatsapp_native/init.go
new file mode 100644
index 000000000..df13e8539
--- /dev/null
+++ b/pkg/channels/whatsapp_native/init.go
@@ -0,0 +1,20 @@
+package whatsapp
+
+import (
+ "path/filepath"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("whatsapp_native", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ waCfg := cfg.Channels.WhatsApp
+ storePath := waCfg.SessionStorePath
+ if storePath == "" {
+ storePath = filepath.Join(cfg.WorkspacePath(), "whatsapp")
+ }
+ return NewWhatsAppNativeChannel(waCfg, b, storePath)
+ })
+}
diff --git a/pkg/channels/whatsapp_native/whatsapp_native.go b/pkg/channels/whatsapp_native/whatsapp_native.go
new file mode 100644
index 000000000..188a7c8fa
--- /dev/null
+++ b/pkg/channels/whatsapp_native/whatsapp_native.go
@@ -0,0 +1,448 @@
+//go:build whatsapp_native
+
+// PicoClaw - Ultra-lightweight personal AI agent
+// License: MIT
+//
+// Copyright (c) 2026 PicoClaw contributors
+
+package whatsapp
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/mdp/qrterminal/v3"
+ "go.mau.fi/whatsmeow"
+ "go.mau.fi/whatsmeow/proto/waE2E"
+ "go.mau.fi/whatsmeow/store/sqlstore"
+ "go.mau.fi/whatsmeow/types"
+ "go.mau.fi/whatsmeow/types/events"
+ waLog "go.mau.fi/whatsmeow/util/log"
+ "google.golang.org/protobuf/proto"
+ _ "modernc.org/sqlite"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+const (
+ sqliteDriver = "sqlite"
+ whatsappDBName = "store.db"
+
+ reconnectInitial = 5 * time.Second
+ reconnectMax = 5 * time.Minute
+ reconnectMultiplier = 2.0
+)
+
+// WhatsAppNativeChannel implements the WhatsApp channel using whatsmeow (in-process, no external bridge).
+type WhatsAppNativeChannel struct {
+ *channels.BaseChannel
+ config config.WhatsAppConfig
+ storePath string
+ client *whatsmeow.Client
+ container *sqlstore.Container
+ mu sync.Mutex
+ runCtx context.Context
+ runCancel context.CancelFunc
+ reconnectMu sync.Mutex
+ reconnecting bool
+ stopping atomic.Bool // set once Stop begins; prevents new wg.Add calls
+ wg sync.WaitGroup // tracks background goroutines (QR handler, reconnect)
+}
+
+// NewWhatsAppNativeChannel creates a WhatsApp channel that uses whatsmeow for connection.
+// storePath is the directory for the SQLite session store (e.g. workspace/whatsapp).
+func NewWhatsAppNativeChannel(
+ cfg config.WhatsAppConfig,
+ bus *bus.MessageBus,
+ storePath string,
+) (channels.Channel, error) {
+ base := channels.NewBaseChannel("whatsapp_native", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(65536))
+ if storePath == "" {
+ storePath = "whatsapp"
+ }
+ c := &WhatsAppNativeChannel{
+ BaseChannel: base,
+ config: cfg,
+ storePath: storePath,
+ }
+ return c, nil
+}
+
+func (c *WhatsAppNativeChannel) Start(ctx context.Context) error {
+ logger.InfoCF("whatsapp", "Starting WhatsApp native channel (whatsmeow)", map[string]any{"store": c.storePath})
+
+ // Reset lifecycle state from any previous Stop() so a restarted channel
+ // behaves correctly. Use reconnectMu to be consistent with eventHandler
+ // and Stop() which coordinate under the same lock.
+ c.reconnectMu.Lock()
+ c.stopping.Store(false)
+ c.reconnecting = false
+ c.reconnectMu.Unlock()
+
+ if err := os.MkdirAll(c.storePath, 0o700); err != nil {
+ return fmt.Errorf("create session store dir: %w", err)
+ }
+
+ dbPath := filepath.Join(c.storePath, whatsappDBName)
+ connStr := "file:" + dbPath + "?_foreign_keys=on"
+
+ db, err := sql.Open(sqliteDriver, connStr)
+ if err != nil {
+ return fmt.Errorf("open whatsapp store: %w", err)
+ }
+ db.SetMaxOpenConns(1)
+ db.SetMaxIdleConns(1)
+ if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
+ _ = db.Close()
+ return fmt.Errorf("enable foreign keys: %w", err)
+ }
+
+ waLogger := waLog.Stdout("WhatsApp", "WARN", true)
+ container := sqlstore.NewWithDB(db, sqliteDriver, waLogger)
+ if err = container.Upgrade(ctx); err != nil {
+ _ = db.Close()
+ return fmt.Errorf("open whatsapp store: %w", err)
+ }
+
+ deviceStore, err := container.GetFirstDevice(ctx)
+ if err != nil {
+ _ = container.Close()
+ return fmt.Errorf("get device store: %w", err)
+ }
+
+ client := whatsmeow.NewClient(deviceStore, waLogger)
+
+ // Create runCtx/runCancel BEFORE registering event handler and starting
+ // goroutines so that Stop() can cancel them at any time, including during
+ // the QR-login flow.
+ c.runCtx, c.runCancel = context.WithCancel(ctx)
+
+ client.AddEventHandler(c.eventHandler)
+
+ c.mu.Lock()
+ c.container = container
+ c.client = client
+ c.mu.Unlock()
+
+ // cleanupOnError clears struct references and releases resources when
+ // Start() fails after fields are already assigned. This prevents
+ // Stop() from operating on stale references (double-close, disconnect
+ // of a partially-initialized client, or stray event handler callbacks).
+ startOK := false
+ defer func() {
+ if startOK {
+ return
+ }
+ c.runCancel()
+ client.Disconnect()
+ c.mu.Lock()
+ c.client = nil
+ c.container = nil
+ c.mu.Unlock()
+ _ = container.Close()
+ }()
+
+ if client.Store.ID == nil {
+ qrChan, err := client.GetQRChannel(c.runCtx)
+ if err != nil {
+ return fmt.Errorf("get QR channel: %w", err)
+ }
+ if err := client.Connect(); err != nil {
+ return fmt.Errorf("connect: %w", err)
+ }
+ // Handle QR events in a background goroutine so Start() returns
+ // promptly. The goroutine is tracked via c.wg and respects
+ // c.runCtx for cancellation.
+ // Guard wg.Add with reconnectMu + stopping check (same protocol
+ // as eventHandler) so a concurrent Stop() cannot enter wg.Wait()
+ // while we call wg.Add(1).
+ c.reconnectMu.Lock()
+ if c.stopping.Load() {
+ c.reconnectMu.Unlock()
+ return fmt.Errorf("channel stopped during QR setup")
+ }
+ c.wg.Add(1)
+ c.reconnectMu.Unlock()
+ go func() {
+ defer c.wg.Done()
+ for {
+ select {
+ case <-c.runCtx.Done():
+ return
+ case evt, ok := <-qrChan:
+ if !ok {
+ return
+ }
+ if evt.Event == "code" {
+ logger.InfoCF("whatsapp", "Scan this QR code with WhatsApp (Linked Devices):", nil)
+ qrterminal.GenerateWithConfig(evt.Code, qrterminal.Config{
+ Level: qrterminal.L,
+ Writer: os.Stdout,
+ HalfBlocks: true,
+ })
+ } else {
+ logger.InfoCF("whatsapp", "WhatsApp login event", map[string]any{"event": evt.Event})
+ }
+ }
+ }
+ }()
+ } else {
+ if err := client.Connect(); err != nil {
+ return fmt.Errorf("connect: %w", err)
+ }
+ }
+
+ startOK = true
+ c.SetRunning(true)
+ logger.InfoC("whatsapp", "WhatsApp native channel connected")
+ return nil
+}
+
+func (c *WhatsAppNativeChannel) Stop(ctx context.Context) error {
+ logger.InfoC("whatsapp", "Stopping WhatsApp native channel")
+
+ // Mark as stopping under reconnectMu so the flag is visible to
+ // eventHandler atomically with respect to its wg.Add(1) call.
+ // This closes the TOCTOU window where eventHandler could check
+ // stopping (false), then Stop sets it true + enters wg.Wait,
+ // then eventHandler calls wg.Add(1) — causing a panic.
+ c.reconnectMu.Lock()
+ c.stopping.Store(true)
+ c.reconnectMu.Unlock()
+
+ if c.runCancel != nil {
+ c.runCancel()
+ }
+
+ // Disconnect the client first so any blocking Connect()/reconnect loops
+ // can be interrupted before we wait on the goroutines.
+ c.mu.Lock()
+ client := c.client
+ container := c.container
+ c.mu.Unlock()
+
+ if client != nil {
+ client.Disconnect()
+ }
+
+ // Wait for background goroutines (QR handler, reconnect) to finish in a
+ // context-aware way so Stop can be bounded by ctx.
+ done := make(chan struct{})
+ go func() {
+ c.wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ // All goroutines have finished.
+ case <-ctx.Done():
+ // Context canceled or timed out; log and proceed with best-effort cleanup.
+ logger.WarnC("whatsapp", fmt.Sprintf("Stop context canceled before all goroutines finished: %v", ctx.Err()))
+ }
+
+ // Now it is safe to clear and close resources.
+ c.mu.Lock()
+ c.client = nil
+ c.container = nil
+ c.mu.Unlock()
+
+ if container != nil {
+ _ = container.Close()
+ }
+ c.SetRunning(false)
+ return nil
+}
+
+func (c *WhatsAppNativeChannel) eventHandler(evt any) {
+ switch evt.(type) {
+ case *events.Message:
+ c.handleIncoming(evt.(*events.Message))
+ case *events.Disconnected:
+ logger.InfoCF("whatsapp", "WhatsApp disconnected, will attempt reconnection", nil)
+ c.reconnectMu.Lock()
+ if c.reconnecting {
+ c.reconnectMu.Unlock()
+ return
+ }
+ // Check stopping while holding the lock so the check and wg.Add
+ // are atomic with respect to Stop() setting the flag + calling
+ // wg.Wait(). This prevents the TOCTOU race.
+ if c.stopping.Load() {
+ c.reconnectMu.Unlock()
+ return
+ }
+ c.reconnecting = true
+ c.wg.Add(1)
+ c.reconnectMu.Unlock()
+ go func() {
+ defer c.wg.Done()
+ c.reconnectWithBackoff()
+ }()
+ }
+}
+
+func (c *WhatsAppNativeChannel) reconnectWithBackoff() {
+ defer func() {
+ c.reconnectMu.Lock()
+ c.reconnecting = false
+ c.reconnectMu.Unlock()
+ }()
+
+ backoff := reconnectInitial
+ for {
+ select {
+ case <-c.runCtx.Done():
+ return
+ default:
+ }
+
+ c.mu.Lock()
+ client := c.client
+ c.mu.Unlock()
+ if client == nil {
+ return
+ }
+
+ logger.InfoCF("whatsapp", "WhatsApp reconnecting", map[string]any{"backoff": backoff.String()})
+ err := client.Connect()
+ if err == nil {
+ logger.InfoC("whatsapp", "WhatsApp reconnected")
+ return
+ }
+
+ logger.WarnCF("whatsapp", "WhatsApp reconnect failed", map[string]any{"error": err.Error()})
+
+ select {
+ case <-c.runCtx.Done():
+ return
+ case <-time.After(backoff):
+ if backoff < reconnectMax {
+ next := time.Duration(float64(backoff) * reconnectMultiplier)
+ if next > reconnectMax {
+ next = reconnectMax
+ }
+ backoff = next
+ }
+ }
+ }
+}
+
+func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
+ if evt.Message == nil {
+ return
+ }
+ senderID := evt.Info.Sender.String()
+ chatID := evt.Info.Chat.String()
+ content := evt.Message.GetConversation()
+ if content == "" && evt.Message.ExtendedTextMessage != nil {
+ content = evt.Message.ExtendedTextMessage.GetText()
+ }
+ content = utils.SanitizeMessageContent(content)
+
+ if content == "" {
+ return
+ }
+
+ var mediaPaths []string
+
+ metadata := make(map[string]string)
+ metadata["message_id"] = evt.Info.ID
+ if evt.Info.PushName != "" {
+ metadata["user_name"] = evt.Info.PushName
+ }
+ if evt.Info.Chat.Server == types.GroupServer {
+ metadata["peer_kind"] = "group"
+ metadata["peer_id"] = chatID
+ } else {
+ metadata["peer_kind"] = "direct"
+ metadata["peer_id"] = senderID
+ }
+
+ peerKind := "direct"
+ if evt.Info.Chat.Server == types.GroupServer {
+ peerKind = "group"
+ }
+ peer := bus.Peer{Kind: peerKind, ID: chatID}
+ messageID := evt.Info.ID
+ sender := bus.SenderInfo{
+ Platform: "whatsapp",
+ PlatformID: senderID,
+ CanonicalID: identity.BuildCanonicalID("whatsapp", senderID),
+ DisplayName: evt.Info.PushName,
+ }
+
+ if !c.IsAllowedSender(sender) {
+ return
+ }
+
+ logger.DebugCF(
+ "whatsapp",
+ "WhatsApp message received",
+ map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)},
+ )
+ c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
+}
+
+func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return channels.ErrNotRunning
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
+ c.mu.Lock()
+ client := c.client
+ c.mu.Unlock()
+
+ if client == nil || !client.IsConnected() {
+ return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary)
+ }
+
+ // Detect unpaired state: the client is connected (to WhatsApp servers)
+ // but has not completed QR-login yet, so sending would fail.
+ if client.Store.ID == nil {
+ return fmt.Errorf("whatsapp not yet paired (QR login pending): %w", channels.ErrTemporary)
+ }
+
+ to, err := parseJID(msg.ChatID)
+ if err != nil {
+ return fmt.Errorf("invalid chat id %q: %w", msg.ChatID, err)
+ }
+
+ waMsg := &waE2E.Message{
+ Conversation: proto.String(msg.Content),
+ }
+
+ if _, err = client.SendMessage(ctx, to, waMsg); err != nil {
+ return fmt.Errorf("whatsapp send: %w", channels.ErrTemporary)
+ }
+ return nil
+}
+
+// parseJID converts a chat ID (phone number or JID string) to types.JID.
+func parseJID(s string) (types.JID, error) {
+ s = strings.TrimSpace(s)
+ if s == "" {
+ return types.JID{}, fmt.Errorf("empty chat id")
+ }
+ if strings.Contains(s, "@") {
+ return types.ParseJID(s)
+ }
+ return types.NewJID(s, types.DefaultUserServer), nil
+}
diff --git a/pkg/channels/whatsapp_native/whatsapp_native_stub.go b/pkg/channels/whatsapp_native/whatsapp_native_stub.go
new file mode 100644
index 000000000..984af23e7
--- /dev/null
+++ b/pkg/channels/whatsapp_native/whatsapp_native_stub.go
@@ -0,0 +1,21 @@
+//go:build !whatsapp_native
+
+package whatsapp
+
+import (
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// NewWhatsAppNativeChannel returns an error when the binary was not built with -tags whatsapp_native.
+// Build with: go build -tags whatsapp_native ./cmd/...
+func NewWhatsAppNativeChannel(
+ cfg config.WhatsAppConfig,
+ bus *bus.MessageBus,
+ storePath string,
+) (channels.Channel, error) {
+ return nil, fmt.Errorf("whatsapp native not compiled in; build with -tags whatsapp_native")
+}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index f7c78136b..6cabddafc 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -4,10 +4,11 @@ import (
"encoding/json"
"fmt"
"os"
- "path/filepath"
"sync/atomic"
"github.com/caarlos0/env/v11"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
)
// rrCounter is a global counter for round-robin load balancing across models.
@@ -167,17 +168,30 @@ type SessionConfig struct {
}
type AgentDefaults struct {
- Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
- RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
- Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
- ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
- Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
- ModelFallbacks []string `json:"model_fallbacks,omitempty"`
- ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
- ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
- MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
- Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
- MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
+ Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
+ RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
+ AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"`
+ Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
+ ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
+ Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
+ ModelFallbacks []string `json:"model_fallbacks,omitempty"`
+ ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
+ ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
+ MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
+ Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
+ MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
+ SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
+ SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
+ MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
+}
+
+const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
+
+func (d *AgentDefaults) GetMaxMediaSize() int {
+ if d.MaxMediaSize > 0 {
+ return d.MaxMediaSize
+ }
+ return DefaultMaxMediaSize
}
// GetModelName returns the effective model name for the agent defaults.
@@ -190,120 +204,201 @@ func (d *AgentDefaults) GetModelName() string {
}
type ChannelsConfig struct {
- WhatsApp WhatsAppConfig `json:"whatsapp"`
- Telegram TelegramConfig `json:"telegram"`
- Feishu FeishuConfig `json:"feishu"`
- Discord DiscordConfig `json:"discord"`
- MaixCam MaixCamConfig `json:"maixcam"`
- QQ QQConfig `json:"qq"`
- DingTalk DingTalkConfig `json:"dingtalk"`
- Slack SlackConfig `json:"slack"`
- LINE LINEConfig `json:"line"`
- OneBot OneBotConfig `json:"onebot"`
- WeCom WeComConfig `json:"wecom"`
- WeComApp WeComAppConfig `json:"wecom_app"`
+ WhatsApp WhatsAppConfig `json:"whatsapp"`
+ Telegram TelegramConfig `json:"telegram"`
+ Feishu FeishuConfig `json:"feishu"`
+ Discord DiscordConfig `json:"discord"`
+ MaixCam MaixCamConfig `json:"maixcam"`
+ QQ QQConfig `json:"qq"`
+ DingTalk DingTalkConfig `json:"dingtalk"`
+ Slack SlackConfig `json:"slack"`
+ LINE LINEConfig `json:"line"`
+ OneBot OneBotConfig `json:"onebot"`
+ WeCom WeComConfig `json:"wecom"`
+ WeComApp WeComAppConfig `json:"wecom_app"`
+ WeComAIBot WeComAIBotConfig `json:"wecom_aibot"`
+ Pico PicoConfig `json:"pico"`
+}
+
+// GroupTriggerConfig controls when the bot responds in group chats.
+type GroupTriggerConfig struct {
+ MentionOnly bool `json:"mention_only,omitempty"`
+ Prefixes []string `json:"prefixes,omitempty"`
+}
+
+// TypingConfig controls typing indicator behavior (Phase 10).
+type TypingConfig struct {
+ Enabled bool `json:"enabled,omitempty"`
+}
+
+// PlaceholderConfig controls placeholder message behavior (Phase 10).
+type PlaceholderConfig struct {
+ Enabled bool `json:"enabled,omitempty"`
+ Text string `json:"text,omitempty"`
}
type WhatsAppConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
- BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
+ BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
+ UseNative bool `json:"use_native" env:"PICOCLAW_CHANNELS_WHATSAPP_USE_NATIVE"`
+ SessionStorePath string `json:"session_store_path" env:"PICOCLAW_CHANNELS_WHATSAPP_SESSION_STORE_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WHATSAPP_REASONING_CHANNEL_ID"`
}
type TelegramConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
- Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
- Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
+ BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"`
+ Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ Typing TypingConfig `json:"typing,omitempty"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"`
}
type FeishuConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
- AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
- AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
- EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
- VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
+ AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
+ AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
+ EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
+ VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
}
type DiscordConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
- Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
- MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
+ Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
+ MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ Typing TypingConfig `json:"typing,omitempty"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DISCORD_REASONING_CHANNEL_ID"`
}
type MaixCamConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
- Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
- Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
+ Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
+ Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MAIXCAM_REASONING_CHANNEL_ID"`
}
type QQConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
- AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
- AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
+ AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
+ AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"`
}
type DingTalkConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
- ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
- ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
+ ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
+ ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DINGTALK_REASONING_CHANNEL_ID"`
}
type SlackConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
- BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
- AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
+ BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
+ AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ Typing TypingConfig `json:"typing,omitempty"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_SLACK_REASONING_CHANNEL_ID"`
}
type LINEConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
- ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
- ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"`
- WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"`
- WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"`
- WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
+ ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
+ ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"`
+ WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"`
+ WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ Typing TypingConfig `json:"typing,omitempty"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_LINE_REASONING_CHANNEL_ID"`
}
type OneBotConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
- WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
- AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
- ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
- GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
+ WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
+ AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
+ ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
+ GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ Typing TypingConfig `json:"typing,omitempty"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_ONEBOT_REASONING_CHANNEL_ID"`
}
type WeComConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
- Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
- EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
- WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
- WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
- WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
- WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
- ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
+ EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
+ WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
+ WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
+ WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
+ ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"`
}
type WeComAppConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
- CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
- CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
- AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
- Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
- EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
- WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
- WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
- WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
- ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
+ CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
+ CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
+ AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
+ EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
+ WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
+ WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
+ ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
+ GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"`
+}
+
+type WeComAIBotConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"`
+ EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"`
+ ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"`
+ MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps
+ WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome
+ ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"`
+}
+
+type PicoConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"`
+ AllowTokenQuery bool `json:"allow_token_query,omitempty"`
+ AllowOrigins []string `json:"allow_origins,omitempty"`
+ PingInterval int `json:"ping_interval,omitempty"`
+ ReadTimeout int `json:"read_timeout,omitempty"`
+ WriteTimeout int `json:"write_timeout,omitempty"`
+ MaxConnections int `json:"max_connections,omitempty"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"`
+ Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
}
type HeartbeatConfig struct {
@@ -319,6 +414,7 @@ type DevicesConfig struct {
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI OpenAIProviderConfig `json:"openai"`
+ LiteLLM ProviderConfig `json:"litellm"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
@@ -342,6 +438,7 @@ type ProvidersConfig struct {
func (p ProvidersConfig) IsEmpty() bool {
return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" &&
p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" &&
+ p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" &&
p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" &&
p.Groq.APIKey == "" && p.Groq.APIBase == "" &&
p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" &&
@@ -371,11 +468,12 @@ func (p ProvidersConfig) MarshalJSON() ([]byte, error) {
}
type ProviderConfig struct {
- APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
- APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
- Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
- AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
- ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc`
+ APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
+ APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
+ Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
+ RequestTimeout int `json:"request_timeout,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_REQUEST_TIMEOUT"`
+ AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
+ ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc`
}
type OpenAIProviderConfig struct {
@@ -406,6 +504,7 @@ type ModelConfig struct {
// Optional optimizations
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
+ RequestTimeout int `json:"request_timeout,omitempty"`
}
// Validate checks if the ModelConfig has all required fields.
@@ -454,15 +553,27 @@ type SearXNGConfig struct {
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARXNG_MAX_RESULTS"`
}
+type GLMSearchConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"`
+ APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"`
+ BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"`
+ // SearchEngine specifies the search backend: "search_std" (default),
+ // "search_pro", "search_pro_sogou", or "search_pro_quark".
+ SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"`
+ MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"`
+}
+
type WebToolsConfig struct {
Brave BraveConfig `json:"brave"`
Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
SearXNG SearXNGConfig `json:"searxng"`
+ GLMSearch GLMSearchConfig `json:"glm_search"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
- Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
+ Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
+ FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
}
type CronToolsConfig struct {
@@ -470,15 +581,26 @@ type CronToolsConfig struct {
}
type ExecConfig struct {
- EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
- CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
+ EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
+ CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
+ CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"`
+}
+
+type MediaCleanupConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"`
+ MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"`
+ Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"`
}
type ToolsConfig struct {
- Web WebToolsConfig `json:"web"`
- Cron CronToolsConfig `json:"cron"`
- Exec ExecConfig `json:"exec"`
- Skills SkillsToolsConfig `json:"skills"`
+ AllowReadPaths []string `json:"allow_read_paths" env:"PICOCLAW_TOOLS_ALLOW_READ_PATHS"`
+ AllowWritePaths []string `json:"allow_write_paths" env:"PICOCLAW_TOOLS_ALLOW_WRITE_PATHS"`
+ Web WebToolsConfig `json:"web"`
+ Cron CronToolsConfig `json:"cron"`
+ Exec ExecConfig `json:"exec"`
+ Skills SkillsToolsConfig `json:"skills"`
+ MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
+ MCP MCPConfig `json:"mcp"`
}
type SkillsToolsConfig struct {
@@ -508,6 +630,34 @@ type ClawHubRegistryConfig struct {
MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
}
+// MCPServerConfig defines configuration for a single MCP server
+type MCPServerConfig struct {
+ // Enabled indicates whether this MCP server is active
+ Enabled bool `json:"enabled"`
+ // Command is the executable to run (e.g., "npx", "python", "/path/to/server")
+ Command string `json:"command"`
+ // Args are the arguments to pass to the command
+ Args []string `json:"args,omitempty"`
+ // Env are environment variables to set for the server process (stdio only)
+ Env map[string]string `json:"env,omitempty"`
+ // EnvFile is the path to a file containing environment variables (stdio only)
+ EnvFile string `json:"env_file,omitempty"`
+ // Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
+ Type string `json:"type,omitempty"`
+ // URL is used for SSE/HTTP transport
+ URL string `json:"url,omitempty"`
+ // Headers are HTTP headers to send with requests (sse/http only)
+ Headers map[string]string `json:"headers,omitempty"`
+}
+
+// MCPConfig defines configuration for all MCP servers
+type MCPConfig struct {
+ // Enabled globally enables/disables MCP integration
+ Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
+ // Servers is a map of server name to server configuration
+ Servers map[string]MCPServerConfig `json:"servers,omitempty"`
+}
+
func LoadConfig(path string) (*Config, error) {
cfg := DefaultConfig()
@@ -541,6 +691,9 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
+ // Migrate legacy channel config fields to new unified structures
+ cfg.migrateChannelConfigs()
+
// Auto-migrate: if only legacy providers config exists, convert to model_list
if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() {
cfg.ModelList = ConvertProvidersToModelList(cfg)
@@ -554,18 +707,27 @@ func LoadConfig(path string) (*Config, error) {
return cfg, nil
}
+func (c *Config) migrateChannelConfigs() {
+ // Discord: mention_only -> group_trigger.mention_only
+ if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
+ c.Channels.Discord.GroupTrigger.MentionOnly = true
+ }
+
+ // OneBot: group_trigger_prefix -> group_trigger.prefixes
+ if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 &&
+ len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 {
+ c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix
+ }
+}
+
func SaveConfig(path string, cfg *Config) error {
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
- dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return err
- }
-
- return os.WriteFile(path, data, 0o600)
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ return fileutil.WriteFileAtomic(path, data, 0o600)
}
func (c *Config) WorkspacePath() string {
@@ -663,25 +825,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig {
// HasProvidersConfig checks if any provider in the old providers config has configuration.
func (c *Config) HasProvidersConfig() bool {
- v := c.Providers
- return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" ||
- v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" ||
- v.OpenRouter.APIKey != "" || v.OpenRouter.APIBase != "" ||
- v.Groq.APIKey != "" || v.Groq.APIBase != "" ||
- v.Zhipu.APIKey != "" || v.Zhipu.APIBase != "" ||
- v.VLLM.APIKey != "" || v.VLLM.APIBase != "" ||
- v.Gemini.APIKey != "" || v.Gemini.APIBase != "" ||
- v.Nvidia.APIKey != "" || v.Nvidia.APIBase != "" ||
- v.Ollama.APIKey != "" || v.Ollama.APIBase != "" ||
- v.Moonshot.APIKey != "" || v.Moonshot.APIBase != "" ||
- v.ShengSuanYun.APIKey != "" || v.ShengSuanYun.APIBase != "" ||
- v.DeepSeek.APIKey != "" || v.DeepSeek.APIBase != "" ||
- v.Cerebras.APIKey != "" || v.Cerebras.APIBase != "" ||
- v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" ||
- v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" ||
- v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
- v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
- v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
+ return !c.Providers.IsEmpty()
}
// ValidateModelList validates all ModelConfig entries in the model_list.
diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go
index 223ac798d..10ebc7c90 100644
--- a/pkg/config/config_test.go
+++ b/pkg/config/config_test.go
@@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"runtime"
+ "strings"
"testing"
)
@@ -210,8 +211,8 @@ func TestDefaultConfig_WorkspacePath(t *testing.T) {
func TestDefaultConfig_Model(t *testing.T) {
cfg := DefaultConfig()
- if cfg.Agents.Defaults.Model == "" {
- t.Error("Model should not be empty")
+ if cfg.Agents.Defaults.Model != "" {
+ t.Error("Model should be empty")
}
}
@@ -324,6 +325,25 @@ func TestSaveConfig_FilePermissions(t *testing.T) {
}
}
+func TestSaveConfig_IncludesEmptyLegacyModelField(t *testing.T) {
+ tmpDir := t.TempDir()
+ path := filepath.Join(tmpDir, "config.json")
+
+ cfg := DefaultConfig()
+ if err := SaveConfig(path, cfg); err != nil {
+ t.Fatalf("SaveConfig failed: %v", err)
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("ReadFile failed: %v", err)
+ }
+
+ if !strings.Contains(string(data), `"model": ""`) {
+ t.Fatalf("saved config should include empty legacy model field, got: %s", string(data))
+ }
+}
+
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
@@ -331,8 +351,8 @@ func TestConfig_Complete(t *testing.T) {
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
- if cfg.Agents.Defaults.Model == "" {
- t.Error("Model should not be empty")
+ if cfg.Agents.Defaults.Model != "" {
+ t.Error("Model should be empty")
}
if cfg.Agents.Defaults.Temperature != nil {
t.Error("Temperature should be nil when not provided")
@@ -413,3 +433,49 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
t.Fatalf("Tools.Web.Proxy = %q, want %q", cfg.Tools.Web.Proxy, "http://127.0.0.1:7890")
}
}
+
+// TestDefaultConfig_DMScope verifies the default dm_scope value
+// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
+func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
+ cfg := DefaultConfig()
+
+ if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 {
+ t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold)
+ }
+ if cfg.Agents.Defaults.SummarizeTokenPercent != 75 {
+ t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent)
+ }
+}
+
+func TestDefaultConfig_DMScope(t *testing.T) {
+ cfg := DefaultConfig()
+
+ if cfg.Session.DMScope != "per-channel-peer" {
+ t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope)
+ }
+}
+
+func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
+ // Unset to ensure we test the default
+ t.Setenv("PICOCLAW_HOME", "")
+ // Set a known home for consistent test results
+ t.Setenv("HOME", "/tmp/home")
+
+ cfg := DefaultConfig()
+ want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
+
+ if cfg.Agents.Defaults.Workspace != want {
+ t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
+ }
+}
+
+func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
+ t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
+
+ cfg := DefaultConfig()
+ want := "/custom/picoclaw/home/workspace"
+
+ if cfg.Agents.Defaults.Workspace != want {
+ t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
+ }
+}
diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go
index 47c51fc9d..19d877dea 100644
--- a/pkg/config/defaults.go
+++ b/pkg/config/defaults.go
@@ -5,34 +5,59 @@
package config
+import (
+ "os"
+ "path/filepath"
+)
+
// DefaultConfig returns the default configuration for PicoClaw.
func DefaultConfig() *Config {
+ // Determine the base path for the workspace.
+ // Priority: $PICOCLAW_HOME > ~/.picoclaw
+ var homePath string
+ if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
+ homePath = picoclawHome
+ } else {
+ userHome, _ := os.UserHomeDir()
+ homePath = filepath.Join(userHome, ".picoclaw")
+ }
+ workspacePath := filepath.Join(homePath, "workspace")
+
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
- Workspace: "~/.picoclaw/workspace",
- RestrictToWorkspace: true,
- Provider: "",
- Model: "glm-4.7",
- MaxTokens: 8192,
- Temperature: nil, // nil means use provider default
- MaxToolIterations: 20,
+ Workspace: workspacePath,
+ RestrictToWorkspace: true,
+ Provider: "",
+ Model: "",
+ MaxTokens: 32768,
+ Temperature: nil, // nil means use provider default
+ MaxToolIterations: 50,
+ SummarizeMessageThreshold: 20,
+ SummarizeTokenPercent: 75,
},
},
Bindings: []AgentBinding{},
Session: SessionConfig{
- DMScope: "main",
+ DMScope: "per-channel-peer",
},
Channels: ChannelsConfig{
WhatsApp: WhatsAppConfig{
- Enabled: false,
- BridgeURL: "ws://localhost:3001",
- AllowFrom: FlexibleStringSlice{},
+ Enabled: false,
+ BridgeURL: "ws://localhost:3001",
+ UseNative: false,
+ SessionStorePath: "",
+ AllowFrom: FlexibleStringSlice{},
},
Telegram: TelegramConfig{
Enabled: false,
Token: "",
AllowFrom: FlexibleStringSlice{},
+ Typing: TypingConfig{Enabled: true},
+ Placeholder: PlaceholderConfig{
+ Enabled: true,
+ Text: "Thinking... 💭",
+ },
},
Feishu: FeishuConfig{
Enabled: false,
@@ -80,6 +105,7 @@ func DefaultConfig() *Config {
WebhookPort: 18791,
WebhookPath: "/webhook/line",
AllowFrom: FlexibleStringSlice{},
+ GroupTrigger: GroupTriggerConfig{MentionOnly: true},
},
OneBot: OneBotConfig{
Enabled: false,
@@ -113,6 +139,25 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
ReplyTimeout: 5,
},
+ WeComAIBot: WeComAIBotConfig{
+ Enabled: false,
+ Token: "",
+ EncodingAESKey: "",
+ WebhookPath: "/webhook/wecom-aibot",
+ AllowFrom: FlexibleStringSlice{},
+ ReplyTimeout: 5,
+ MaxSteps: 10,
+ WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?",
+ },
+ Pico: PicoConfig{
+ Enabled: false,
+ Token: "",
+ PingInterval: 30,
+ ReadTimeout: 60,
+ WriteTimeout: 10,
+ MaxConnections: 100,
+ AllowFrom: FlexibleStringSlice{},
+ },
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{WebSearch: true},
@@ -276,8 +321,14 @@ func DefaultConfig() *Config {
Port: 18790,
},
Tools: ToolsConfig{
+ MediaCleanup: MediaCleanupConfig{
+ Enabled: true,
+ MaxAge: 30,
+ Interval: 5,
+ },
Web: WebToolsConfig{
- Proxy: "",
+ Proxy: "",
+ FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default
Brave: BraveConfig{
Enabled: false,
APIKey: "",
@@ -297,6 +348,13 @@ func DefaultConfig() *Config {
BaseURL: "",
MaxResults: 5,
},
+ GLMSearch: GLMSearchConfig{
+ Enabled: false,
+ APIKey: "",
+ BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search",
+ SearchEngine: "search_std",
+ MaxResults: 5,
+ },
},
Cron: CronToolsConfig{
ExecTimeoutMinutes: 5,
@@ -317,6 +375,10 @@ func DefaultConfig() *Config {
TTLSeconds: 300,
},
},
+ MCP: MCPConfig{
+ Enabled: false,
+ Servers: map[string]MCPServerConfig{},
+ },
},
Heartbeat: HeartbeatConfig{
Enabled: true,
diff --git a/pkg/config/migration.go b/pkg/config/migration.go
index 70e1de438..772f714fd 100644
--- a/pkg/config/migration.go
+++ b/pkg/config/migration.go
@@ -60,12 +60,13 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "openai",
- Model: "openai/gpt-5.2",
- APIKey: p.OpenAI.APIKey,
- APIBase: p.OpenAI.APIBase,
- Proxy: p.OpenAI.Proxy,
- AuthMethod: p.OpenAI.AuthMethod,
+ ModelName: "openai",
+ Model: "openai/gpt-5.2",
+ APIKey: p.OpenAI.APIKey,
+ APIBase: p.OpenAI.APIBase,
+ Proxy: p.OpenAI.Proxy,
+ RequestTimeout: p.OpenAI.RequestTimeout,
+ AuthMethod: p.OpenAI.AuthMethod,
}, true
},
},
@@ -77,12 +78,30 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "anthropic",
- Model: "anthropic/claude-sonnet-4.6",
- APIKey: p.Anthropic.APIKey,
- APIBase: p.Anthropic.APIBase,
- Proxy: p.Anthropic.Proxy,
- AuthMethod: p.Anthropic.AuthMethod,
+ ModelName: "anthropic",
+ Model: "anthropic/claude-sonnet-4.6",
+ APIKey: p.Anthropic.APIKey,
+ APIBase: p.Anthropic.APIBase,
+ Proxy: p.Anthropic.Proxy,
+ RequestTimeout: p.Anthropic.RequestTimeout,
+ AuthMethod: p.Anthropic.AuthMethod,
+ }, true
+ },
+ },
+ {
+ providerNames: []string{"litellm"},
+ protocol: "litellm",
+ buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
+ if p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" {
+ return ModelConfig{}, false
+ }
+ return ModelConfig{
+ ModelName: "litellm",
+ Model: "litellm/auto",
+ APIKey: p.LiteLLM.APIKey,
+ APIBase: p.LiteLLM.APIBase,
+ Proxy: p.LiteLLM.Proxy,
+ RequestTimeout: p.LiteLLM.RequestTimeout,
}, true
},
},
@@ -94,11 +113,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "openrouter",
- Model: "openrouter/auto",
- APIKey: p.OpenRouter.APIKey,
- APIBase: p.OpenRouter.APIBase,
- Proxy: p.OpenRouter.Proxy,
+ ModelName: "openrouter",
+ Model: "openrouter/auto",
+ APIKey: p.OpenRouter.APIKey,
+ APIBase: p.OpenRouter.APIBase,
+ Proxy: p.OpenRouter.Proxy,
+ RequestTimeout: p.OpenRouter.RequestTimeout,
}, true
},
},
@@ -110,11 +130,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "groq",
- Model: "groq/llama-3.1-70b-versatile",
- APIKey: p.Groq.APIKey,
- APIBase: p.Groq.APIBase,
- Proxy: p.Groq.Proxy,
+ ModelName: "groq",
+ Model: "groq/llama-3.1-70b-versatile",
+ APIKey: p.Groq.APIKey,
+ APIBase: p.Groq.APIBase,
+ Proxy: p.Groq.Proxy,
+ RequestTimeout: p.Groq.RequestTimeout,
}, true
},
},
@@ -126,11 +147,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "zhipu",
- Model: "zhipu/glm-4",
- APIKey: p.Zhipu.APIKey,
- APIBase: p.Zhipu.APIBase,
- Proxy: p.Zhipu.Proxy,
+ ModelName: "zhipu",
+ Model: "zhipu/glm-4",
+ APIKey: p.Zhipu.APIKey,
+ APIBase: p.Zhipu.APIBase,
+ Proxy: p.Zhipu.Proxy,
+ RequestTimeout: p.Zhipu.RequestTimeout,
}, true
},
},
@@ -142,11 +164,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "vllm",
- Model: "vllm/auto",
- APIKey: p.VLLM.APIKey,
- APIBase: p.VLLM.APIBase,
- Proxy: p.VLLM.Proxy,
+ ModelName: "vllm",
+ Model: "vllm/auto",
+ APIKey: p.VLLM.APIKey,
+ APIBase: p.VLLM.APIBase,
+ Proxy: p.VLLM.Proxy,
+ RequestTimeout: p.VLLM.RequestTimeout,
}, true
},
},
@@ -158,11 +181,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "gemini",
- Model: "gemini/gemini-pro",
- APIKey: p.Gemini.APIKey,
- APIBase: p.Gemini.APIBase,
- Proxy: p.Gemini.Proxy,
+ ModelName: "gemini",
+ Model: "gemini/gemini-pro",
+ APIKey: p.Gemini.APIKey,
+ APIBase: p.Gemini.APIBase,
+ Proxy: p.Gemini.Proxy,
+ RequestTimeout: p.Gemini.RequestTimeout,
}, true
},
},
@@ -174,11 +198,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "nvidia",
- Model: "nvidia/meta/llama-3.1-8b-instruct",
- APIKey: p.Nvidia.APIKey,
- APIBase: p.Nvidia.APIBase,
- Proxy: p.Nvidia.Proxy,
+ ModelName: "nvidia",
+ Model: "nvidia/meta/llama-3.1-8b-instruct",
+ APIKey: p.Nvidia.APIKey,
+ APIBase: p.Nvidia.APIBase,
+ Proxy: p.Nvidia.Proxy,
+ RequestTimeout: p.Nvidia.RequestTimeout,
}, true
},
},
@@ -190,11 +215,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "ollama",
- Model: "ollama/llama3",
- APIKey: p.Ollama.APIKey,
- APIBase: p.Ollama.APIBase,
- Proxy: p.Ollama.Proxy,
+ ModelName: "ollama",
+ Model: "ollama/llama3",
+ APIKey: p.Ollama.APIKey,
+ APIBase: p.Ollama.APIBase,
+ Proxy: p.Ollama.Proxy,
+ RequestTimeout: p.Ollama.RequestTimeout,
}, true
},
},
@@ -206,11 +232,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "moonshot",
- Model: "moonshot/kimi",
- APIKey: p.Moonshot.APIKey,
- APIBase: p.Moonshot.APIBase,
- Proxy: p.Moonshot.Proxy,
+ ModelName: "moonshot",
+ Model: "moonshot/kimi",
+ APIKey: p.Moonshot.APIKey,
+ APIBase: p.Moonshot.APIBase,
+ Proxy: p.Moonshot.Proxy,
+ RequestTimeout: p.Moonshot.RequestTimeout,
}, true
},
},
@@ -222,11 +249,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "shengsuanyun",
- Model: "shengsuanyun/auto",
- APIKey: p.ShengSuanYun.APIKey,
- APIBase: p.ShengSuanYun.APIBase,
- Proxy: p.ShengSuanYun.Proxy,
+ ModelName: "shengsuanyun",
+ Model: "shengsuanyun/auto",
+ APIKey: p.ShengSuanYun.APIKey,
+ APIBase: p.ShengSuanYun.APIBase,
+ Proxy: p.ShengSuanYun.Proxy,
+ RequestTimeout: p.ShengSuanYun.RequestTimeout,
}, true
},
},
@@ -238,11 +266,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "deepseek",
- Model: "deepseek/deepseek-chat",
- APIKey: p.DeepSeek.APIKey,
- APIBase: p.DeepSeek.APIBase,
- Proxy: p.DeepSeek.Proxy,
+ ModelName: "deepseek",
+ Model: "deepseek/deepseek-chat",
+ APIKey: p.DeepSeek.APIKey,
+ APIBase: p.DeepSeek.APIBase,
+ Proxy: p.DeepSeek.Proxy,
+ RequestTimeout: p.DeepSeek.RequestTimeout,
}, true
},
},
@@ -254,11 +283,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "cerebras",
- Model: "cerebras/llama-3.3-70b",
- APIKey: p.Cerebras.APIKey,
- APIBase: p.Cerebras.APIBase,
- Proxy: p.Cerebras.Proxy,
+ ModelName: "cerebras",
+ Model: "cerebras/llama-3.3-70b",
+ APIKey: p.Cerebras.APIKey,
+ APIBase: p.Cerebras.APIBase,
+ Proxy: p.Cerebras.Proxy,
+ RequestTimeout: p.Cerebras.RequestTimeout,
}, true
},
},
@@ -270,11 +300,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "volcengine",
- Model: "volcengine/doubao-pro",
- APIKey: p.VolcEngine.APIKey,
- APIBase: p.VolcEngine.APIBase,
- Proxy: p.VolcEngine.Proxy,
+ ModelName: "volcengine",
+ Model: "volcengine/doubao-pro",
+ APIKey: p.VolcEngine.APIKey,
+ APIBase: p.VolcEngine.APIBase,
+ Proxy: p.VolcEngine.Proxy,
+ RequestTimeout: p.VolcEngine.RequestTimeout,
}, true
},
},
@@ -316,11 +347,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "qwen",
- Model: "qwen/qwen-max",
- APIKey: p.Qwen.APIKey,
- APIBase: p.Qwen.APIBase,
- Proxy: p.Qwen.Proxy,
+ ModelName: "qwen",
+ Model: "qwen/qwen-max",
+ APIKey: p.Qwen.APIKey,
+ APIBase: p.Qwen.APIBase,
+ Proxy: p.Qwen.Proxy,
+ RequestTimeout: p.Qwen.RequestTimeout,
}, true
},
},
@@ -332,11 +364,12 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
return ModelConfig{}, false
}
return ModelConfig{
- ModelName: "mistral",
- Model: "mistral/mistral-small-latest",
- APIKey: p.Mistral.APIKey,
- APIBase: p.Mistral.APIBase,
- Proxy: p.Mistral.Proxy,
+ ModelName: "mistral",
+ Model: "mistral/mistral-small-latest",
+ APIKey: p.Mistral.APIKey,
+ APIBase: p.Mistral.APIBase,
+ Proxy: p.Mistral.Proxy,
+ RequestTimeout: p.Mistral.RequestTimeout,
}, true
},
},
diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go
index 42165cb71..e24e9fa1d 100644
--- a/pkg/config/migration_test.go
+++ b/pkg/config/migration_test.go
@@ -63,6 +63,33 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
}
}
+func TestConvertProvidersToModelList_LiteLLM(t *testing.T) {
+ cfg := &Config{
+ Providers: ProvidersConfig{
+ LiteLLM: ProviderConfig{
+ APIKey: "litellm-key",
+ APIBase: "http://localhost:4000/v1",
+ },
+ },
+ }
+
+ result := ConvertProvidersToModelList(cfg)
+
+ if len(result) != 1 {
+ t.Fatalf("len(result) = %d, want 1", len(result))
+ }
+
+ if result[0].ModelName != "litellm" {
+ t.Errorf("ModelName = %q, want %q", result[0].ModelName, "litellm")
+ }
+ if result[0].Model != "litellm/auto" {
+ t.Errorf("Model = %q, want %q", result[0].Model, "litellm/auto")
+ }
+ if result[0].APIBase != "http://localhost:4000/v1" {
+ t.Errorf("APIBase = %q, want %q", result[0].APIBase, "http://localhost:4000/v1")
+ }
+}
+
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
@@ -115,6 +142,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}},
+ LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"},
Anthropic: ProviderConfig{APIKey: "key2"},
OpenRouter: ProviderConfig{APIKey: "key3"},
Groq: ProviderConfig{APIKey: "key4"},
@@ -137,9 +165,9 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
result := ConvertProvidersToModelList(cfg)
- // All 18 providers should be converted
- if len(result) != 18 {
- t.Errorf("len(result) = %d, want 18", len(result))
+ // All 19 providers should be converted
+ if len(result) != 19 {
+ t.Errorf("len(result) = %d, want 19", len(result))
}
}
@@ -166,6 +194,27 @@ func TestConvertProvidersToModelList_Proxy(t *testing.T) {
}
}
+func TestConvertProvidersToModelList_RequestTimeout(t *testing.T) {
+ cfg := &Config{
+ Providers: ProvidersConfig{
+ Ollama: ProviderConfig{
+ APIKey: "ollama-key",
+ RequestTimeout: 300,
+ },
+ },
+ }
+
+ result := ConvertProvidersToModelList(cfg)
+
+ if len(result) != 1 {
+ t.Fatalf("len(result) = %d, want 1", len(result))
+ }
+
+ if result[0].RequestTimeout != 300 {
+ t.Errorf("RequestTimeout = %d, want %d", result[0].RequestTimeout, 300)
+ }
+}
+
func TestConvertProvidersToModelList_AuthMethod(t *testing.T) {
cfg := &Config{
Providers: ProvidersConfig{
diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go
index 99eea2782..da6e506f8 100644
--- a/pkg/config/model_config_test.go
+++ b/pkg/config/model_config_test.go
@@ -64,7 +64,7 @@ func TestGetModelConfig_RoundRobin(t *testing.T) {
// Test round-robin distribution
results := make(map[string]int)
- for i := 0; i < 30; i++ {
+ for range 30 {
result, err := cfg.GetModelConfig("lb-model")
if err != nil {
t.Fatalf("GetModelConfig() error = %v", err)
@@ -94,17 +94,15 @@ func TestGetModelConfig_Concurrent(t *testing.T) {
var wg sync.WaitGroup
errors := make(chan error, goroutines*iterations)
- for i := 0; i < goroutines; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for j := 0; j < iterations; j++ {
+ for range goroutines {
+ wg.Go(func() {
+ for range iterations {
_, err := cfg.GetModelConfig("concurrent-model")
if err != nil {
errors <- err
}
}
- }()
+ })
}
wg.Wait()
@@ -365,3 +363,38 @@ func TestConfig_ValidateModelList(t *testing.T) {
})
}
}
+
+func TestModelConfig_RequestTimeoutParsing(t *testing.T) {
+ jsonData := `{
+ "model_name": "slow-local",
+ "model": "openai/local-model",
+ "api_base": "http://localhost:11434/v1",
+ "request_timeout": 300
+ }`
+
+ var cfg ModelConfig
+ if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+
+ if cfg.RequestTimeout != 300 {
+ t.Fatalf("RequestTimeout = %d, want 300", cfg.RequestTimeout)
+ }
+}
+
+func TestModelConfig_RequestTimeoutDefaultZeroValue(t *testing.T) {
+ jsonData := `{
+ "model_name": "default-timeout",
+ "model": "openai/gpt-4o",
+ "api_key": "test-key"
+ }`
+
+ var cfg ModelConfig
+ if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+
+ if cfg.RequestTimeout != 0 {
+ t.Fatalf("RequestTimeout = %d, want 0", cfg.RequestTimeout)
+ }
+}
diff --git a/pkg/cron/service.go b/pkg/cron/service.go
index e699a44b5..6962041c1 100644
--- a/pkg/cron/service.go
+++ b/pkg/cron/service.go
@@ -7,11 +7,12 @@ import (
"fmt"
"log"
"os"
- "path/filepath"
"sync"
"time"
"github.com/adhocore/gronx"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
)
type CronSchedule struct {
@@ -330,17 +331,13 @@ func (cs *CronService) loadStore() error {
}
func (cs *CronService) saveStoreUnsafe() error {
- dir := filepath.Dir(cs.storePath)
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return err
- }
-
data, err := json.MarshalIndent(cs.store, "", " ")
if err != nil {
return err
}
- return os.WriteFile(cs.storePath, data, 0o600)
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ return fileutil.WriteFileAtomic(cs.storePath, data, 0o600)
}
func (cs *CronService) AddJob(
diff --git a/pkg/devices/service.go b/pkg/devices/service.go
index 1541d3c57..1bafe6085 100644
--- a/pkg/devices/service.go
+++ b/pkg/devices/service.go
@@ -4,6 +4,7 @@ import (
"context"
"strings"
"sync"
+ "time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
@@ -127,7 +128,9 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
}
msg := ev.FormatMessage()
- msgBus.PublishOutbound(bus.OutboundMessage{
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer pubCancel()
+ msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: msg,
diff --git a/pkg/fileutil/file.go b/pkg/fileutil/file.go
new file mode 100644
index 000000000..7ca872374
--- /dev/null
+++ b/pkg/fileutil/file.go
@@ -0,0 +1,119 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
+// License: MIT
+//
+// Copyright (c) 2026 PicoClaw contributors
+
+// Package fileutil provides file manipulation utilities.
+package fileutil
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+)
+
+// WriteFileAtomic atomically writes data to a file using a temp file + rename pattern.
+//
+// This guarantees that the target file is either:
+// - Completely written with the new data
+// - Unchanged (if any step fails before rename)
+//
+// The function:
+// 1. Creates a temp file in the same directory (original untouched)
+// 2. Writes data to temp file
+// 3. Syncs data to disk (critical for SD cards/flash storage)
+// 4. Sets file permissions
+// 5. Syncs directory metadata (ensures rename is durable)
+// 6. Atomically renames temp file to target path
+//
+// Safety guarantees:
+// - Original file is NEVER modified until successful rename
+// - Temp file is always cleaned up on error
+// - Data is flushed to physical storage before rename
+// - Directory entry is synced to prevent orphaned inodes
+//
+// Parameters:
+// - path: Target file path
+// - data: Data to write
+// - perm: File permission mode (e.g., 0o600 for secure, 0o644 for readable)
+//
+// Returns:
+// - Error if any step fails, nil on success
+//
+// Example:
+//
+// // Secure config file (owner read/write only)
+// err := utils.WriteFileAtomic("config.json", data, 0o600)
+//
+// // Public readable file
+// err := utils.WriteFileAtomic("public.txt", data, 0o644)
+func WriteFileAtomic(path string, data []byte, perm os.FileMode) error {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return fmt.Errorf("failed to create directory: %w", err)
+ }
+
+ // Create temp file in the same directory (ensures atomic rename works)
+ // Using a hidden prefix (.tmp-) to avoid issues with some tools
+ tmpFile, err := os.OpenFile(
+ filepath.Join(dir, fmt.Sprintf(".tmp-%d-%d", os.Getpid(), time.Now().UnixNano())),
+ os.O_WRONLY|os.O_CREATE|os.O_EXCL,
+ perm,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to create temp file: %w", err)
+ }
+
+ tmpPath := tmpFile.Name()
+ cleanup := true
+
+ defer func() {
+ if cleanup {
+ tmpFile.Close()
+ _ = os.Remove(tmpPath)
+ }
+ }()
+
+ // Write data to temp file
+ // Note: Original file is untouched at this point
+ if _, err := tmpFile.Write(data); err != nil {
+ return fmt.Errorf("failed to write temp file: %w", err)
+ }
+
+ // CRITICAL: Force sync to storage medium before any other operations.
+ // This ensures data is physically written to disk, not just cached.
+ // Essential for SD cards, eMMC, and other flash storage on edge devices.
+ if err := tmpFile.Sync(); err != nil {
+ return fmt.Errorf("failed to sync temp file: %w", err)
+ }
+
+ // Set file permissions before closing
+ if err := tmpFile.Chmod(perm); err != nil {
+ return fmt.Errorf("failed to set permissions: %w", err)
+ }
+
+ // Close file before rename (required on Windows)
+ if err := tmpFile.Close(); err != nil {
+ return fmt.Errorf("failed to close temp file: %w", err)
+ }
+
+ // Atomic rename: temp file becomes the target
+ // On POSIX: rename() is atomic
+ // On Windows: Rename() is atomic for files
+ if err := os.Rename(tmpPath, path); err != nil {
+ return fmt.Errorf("failed to rename temp file: %w", err)
+ }
+
+ // Sync directory to ensure rename is durable
+ // This prevents the renamed file from disappearing after a crash
+ if dirFile, err := os.Open(dir); err == nil {
+ _ = dirFile.Sync()
+ dirFile.Close()
+ }
+
+ // Success: skip cleanup (file was renamed, no temp to remove)
+ cleanup = false
+ return nil
+}
diff --git a/pkg/health/server.go b/pkg/health/server.go
index 77b36034d..5609ebdf6 100644
--- a/pkg/health/server.go
+++ b/pkg/health/server.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "maps"
"net/http"
"sync"
"time"
@@ -122,9 +123,7 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
s.mu.RLock()
ready := s.ready
checks := make(map[string]Check)
- for k, v := range s.checks {
- checks[k] = v
- }
+ maps.Copy(checks, s.checks)
s.mu.RUnlock()
if !ready {
@@ -156,6 +155,13 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
})
}
+// RegisterOnMux registers /health and /ready handlers onto the given mux.
+// This allows the health endpoints to be served by a shared HTTP server.
+func (s *Server) RegisterOnMux(mux *http.ServeMux) {
+ mux.HandleFunc("/health", s.healthHandler)
+ mux.HandleFunc("/ready", s.readyHandler)
+}
+
func statusString(ok bool) string {
if ok {
return "ok"
diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go
index e05a9fdbf..09c93fc6b 100644
--- a/pkg/heartbeat/service.go
+++ b/pkg/heartbeat/service.go
@@ -7,6 +7,7 @@
package heartbeat
import (
+ "context"
"fmt"
"os"
"path/filepath"
@@ -16,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
+ "github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -275,7 +277,7 @@ This file contains tasks for the heartbeat service to check periodically.
Add your heartbeat tasks below this line:
`
- if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0o644); err != nil {
+ if err := fileutil.WriteFileAtomic(heartbeatPath, []byte(defaultContent), 0o644); err != nil {
hs.logErrorf("Failed to create default HEARTBEAT.md: %v", err)
} else {
hs.logInfof("Created default HEARTBEAT.md template")
@@ -307,7 +309,9 @@ func (hs *HeartbeatService) sendResponse(response string) {
return
}
- msgBus.PublishOutbound(bus.OutboundMessage{
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer pubCancel()
+ msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: response,
diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go
index a7aef8c3a..3b7eeeefb 100644
--- a/pkg/heartbeat/service_test.go
+++ b/pkg/heartbeat/service_test.go
@@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
}
}
-func TestExecuteHeartbeat_Error(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- hs := NewHeartbeatService(tmpDir, 30, true)
- hs.stopChan = make(chan struct{}) // Enable for testing
-
- hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
- return &tools.ToolResult{
- ForLLM: "Heartbeat failed: connection error",
- ForUser: "",
- Silent: false,
- IsError: true,
- Async: false,
- }
- })
-
- // Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
-
- hs.executeHeartbeat()
-
- // Check log file for error message
- logFile := filepath.Join(tmpDir, "heartbeat.log")
- data, err := os.ReadFile(logFile)
- if err != nil {
- t.Fatalf("Failed to read log file: %v", err)
+func TestExecuteHeartbeat_ResultLogging(t *testing.T) {
+ tests := []struct {
+ name string
+ result *tools.ToolResult
+ wantLog string
+ }{
+ {
+ name: "error result",
+ result: &tools.ToolResult{
+ ForLLM: "Heartbeat failed: connection error",
+ ForUser: "",
+ Silent: false,
+ IsError: true,
+ Async: false,
+ },
+ wantLog: "error message",
+ },
+ {
+ name: "silent result",
+ result: &tools.ToolResult{
+ ForLLM: "Heartbeat completed successfully",
+ ForUser: "",
+ Silent: true,
+ IsError: false,
+ Async: false,
+ },
+ wantLog: "completion message",
+ },
}
- logContent := string(data)
- if logContent == "" {
- t.Error("Expected log file to contain error message")
- }
-}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
-func TestExecuteHeartbeat_Silent(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
- if err != nil {
- t.Fatalf("Failed to create temp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
+ hs := NewHeartbeatService(tmpDir, 30, true)
+ hs.stopChan = make(chan struct{}) // Enable for testing
- hs := NewHeartbeatService(tmpDir, 30, true)
- hs.stopChan = make(chan struct{}) // Enable for testing
+ hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
+ return tt.result
+ })
- hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
- return &tools.ToolResult{
- ForLLM: "Heartbeat completed successfully",
- ForUser: "",
- Silent: true,
- IsError: false,
- Async: false,
- }
- })
+ os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
+ hs.executeHeartbeat()
- // Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
-
- hs.executeHeartbeat()
-
- // Check log file for completion message
- logFile := filepath.Join(tmpDir, "heartbeat.log")
- data, err := os.ReadFile(logFile)
- if err != nil {
- t.Fatalf("Failed to read log file: %v", err)
- }
-
- logContent := string(data)
- if logContent == "" {
- t.Error("Expected log file to contain completion message")
+ logFile := filepath.Join(tmpDir, "heartbeat.log")
+ data, err := os.ReadFile(logFile)
+ if err != nil {
+ t.Fatalf("Failed to read log file: %v", err)
+ }
+ if string(data) == "" {
+ t.Errorf("Expected log file to contain %s", tt.wantLog)
+ }
+ })
}
}
diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go
new file mode 100644
index 000000000..6bc09c210
--- /dev/null
+++ b/pkg/identity/identity.go
@@ -0,0 +1,107 @@
+// Package identity provides unified user identity utilities for PicoClaw.
+// It introduces a canonical "platform:id" format and matching logic
+// that is backward-compatible with all legacy allow-list formats.
+package identity
+
+import (
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+)
+
+// BuildCanonicalID constructs a canonical "platform:id" identifier.
+// Both platform and platformID are lowercased and trimmed.
+func BuildCanonicalID(platform, platformID string) string {
+ p := strings.ToLower(strings.TrimSpace(platform))
+ id := strings.TrimSpace(platformID)
+ if p == "" || id == "" {
+ return ""
+ }
+ return p + ":" + id
+}
+
+// ParseCanonicalID splits a canonical ID ("platform:id") into its parts.
+// Returns ok=false if the input does not contain a colon separator.
+func ParseCanonicalID(canonical string) (platform, id string, ok bool) {
+ canonical = strings.TrimSpace(canonical)
+ idx := strings.Index(canonical, ":")
+ if idx <= 0 || idx == len(canonical)-1 {
+ return "", "", false
+ }
+ return canonical[:idx], canonical[idx+1:], true
+}
+
+// MatchAllowed checks whether the given sender matches a single allow-list entry.
+// It is backward-compatible with all legacy formats:
+//
+// - "123456" → matches sender.PlatformID
+// - "@alice" → matches sender.Username
+// - "123456|alice" → matches PlatformID or Username
+// - "telegram:123456" → exact match on sender.CanonicalID
+func MatchAllowed(sender bus.SenderInfo, allowed string) bool {
+ allowed = strings.TrimSpace(allowed)
+ if allowed == "" {
+ return false
+ }
+
+ // Try canonical match first: "platform:id" format
+ if platform, id, ok := ParseCanonicalID(allowed); ok {
+ // Only treat as canonical if the platform portion looks like a known platform name
+ // (not a pure-numeric string, which could be a compound ID)
+ if !isNumeric(platform) {
+ candidate := BuildCanonicalID(platform, id)
+ if candidate != "" && sender.CanonicalID != "" {
+ return strings.EqualFold(sender.CanonicalID, candidate)
+ }
+ // If sender has no canonical ID, try matching platform + platformID
+ return strings.EqualFold(platform, sender.Platform) &&
+ sender.PlatformID == id
+ }
+ }
+
+ // Strip leading "@" for username matching
+ trimmed := strings.TrimPrefix(allowed, "@")
+
+ // Split compound "id|username" format
+ allowedID := trimmed
+ allowedUser := ""
+ if idx := strings.Index(trimmed, "|"); idx > 0 {
+ allowedID = trimmed[:idx]
+ allowedUser = trimmed[idx+1:]
+ }
+
+ // Match against PlatformID
+ if sender.PlatformID != "" && sender.PlatformID == allowedID {
+ return true
+ }
+
+ // Match against Username
+ if sender.Username != "" {
+ if sender.Username == trimmed || sender.Username == allowedUser {
+ return true
+ }
+ }
+
+ // Match compound sender format against allowed parts
+ if allowedUser != "" && sender.PlatformID != "" && sender.PlatformID == allowedID {
+ return true
+ }
+ if allowedUser != "" && sender.Username != "" && sender.Username == allowedUser {
+ return true
+ }
+
+ return false
+}
+
+// isNumeric returns true if s consists entirely of digits.
+func isNumeric(s string) bool {
+ if s == "" {
+ return false
+ }
+ for _, r := range s {
+ if r < '0' || r > '9' {
+ return false
+ }
+ }
+ return true
+}
diff --git a/pkg/identity/identity_test.go b/pkg/identity/identity_test.go
new file mode 100644
index 000000000..3d24bd794
--- /dev/null
+++ b/pkg/identity/identity_test.go
@@ -0,0 +1,229 @@
+package identity
+
+import (
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+)
+
+func TestBuildCanonicalID(t *testing.T) {
+ tests := []struct {
+ platform string
+ platformID string
+ want string
+ }{
+ {"telegram", "123456", "telegram:123456"},
+ {"Discord", "98765432", "discord:98765432"},
+ {"SLACK", "U123ABC", "slack:U123ABC"},
+ {"", "123", ""},
+ {"telegram", "", ""},
+ {" telegram ", " 123 ", "telegram:123"},
+ }
+
+ for _, tt := range tests {
+ got := BuildCanonicalID(tt.platform, tt.platformID)
+ if got != tt.want {
+ t.Errorf("BuildCanonicalID(%q, %q) = %q, want %q",
+ tt.platform, tt.platformID, got, tt.want)
+ }
+ }
+}
+
+func TestParseCanonicalID(t *testing.T) {
+ tests := []struct {
+ input string
+ wantPlatform string
+ wantID string
+ wantOk bool
+ }{
+ {"telegram:123456", "telegram", "123456", true},
+ {"discord:98765432", "discord", "98765432", true},
+ {"slack:U123ABC", "slack", "U123ABC", true},
+ {"nocolon", "", "", false},
+ {"", "", "", false},
+ {":missing", "", "", false},
+ {"missing:", "", "", false},
+ }
+
+ for _, tt := range tests {
+ platform, id, ok := ParseCanonicalID(tt.input)
+ if ok != tt.wantOk || platform != tt.wantPlatform || id != tt.wantID {
+ t.Errorf("ParseCanonicalID(%q) = (%q, %q, %v), want (%q, %q, %v)",
+ tt.input, platform, id, ok,
+ tt.wantPlatform, tt.wantID, tt.wantOk)
+ }
+ }
+}
+
+func TestMatchAllowed(t *testing.T) {
+ telegramSender := bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "123456",
+ CanonicalID: "telegram:123456",
+ Username: "alice",
+ DisplayName: "Alice Smith",
+ }
+
+ discordSender := bus.SenderInfo{
+ Platform: "discord",
+ PlatformID: "98765432",
+ CanonicalID: "discord:98765432",
+ Username: "bob",
+ DisplayName: "bob#1234",
+ }
+
+ noCanonicalSender := bus.SenderInfo{
+ Platform: "telegram",
+ PlatformID: "999",
+ Username: "carol",
+ }
+
+ tests := []struct {
+ name string
+ sender bus.SenderInfo
+ allowed string
+ want bool
+ }{
+ // Pure numeric ID matching
+ {
+ name: "numeric ID matches PlatformID",
+ sender: telegramSender,
+ allowed: "123456",
+ want: true,
+ },
+ {
+ name: "numeric ID does not match",
+ sender: telegramSender,
+ allowed: "654321",
+ want: false,
+ },
+ // Username matching
+ {
+ name: "@username matches Username",
+ sender: telegramSender,
+ allowed: "@alice",
+ want: true,
+ },
+ {
+ name: "@username does not match",
+ sender: telegramSender,
+ allowed: "@bob",
+ want: false,
+ },
+ // Compound format "id|username"
+ {
+ name: "compound matches by ID",
+ sender: telegramSender,
+ allowed: "123456|alice",
+ want: true,
+ },
+ {
+ name: "compound matches by username",
+ sender: telegramSender,
+ allowed: "999|alice",
+ want: true,
+ },
+ {
+ name: "compound does not match",
+ sender: telegramSender,
+ allowed: "654321|bob",
+ want: false,
+ },
+ // Canonical format "platform:id"
+ {
+ name: "canonical matches exactly",
+ sender: telegramSender,
+ allowed: "telegram:123456",
+ want: true,
+ },
+ {
+ name: "canonical case-insensitive platform",
+ sender: telegramSender,
+ allowed: "Telegram:123456",
+ want: true,
+ },
+ {
+ name: "canonical wrong platform",
+ sender: telegramSender,
+ allowed: "discord:123456",
+ want: false,
+ },
+ {
+ name: "canonical wrong ID",
+ sender: telegramSender,
+ allowed: "telegram:654321",
+ want: false,
+ },
+ // Cross-platform canonical
+ {
+ name: "discord canonical match",
+ sender: discordSender,
+ allowed: "discord:98765432",
+ want: true,
+ },
+ {
+ name: "telegram canonical does not match discord sender",
+ sender: discordSender,
+ allowed: "telegram:98765432",
+ want: false,
+ },
+ // Sender without canonical ID
+ {
+ name: "canonical match falls back to platform+platformID",
+ sender: noCanonicalSender,
+ allowed: "telegram:999",
+ want: true,
+ },
+ {
+ name: "platform mismatch on fallback",
+ sender: noCanonicalSender,
+ allowed: "discord:999",
+ want: false,
+ },
+ // Empty allowed string
+ {
+ name: "empty allowed never matches",
+ sender: telegramSender,
+ allowed: "",
+ want: false,
+ },
+ // Whitespace handling
+ {
+ name: "trimmed allowed matches",
+ sender: telegramSender,
+ allowed: " 123456 ",
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := MatchAllowed(tt.sender, tt.allowed)
+ if got != tt.want {
+ t.Errorf("MatchAllowed(%+v, %q) = %v, want %v",
+ tt.sender, tt.allowed, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsNumeric(t *testing.T) {
+ tests := []struct {
+ input string
+ want bool
+ }{
+ {"123456", true},
+ {"0", true},
+ {"", false},
+ {"abc", false},
+ {"12a34", false},
+ {"telegram", false},
+ }
+
+ for _, tt := range tests {
+ got := isNumeric(tt.input)
+ if got != tt.want {
+ t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go
new file mode 100644
index 000000000..7b63cc979
--- /dev/null
+++ b/pkg/mcp/manager.go
@@ -0,0 +1,532 @@
+package mcp
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// headerTransport is an http.RoundTripper that adds custom headers to requests
+type headerTransport struct {
+ base http.RoundTripper
+ headers map[string]string
+}
+
+func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ // Clone the request to avoid modifying the original
+ req = req.Clone(req.Context())
+
+ // Add custom headers
+ for key, value := range t.headers {
+ req.Header.Set(key, value)
+ }
+
+ // Use the base transport
+ base := t.base
+ if base == nil {
+ base = http.DefaultTransport
+ }
+ return base.RoundTrip(req)
+}
+
+// loadEnvFile loads environment variables from a file in .env format
+// Each line should be in the format: KEY=value
+// Lines starting with # are comments
+// Empty lines are ignored
+func loadEnvFile(path string) (map[string]string, error) {
+ file, err := os.Open(path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open env file: %w", err)
+ }
+ defer file.Close()
+
+ envVars := make(map[string]string)
+ scanner := bufio.NewScanner(file)
+ lineNum := 0
+
+ for scanner.Scan() {
+ lineNum++
+ line := strings.TrimSpace(scanner.Text())
+
+ // Skip empty lines and comments
+ if line == "" || strings.HasPrefix(line, "#") {
+ continue
+ }
+
+ // Parse KEY=value
+ parts := strings.SplitN(line, "=", 2)
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
+ }
+
+ key := strings.TrimSpace(parts[0])
+ value := strings.TrimSpace(parts[1])
+
+ if key == "" {
+ return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum)
+ }
+
+ // Remove surrounding quotes if present
+ if len(value) >= 2 {
+ if (value[0] == '"' && value[len(value)-1] == '"') ||
+ (value[0] == '\'' && value[len(value)-1] == '\'') {
+ value = value[1 : len(value)-1]
+ }
+ }
+
+ envVars[key] = value
+ }
+
+ if err := scanner.Err(); err != nil {
+ return nil, fmt.Errorf("error reading env file: %w", err)
+ }
+
+ return envVars, nil
+}
+
+// ServerConnection represents a connection to an MCP server
+type ServerConnection struct {
+ Name string
+ Client *mcp.Client
+ Session *mcp.ClientSession
+ Tools []*mcp.Tool
+}
+
+// Manager manages multiple MCP server connections
+type Manager struct {
+ servers map[string]*ServerConnection
+ mu sync.RWMutex
+ closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
+ wg sync.WaitGroup // tracks in-flight CallTool calls
+}
+
+// NewManager creates a new MCP manager
+func NewManager() *Manager {
+ return &Manager{
+ servers: make(map[string]*ServerConnection),
+ }
+}
+
+// LoadFromConfig loads MCP servers from configuration
+func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
+ return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
+}
+
+// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
+// This is the minimal dependency version that doesn't require the full Config object.
+func (m *Manager) LoadFromMCPConfig(
+ ctx context.Context,
+ mcpCfg config.MCPConfig,
+ workspacePath string,
+) error {
+ if !mcpCfg.Enabled {
+ logger.InfoCF("mcp", "MCP integration is disabled", nil)
+ return nil
+ }
+
+ if len(mcpCfg.Servers) == 0 {
+ logger.InfoCF("mcp", "No MCP servers configured", nil)
+ return nil
+ }
+
+ logger.InfoCF("mcp", "Initializing MCP servers",
+ map[string]any{
+ "count": len(mcpCfg.Servers),
+ })
+
+ var wg sync.WaitGroup
+ errs := make(chan error, len(mcpCfg.Servers))
+ enabledCount := 0
+
+ for name, serverCfg := range mcpCfg.Servers {
+ if !serverCfg.Enabled {
+ logger.DebugCF("mcp", "Skipping disabled server",
+ map[string]any{
+ "server": name,
+ })
+ continue
+ }
+
+ enabledCount++
+ wg.Add(1)
+ go func(name string, serverCfg config.MCPServerConfig, workspace string) {
+ defer wg.Done()
+
+ // Resolve relative envFile paths relative to workspace
+ if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
+ if workspace == "" {
+ err := fmt.Errorf(
+ "workspace path is empty while resolving relative envFile %q for server %s",
+ serverCfg.EnvFile,
+ name,
+ )
+ logger.ErrorCF("mcp", "Invalid MCP server configuration",
+ map[string]any{
+ "server": name,
+ "env_file": serverCfg.EnvFile,
+ "error": err.Error(),
+ })
+ errs <- err
+ return
+ }
+ serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
+ }
+
+ if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
+ logger.ErrorCF("mcp", "Failed to connect to MCP server",
+ map[string]any{
+ "server": name,
+ "error": err.Error(),
+ })
+ errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
+ }
+ }(name, serverCfg, workspacePath)
+ }
+
+ wg.Wait()
+ close(errs)
+
+ // Collect errors
+ var allErrors []error
+ for err := range errs {
+ allErrors = append(allErrors, err)
+ }
+
+ connectedCount := len(m.GetServers())
+
+ // If all enabled servers failed to connect, return aggregated error
+ if enabledCount > 0 && connectedCount == 0 {
+ logger.ErrorCF("mcp", "All MCP servers failed to connect",
+ map[string]any{
+ "failed": len(allErrors),
+ "total": enabledCount,
+ })
+ return errors.Join(allErrors...)
+ }
+
+ if len(allErrors) > 0 {
+ logger.WarnCF("mcp", "Some MCP servers failed to connect",
+ map[string]any{
+ "failed": len(allErrors),
+ "connected": connectedCount,
+ "total": enabledCount,
+ })
+ // Don't fail completely if some servers successfully connected
+ }
+
+ logger.InfoCF("mcp", "MCP server initialization complete",
+ map[string]any{
+ "connected": connectedCount,
+ "total": enabledCount,
+ })
+
+ return nil
+}
+
+// ConnectServer connects to a single MCP server
+func (m *Manager) ConnectServer(
+ ctx context.Context,
+ name string,
+ cfg config.MCPServerConfig,
+) error {
+ logger.InfoCF("mcp", "Connecting to MCP server",
+ map[string]any{
+ "server": name,
+ "command": cfg.Command,
+ "args_count": len(cfg.Args),
+ })
+
+ // Create client
+ client := mcp.NewClient(&mcp.Implementation{
+ Name: "picoclaw",
+ Version: "1.0.0",
+ }, nil)
+
+ // Create transport based on configuration
+ // Auto-detect transport type if not explicitly specified
+ var transport mcp.Transport
+ transportType := cfg.Type
+
+ // Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
+ if transportType == "" {
+ if cfg.URL != "" {
+ transportType = "sse"
+ } else if cfg.Command != "" {
+ transportType = "stdio"
+ } else {
+ return fmt.Errorf("either URL or command must be provided")
+ }
+ }
+
+ switch transportType {
+ case "sse", "http":
+ if cfg.URL == "" {
+ return fmt.Errorf("URL is required for SSE/HTTP transport")
+ }
+ logger.DebugCF("mcp", "Using SSE/HTTP transport",
+ map[string]any{
+ "server": name,
+ "url": cfg.URL,
+ })
+
+ sseTransport := &mcp.StreamableClientTransport{
+ Endpoint: cfg.URL,
+ }
+
+ // Add custom headers if provided
+ if len(cfg.Headers) > 0 {
+ // Create a custom HTTP client with header-injecting transport
+ sseTransport.HTTPClient = &http.Client{
+ Transport: &headerTransport{
+ base: http.DefaultTransport,
+ headers: cfg.Headers,
+ },
+ }
+ logger.DebugCF("mcp", "Added custom HTTP headers",
+ map[string]any{
+ "server": name,
+ "header_count": len(cfg.Headers),
+ })
+ }
+
+ transport = sseTransport
+ case "stdio":
+ if cfg.Command == "" {
+ return fmt.Errorf("command is required for stdio transport")
+ }
+ logger.DebugCF("mcp", "Using stdio transport",
+ map[string]any{
+ "server": name,
+ "command": cfg.Command,
+ })
+ // Create command with context
+ cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
+
+ // Build environment variables with proper override semantics
+ // Use a map to ensure config variables override file variables
+ envMap := make(map[string]string)
+
+ // Start with parent process environment
+ for _, e := range cmd.Environ() {
+ if idx := strings.Index(e, "="); idx > 0 {
+ envMap[e[:idx]] = e[idx+1:]
+ }
+ }
+
+ // Load environment variables from file if specified
+ if cfg.EnvFile != "" {
+ envVars, err := loadEnvFile(cfg.EnvFile)
+ if err != nil {
+ return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
+ }
+ for k, v := range envVars {
+ envMap[k] = v
+ }
+ logger.DebugCF("mcp", "Loaded environment variables from file",
+ map[string]any{
+ "server": name,
+ "envFile": cfg.EnvFile,
+ "var_count": len(envVars),
+ })
+ }
+
+ // Environment variables from config override those from file
+ for k, v := range cfg.Env {
+ envMap[k] = v
+ }
+
+ // Convert map to slice
+ env := make([]string, 0, len(envMap))
+ for k, v := range envMap {
+ env = append(env, fmt.Sprintf("%s=%s", k, v))
+ }
+ cmd.Env = env
+
+ transport = &mcp.CommandTransport{Command: cmd}
+ default:
+ return fmt.Errorf(
+ "unsupported transport type: %s (supported: stdio, sse, http)",
+ transportType,
+ )
+ }
+
+ // Connect to server
+ session, err := client.Connect(ctx, transport, nil)
+ if err != nil {
+ return fmt.Errorf("failed to connect: %w", err)
+ }
+
+ // Get server info
+ initResult := session.InitializeResult()
+ logger.InfoCF("mcp", "Connected to MCP server",
+ map[string]any{
+ "server": name,
+ "serverName": initResult.ServerInfo.Name,
+ "serverVersion": initResult.ServerInfo.Version,
+ "protocol": initResult.ProtocolVersion,
+ })
+
+ // List available tools if supported
+ var tools []*mcp.Tool
+ if initResult.Capabilities.Tools != nil {
+ for tool, err := range session.Tools(ctx, nil) {
+ if err != nil {
+ logger.WarnCF("mcp", "Error listing tool",
+ map[string]any{
+ "server": name,
+ "error": err.Error(),
+ })
+ continue
+ }
+ tools = append(tools, tool)
+ }
+
+ logger.InfoCF("mcp", "Listed tools from MCP server",
+ map[string]any{
+ "server": name,
+ "toolCount": len(tools),
+ })
+ }
+
+ // Store connection
+ m.mu.Lock()
+ m.servers[name] = &ServerConnection{
+ Name: name,
+ Client: client,
+ Session: session,
+ Tools: tools,
+ }
+ m.mu.Unlock()
+
+ return nil
+}
+
+// GetServers returns all connected servers
+func (m *Manager) GetServers() map[string]*ServerConnection {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ result := make(map[string]*ServerConnection, len(m.servers))
+ for k, v := range m.servers {
+ result[k] = v
+ }
+ return result
+}
+
+// GetServer returns a specific server connection
+func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ conn, ok := m.servers[name]
+ return conn, ok
+}
+
+// CallTool calls a tool on a specific server
+func (m *Manager) CallTool(
+ ctx context.Context,
+ serverName, toolName string,
+ arguments map[string]any,
+) (*mcp.CallToolResult, error) {
+ // Check if closed before acquiring lock (fast path)
+ if m.closed.Load() {
+ return nil, fmt.Errorf("manager is closed")
+ }
+
+ m.mu.RLock()
+ // Double-check after acquiring lock to prevent TOCTOU race
+ if m.closed.Load() {
+ m.mu.RUnlock()
+ return nil, fmt.Errorf("manager is closed")
+ }
+ conn, ok := m.servers[serverName]
+ if ok {
+ m.wg.Add(1) // Add to WaitGroup while holding the lock
+ }
+ m.mu.RUnlock()
+
+ if !ok {
+ return nil, fmt.Errorf("server %s not found", serverName)
+ }
+ defer m.wg.Done()
+
+ params := &mcp.CallToolParams{
+ Name: toolName,
+ Arguments: arguments,
+ }
+
+ result, err := conn.Session.CallTool(ctx, params)
+ if err != nil {
+ return nil, fmt.Errorf("failed to call tool: %w", err)
+ }
+
+ return result, nil
+}
+
+// Close closes all server connections
+func (m *Manager) Close() error {
+ // Use Swap to atomically set closed=true and get the previous value
+ // This prevents TOCTOU race with CallTool's closed check
+ if m.closed.Swap(true) {
+ return nil // already closed
+ }
+
+ // Wait for all in-flight CallTool calls to finish before closing sessions
+ // After closed=true is set, no new CallTool can start (they check closed first)
+ m.wg.Wait()
+
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ logger.InfoCF("mcp", "Closing all MCP server connections",
+ map[string]any{
+ "count": len(m.servers),
+ })
+
+ var errs []error
+ for name, conn := range m.servers {
+ if err := conn.Session.Close(); err != nil {
+ logger.ErrorCF("mcp", "Failed to close server connection",
+ map[string]any{
+ "server": name,
+ "error": err.Error(),
+ })
+ errs = append(errs, fmt.Errorf("server %s: %w", name, err))
+ }
+ }
+
+ m.servers = make(map[string]*ServerConnection)
+
+ if len(errs) > 0 {
+ return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...))
+ }
+
+ return nil
+}
+
+// GetAllTools returns all tools from all connected servers
+func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ result := make(map[string][]*mcp.Tool)
+ for name, conn := range m.servers {
+ if len(conn.Tools) > 0 {
+ result[name] = conn.Tools
+ }
+ }
+ return result
+}
diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go
new file mode 100644
index 000000000..8ce81d09e
--- /dev/null
+++ b/pkg/mcp/manager_test.go
@@ -0,0 +1,298 @@
+package mcp
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func TestLoadEnvFile(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ expected map[string]string
+ expectErr bool
+ }{
+ {
+ name: "basic env file",
+ content: `API_KEY=secret123
+DATABASE_URL=postgres://localhost/db
+PORT=8080`,
+ expected: map[string]string{
+ "API_KEY": "secret123",
+ "DATABASE_URL": "postgres://localhost/db",
+ "PORT": "8080",
+ },
+ expectErr: false,
+ },
+ {
+ name: "with comments and empty lines",
+ content: `# This is a comment
+API_KEY=secret123
+
+# Another comment
+DATABASE_URL=postgres://localhost/db
+
+PORT=8080`,
+ expected: map[string]string{
+ "API_KEY": "secret123",
+ "DATABASE_URL": "postgres://localhost/db",
+ "PORT": "8080",
+ },
+ expectErr: false,
+ },
+ {
+ name: "with quoted values",
+ content: `API_KEY="secret with spaces"
+NAME='single quoted'
+PLAIN=no-quotes`,
+ expected: map[string]string{
+ "API_KEY": "secret with spaces",
+ "NAME": "single quoted",
+ "PLAIN": "no-quotes",
+ },
+ expectErr: false,
+ },
+ {
+ name: "with spaces around equals",
+ content: `API_KEY = secret123
+DATABASE_URL= postgres://localhost/db
+PORT =8080`,
+ expected: map[string]string{
+ "API_KEY": "secret123",
+ "DATABASE_URL": "postgres://localhost/db",
+ "PORT": "8080",
+ },
+ expectErr: false,
+ },
+ {
+ name: "invalid format - no equals",
+ content: `INVALID_LINE`,
+ expectErr: true,
+ },
+ {
+ name: "empty file",
+ content: ``,
+ expected: map[string]string{},
+ expectErr: false,
+ },
+ {
+ name: "only comments",
+ content: `# Comment 1
+# Comment 2`,
+ expected: map[string]string{},
+ expectErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir := t.TempDir()
+ envFile := filepath.Join(tmpDir, ".env")
+
+ if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ result, err := loadEnvFile(envFile)
+
+ if tt.expectErr {
+ if err == nil {
+ t.Errorf("Expected error but got none")
+ }
+ return
+ }
+
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ return
+ }
+
+ if len(result) != len(tt.expected) {
+ t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
+ }
+
+ for key, expectedValue := range tt.expected {
+ if actualValue, ok := result[key]; !ok {
+ t.Errorf("Expected key %s not found", key)
+ } else if actualValue != expectedValue {
+ t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
+ }
+ }
+ })
+ }
+}
+
+func TestLoadEnvFileNotFound(t *testing.T) {
+ _, err := loadEnvFile("/nonexistent/file.env")
+ if err == nil {
+ t.Error("Expected error for nonexistent file")
+ }
+}
+
+func TestEnvFilePriority(t *testing.T) {
+ // Create a temporary .env file
+ tmpDir := t.TempDir()
+ envFile := filepath.Join(tmpDir, ".env")
+
+ envContent := `API_KEY=from_file
+DATABASE_URL=from_file
+SHARED_VAR=from_file`
+
+ if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil {
+ t.Fatalf("Failed to create .env file: %v", err)
+ }
+
+ // Load envFile
+ envVars, err := loadEnvFile(envFile)
+ if err != nil {
+ t.Fatalf("Failed to load env file: %v", err)
+ }
+
+ // Verify envFile variables
+ if envVars["API_KEY"] != "from_file" {
+ t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
+ }
+
+ // Simulate config.Env overriding envFile
+ configEnv := map[string]string{
+ "SHARED_VAR": "from_config",
+ "NEW_VAR": "from_config",
+ }
+
+ // Merge: envFile first, then config overrides
+ merged := make(map[string]string)
+ for k, v := range envVars {
+ merged[k] = v
+ }
+ for k, v := range configEnv {
+ merged[k] = v
+ }
+
+ // Verify priority: config.Env should override envFile
+ if merged["SHARED_VAR"] != "from_config" {
+ t.Errorf(
+ "Expected SHARED_VAR=from_config (config should override file), got %s",
+ merged["SHARED_VAR"],
+ )
+ }
+ if merged["API_KEY"] != "from_file" {
+ t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
+ }
+ if merged["NEW_VAR"] != "from_config" {
+ t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
+ }
+}
+
+func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
+ mgr := NewManager()
+
+ mcpCfg := config.MCPConfig{
+ Enabled: true,
+ Servers: map[string]config.MCPServerConfig{
+ "test-server": {
+ Enabled: true,
+ Command: "echo",
+ Args: []string{"ok"},
+ EnvFile: ".env",
+ },
+ },
+ }
+
+ err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "")
+ if err == nil {
+ t.Fatal("expected error for relative env_file with empty workspace path, got nil")
+ }
+
+ if !strings.Contains(err.Error(), "workspace path is empty") {
+ t.Fatalf("expected workspace path validation error, got: %v", err)
+ }
+}
+
+func TestNewManager_InitialState(t *testing.T) {
+ mgr := NewManager()
+ if mgr == nil {
+ t.Fatal("expected manager instance, got nil")
+ }
+ if len(mgr.GetServers()) != 0 {
+ t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers()))
+ }
+}
+
+func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
+ mgr := NewManager()
+
+ err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
+ if err != nil {
+ t.Fatalf("expected nil error when MCP disabled, got: %v", err)
+ }
+
+ err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
+ if err != nil {
+ t.Fatalf("expected nil error when no servers configured, got: %v", err)
+ }
+}
+
+func TestGetServers_ReturnsCopy(t *testing.T) {
+ mgr := NewManager()
+ mgr.servers["s1"] = &ServerConnection{Name: "s1"}
+
+ servers := mgr.GetServers()
+ delete(servers, "s1")
+
+ if _, ok := mgr.GetServer("s1"); !ok {
+ t.Fatal("expected internal manager state to remain unchanged")
+ }
+}
+
+func TestGetAllTools_FiltersEmptyTools(t *testing.T) {
+ mgr := NewManager()
+ mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil}
+ mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}}
+
+ all := mgr.GetAllTools()
+ if _, ok := all["empty"]; ok {
+ t.Fatal("expected server without tools to be excluded")
+ }
+ if _, ok := all["with-tools"]; !ok {
+ t.Fatal("expected server with tools to be included")
+ }
+}
+
+func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
+ t.Run("manager closed", func(t *testing.T) {
+ mgr := NewManager()
+ mgr.closed.Store(true)
+
+ _, err := mgr.CallTool(context.Background(), "s1", "tool", nil)
+ if err == nil || !strings.Contains(err.Error(), "manager is closed") {
+ t.Fatalf("expected manager closed error, got: %v", err)
+ }
+ })
+
+ t.Run("server missing", func(t *testing.T) {
+ mgr := NewManager()
+
+ _, err := mgr.CallTool(context.Background(), "missing", "tool", nil)
+ if err == nil || !strings.Contains(err.Error(), "not found") {
+ t.Fatalf("expected server not found error, got: %v", err)
+ }
+ })
+}
+
+func TestClose_IdempotentOnEmptyManager(t *testing.T) {
+ mgr := NewManager()
+
+ if err := mgr.Close(); err != nil {
+ t.Fatalf("first close should succeed, got: %v", err)
+ }
+ if err := mgr.Close(); err != nil {
+ t.Fatalf("second close should be idempotent, got: %v", err)
+ }
+}
diff --git a/pkg/media/store.go b/pkg/media/store.go
new file mode 100644
index 000000000..30220986c
--- /dev/null
+++ b/pkg/media/store.go
@@ -0,0 +1,271 @@
+package media
+
+import (
+ "fmt"
+ "os"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// MediaMeta holds metadata about a stored media file.
+type MediaMeta struct {
+ Filename string
+ ContentType string
+ Source string // "telegram", "discord", "tool:image-gen", etc.
+}
+
+// MediaStore manages the lifecycle of media files associated with processing scopes.
+type MediaStore interface {
+ // Store registers an existing local file under the given scope.
+ // Returns a ref identifier (e.g. "media://").
+ // Store does not move or copy the file; it only records the mapping.
+ Store(localPath string, meta MediaMeta, scope string) (ref string, err error)
+
+ // Resolve returns the local file path for a given ref.
+ Resolve(ref string) (localPath string, err error)
+
+ // ResolveWithMeta returns the local file path and metadata for a given ref.
+ ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error)
+
+ // ReleaseAll deletes all files registered under the given scope
+ // and removes the mapping entries. File-not-exist errors are ignored.
+ ReleaseAll(scope string) error
+}
+
+// mediaEntry holds the path and metadata for a stored media file.
+type mediaEntry struct {
+ path string
+ meta MediaMeta
+ storedAt time.Time
+}
+
+// MediaCleanerConfig configures the background TTL cleanup.
+type MediaCleanerConfig struct {
+ Enabled bool
+ MaxAge time.Duration
+ Interval time.Duration
+}
+
+// FileMediaStore is a pure in-memory implementation of MediaStore.
+// Files are expected to already exist on disk (e.g. in /tmp/picoclaw_media/).
+type FileMediaStore struct {
+ mu sync.RWMutex
+ refs map[string]mediaEntry
+ scopeToRefs map[string]map[string]struct{}
+ refToScope map[string]string
+
+ cleanerCfg MediaCleanerConfig
+ stop chan struct{}
+ startOnce sync.Once
+ stopOnce sync.Once
+ nowFunc func() time.Time // for testing
+}
+
+// NewFileMediaStore creates a new FileMediaStore without background cleanup.
+func NewFileMediaStore() *FileMediaStore {
+ return &FileMediaStore{
+ refs: make(map[string]mediaEntry),
+ scopeToRefs: make(map[string]map[string]struct{}),
+ refToScope: make(map[string]string),
+ nowFunc: time.Now,
+ }
+}
+
+// NewFileMediaStoreWithCleanup creates a FileMediaStore with TTL-based background cleanup.
+func NewFileMediaStoreWithCleanup(cfg MediaCleanerConfig) *FileMediaStore {
+ return &FileMediaStore{
+ refs: make(map[string]mediaEntry),
+ scopeToRefs: make(map[string]map[string]struct{}),
+ refToScope: make(map[string]string),
+ cleanerCfg: cfg,
+ stop: make(chan struct{}),
+ nowFunc: time.Now,
+ }
+}
+
+// Store registers a local file under the given scope. The file must exist.
+func (s *FileMediaStore) Store(localPath string, meta MediaMeta, scope string) (string, error) {
+ if _, err := os.Stat(localPath); err != nil {
+ return "", fmt.Errorf("media store: %s: %w", localPath, err)
+ }
+
+ ref := "media://" + uuid.New().String()
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.refs[ref] = mediaEntry{path: localPath, meta: meta, storedAt: s.nowFunc()}
+ if s.scopeToRefs[scope] == nil {
+ s.scopeToRefs[scope] = make(map[string]struct{})
+ }
+ s.scopeToRefs[scope][ref] = struct{}{}
+ s.refToScope[ref] = scope
+
+ return ref, nil
+}
+
+// Resolve returns the local path for the given ref.
+func (s *FileMediaStore) Resolve(ref string) (string, error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ entry, ok := s.refs[ref]
+ if !ok {
+ return "", fmt.Errorf("media store: unknown ref: %s", ref)
+ }
+ return entry.path, nil
+}
+
+// ResolveWithMeta returns the local path and metadata for the given ref.
+func (s *FileMediaStore) ResolveWithMeta(ref string) (string, MediaMeta, error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ entry, ok := s.refs[ref]
+ if !ok {
+ return "", MediaMeta{}, fmt.Errorf("media store: unknown ref: %s", ref)
+ }
+ return entry.path, entry.meta, nil
+}
+
+// ReleaseAll removes all files under the given scope and cleans up mappings.
+// Phase 1 (under lock): remove entries from maps.
+// Phase 2 (no lock): delete files from disk.
+func (s *FileMediaStore) ReleaseAll(scope string) error {
+ // Phase 1: collect paths and remove from maps under lock
+ var paths []string
+
+ s.mu.Lock()
+ refs, ok := s.scopeToRefs[scope]
+ if !ok {
+ s.mu.Unlock()
+ return nil
+ }
+
+ for ref := range refs {
+ if entry, exists := s.refs[ref]; exists {
+ paths = append(paths, entry.path)
+ }
+ delete(s.refs, ref)
+ delete(s.refToScope, ref)
+ }
+ delete(s.scopeToRefs, scope)
+ s.mu.Unlock()
+
+ // Phase 2: delete files without holding the lock
+ for _, p := range paths {
+ if err := os.Remove(p); err != nil && !os.IsNotExist(err) {
+ logger.WarnCF("media", "release: failed to remove file", map[string]any{
+ "path": p,
+ "error": err.Error(),
+ })
+ }
+ }
+
+ return nil
+}
+
+// CleanExpired removes all entries older than MaxAge.
+// Phase 1 (under lock): identify expired entries and remove from maps.
+// Phase 2 (no lock): delete files from disk to minimize lock contention.
+func (s *FileMediaStore) CleanExpired() int {
+ if s.cleanerCfg.MaxAge <= 0 {
+ return 0
+ }
+
+ // Phase 1: collect expired entries under lock
+ type expiredEntry struct {
+ ref string
+ path string
+ }
+
+ s.mu.Lock()
+ cutoff := s.nowFunc().Add(-s.cleanerCfg.MaxAge)
+ var expired []expiredEntry
+
+ for ref, entry := range s.refs {
+ if entry.storedAt.Before(cutoff) {
+ expired = append(expired, expiredEntry{ref: ref, path: entry.path})
+
+ if scope, ok := s.refToScope[ref]; ok {
+ if scopeRefs, ok := s.scopeToRefs[scope]; ok {
+ delete(scopeRefs, ref)
+ if len(scopeRefs) == 0 {
+ delete(s.scopeToRefs, scope)
+ }
+ }
+ }
+
+ delete(s.refs, ref)
+ delete(s.refToScope, ref)
+ }
+ }
+ s.mu.Unlock()
+
+ // Phase 2: delete files without holding the lock
+ for _, e := range expired {
+ if err := os.Remove(e.path); err != nil && !os.IsNotExist(err) {
+ logger.WarnCF("media", "cleanup: failed to remove file", map[string]any{
+ "path": e.path,
+ "error": err.Error(),
+ })
+ }
+ }
+
+ return len(expired)
+}
+
+// Start begins the background cleanup goroutine if cleanup is enabled.
+// Safe to call multiple times; only the first call starts the goroutine.
+func (s *FileMediaStore) Start() {
+ if !s.cleanerCfg.Enabled || s.stop == nil {
+ return
+ }
+ if s.cleanerCfg.Interval <= 0 || s.cleanerCfg.MaxAge <= 0 {
+ logger.WarnCF("media", "cleanup: skipped due to invalid config", map[string]any{
+ "interval": s.cleanerCfg.Interval.String(),
+ "max_age": s.cleanerCfg.MaxAge.String(),
+ })
+ return
+ }
+
+ s.startOnce.Do(func() {
+ logger.InfoCF("media", "cleanup enabled", map[string]any{
+ "interval": s.cleanerCfg.Interval.String(),
+ "max_age": s.cleanerCfg.MaxAge.String(),
+ })
+
+ go func() {
+ ticker := time.NewTicker(s.cleanerCfg.Interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ if n := s.CleanExpired(); n > 0 {
+ logger.InfoCF("media", "cleanup: removed expired entries", map[string]any{
+ "count": n,
+ })
+ }
+ case <-s.stop:
+ return
+ }
+ }
+ }()
+ })
+}
+
+// Stop terminates the background cleanup goroutine.
+// Safe to call multiple times; only the first call closes the channel.
+func (s *FileMediaStore) Stop() {
+ if s.stop == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ close(s.stop)
+ })
+}
diff --git a/pkg/media/store_test.go b/pkg/media/store_test.go
new file mode 100644
index 000000000..1dcfdf350
--- /dev/null
+++ b/pkg/media/store_test.go
@@ -0,0 +1,530 @@
+package media
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+func createTempFile(t *testing.T, dir, name string) string {
+ t.Helper()
+ path := filepath.Join(dir, name)
+ if err := os.WriteFile(path, []byte("test content"), 0o644); err != nil {
+ t.Fatalf("failed to create temp file: %v", err)
+ }
+ return path
+}
+
+func TestStoreAndResolve(t *testing.T) {
+ dir := t.TempDir()
+ store := NewFileMediaStore()
+
+ path := createTempFile(t, dir, "photo.jpg")
+
+ ref, err := store.Store(path, MediaMeta{Filename: "photo.jpg", Source: "telegram"}, "scope1")
+ if err != nil {
+ t.Fatalf("Store failed: %v", err)
+ }
+
+ if !strings.HasPrefix(ref, "media://") {
+ t.Errorf("ref should start with media://, got %q", ref)
+ }
+
+ resolved, err := store.Resolve(ref)
+ if err != nil {
+ t.Fatalf("Resolve failed: %v", err)
+ }
+ if resolved != path {
+ t.Errorf("Resolve returned %q, want %q", resolved, path)
+ }
+}
+
+func TestReleaseAll(t *testing.T) {
+ dir := t.TempDir()
+ store := NewFileMediaStore()
+
+ paths := make([]string, 3)
+ refs := make([]string, 3)
+ for i := range 3 {
+ paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg")
+ var err error
+ refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1")
+ if err != nil {
+ t.Fatalf("Store failed: %v", err)
+ }
+ }
+
+ if err := store.ReleaseAll("scope1"); err != nil {
+ t.Fatalf("ReleaseAll failed: %v", err)
+ }
+
+ // Files should be deleted
+ for _, p := range paths {
+ if _, err := os.Stat(p); !os.IsNotExist(err) {
+ t.Errorf("file %q should have been deleted", p)
+ }
+ }
+
+ // Refs should be unresolvable
+ for _, ref := range refs {
+ if _, err := store.Resolve(ref); err == nil {
+ t.Errorf("Resolve(%q) should fail after ReleaseAll", ref)
+ }
+ }
+}
+
+func TestMultiScopeIsolation(t *testing.T) {
+ dir := t.TempDir()
+ store := NewFileMediaStore()
+
+ pathA := createTempFile(t, dir, "fileA.jpg")
+ pathB := createTempFile(t, dir, "fileB.jpg")
+
+ refA, _ := store.Store(pathA, MediaMeta{Source: "test"}, "scopeA")
+ refB, _ := store.Store(pathB, MediaMeta{Source: "test"}, "scopeB")
+
+ // Release only scopeA
+ if err := store.ReleaseAll("scopeA"); err != nil {
+ t.Fatalf("ReleaseAll(scopeA) failed: %v", err)
+ }
+
+ // scopeA file should be gone
+ if _, err := os.Stat(pathA); !os.IsNotExist(err) {
+ t.Error("file A should have been deleted")
+ }
+ if _, err := store.Resolve(refA); err == nil {
+ t.Error("refA should be unresolvable after release")
+ }
+
+ // scopeB file should still exist
+ if _, err := os.Stat(pathB); err != nil {
+ t.Error("file B should still exist")
+ }
+ resolved, err := store.Resolve(refB)
+ if err != nil {
+ t.Fatalf("refB should still resolve: %v", err)
+ }
+ if resolved != pathB {
+ t.Errorf("resolved %q, want %q", resolved, pathB)
+ }
+}
+
+func TestReleaseAllIdempotent(t *testing.T) {
+ store := NewFileMediaStore()
+
+ // ReleaseAll on non-existent scope should not error
+ if err := store.ReleaseAll("nonexistent"); err != nil {
+ t.Fatalf("ReleaseAll on empty scope should not error: %v", err)
+ }
+
+ // Create and release, then release again
+ dir := t.TempDir()
+ path := createTempFile(t, dir, "file.jpg")
+ _, _ = store.Store(path, MediaMeta{Source: "test"}, "scope1")
+
+ if err := store.ReleaseAll("scope1"); err != nil {
+ t.Fatalf("first ReleaseAll failed: %v", err)
+ }
+ if err := store.ReleaseAll("scope1"); err != nil {
+ t.Fatalf("second ReleaseAll should not error: %v", err)
+ }
+}
+
+func TestReleaseAllCleansMappingsIfRefsMissing(t *testing.T) {
+ dir := t.TempDir()
+ store := NewFileMediaStore()
+
+ path := createTempFile(t, dir, "file.jpg")
+ ref, err := store.Store(path, MediaMeta{Source: "test"}, "scope1")
+ if err != nil {
+ t.Fatalf("Store failed: %v", err)
+ }
+
+ // Simulate internal inconsistency: scopeToRefs/refToScope contains ref but refs map doesn't.
+ store.mu.Lock()
+ delete(store.refs, ref)
+ store.mu.Unlock()
+
+ if err := store.ReleaseAll("scope1"); err != nil {
+ t.Fatalf("ReleaseAll failed: %v", err)
+ }
+
+ // ReleaseAll should still clean mappings (even if it can't delete the file without the path).
+ store.mu.RLock()
+ defer store.mu.RUnlock()
+ if _, ok := store.refToScope[ref]; ok {
+ t.Error("refToScope should not contain ref after ReleaseAll")
+ }
+ if _, ok := store.scopeToRefs["scope1"]; ok {
+ t.Error("scopeToRefs should not contain scope1 after ReleaseAll")
+ }
+}
+
+func TestStoreNonexistentFile(t *testing.T) {
+ store := NewFileMediaStore()
+
+ _, err := store.Store("/nonexistent/path/file.jpg", MediaMeta{Source: "test"}, "scope1")
+ if err == nil {
+ t.Error("Store should fail for nonexistent file")
+ }
+ // Error message should include the underlying os error, not just "file does not exist"
+ if !strings.Contains(err.Error(), "no such file or directory") &&
+ !strings.Contains(err.Error(), "cannot find") {
+ t.Errorf("Error should contain OS error detail, got: %v", err)
+ }
+}
+
+func TestResolveWithMeta(t *testing.T) {
+ dir := t.TempDir()
+ store := NewFileMediaStore()
+
+ path := createTempFile(t, dir, "image.png")
+ meta := MediaMeta{
+ Filename: "image.png",
+ ContentType: "image/png",
+ Source: "telegram",
+ }
+
+ ref, err := store.Store(path, meta, "scope1")
+ if err != nil {
+ t.Fatalf("Store failed: %v", err)
+ }
+
+ resolvedPath, resolvedMeta, err := store.ResolveWithMeta(ref)
+ if err != nil {
+ t.Fatalf("ResolveWithMeta failed: %v", err)
+ }
+ if resolvedPath != path {
+ t.Errorf("ResolveWithMeta path = %q, want %q", resolvedPath, path)
+ }
+ if resolvedMeta.Filename != meta.Filename {
+ t.Errorf("ResolveWithMeta Filename = %q, want %q", resolvedMeta.Filename, meta.Filename)
+ }
+ if resolvedMeta.ContentType != meta.ContentType {
+ t.Errorf("ResolveWithMeta ContentType = %q, want %q", resolvedMeta.ContentType, meta.ContentType)
+ }
+ if resolvedMeta.Source != meta.Source {
+ t.Errorf("ResolveWithMeta Source = %q, want %q", resolvedMeta.Source, meta.Source)
+ }
+
+ // Unknown ref should fail
+ _, _, err = store.ResolveWithMeta("media://nonexistent")
+ if err == nil {
+ t.Error("ResolveWithMeta should fail for unknown ref")
+ }
+}
+
+func TestConcurrentSafety(t *testing.T) {
+ dir := t.TempDir()
+ store := NewFileMediaStore()
+
+ const goroutines = 20
+ const filesPerGoroutine = 5
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for g := range goroutines {
+ go func(gIdx int) {
+ defer wg.Done()
+ scope := strings.Repeat("s", gIdx+1)
+
+ for i := range filesPerGoroutine {
+ path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp")
+ ref, err := store.Store(path, MediaMeta{Source: "test"}, scope)
+ if err != nil {
+ t.Errorf("Store failed: %v", err)
+ return
+ }
+
+ if _, err := store.Resolve(ref); err != nil {
+ t.Errorf("Resolve failed: %v", err)
+ }
+ }
+
+ if err := store.ReleaseAll(scope); err != nil {
+ t.Errorf("ReleaseAll failed: %v", err)
+ }
+ }(g)
+ }
+
+ wg.Wait()
+}
+
+// --- TTL cleanup tests ---
+
+func newTestStoreWithCleanup(maxAge time.Duration) *FileMediaStore {
+ s := NewFileMediaStoreWithCleanup(MediaCleanerConfig{
+ Enabled: true,
+ MaxAge: maxAge,
+ Interval: time.Hour, // won't tick in tests
+ })
+ return s
+}
+
+func TestCleanExpiredRemovesOldEntries(t *testing.T) {
+ dir := t.TempDir()
+ now := time.Now()
+ store := newTestStoreWithCleanup(10 * time.Minute)
+ store.nowFunc = func() time.Time { return now.Add(-20 * time.Minute) }
+
+ path := createTempFile(t, dir, "old.jpg")
+ ref, err := store.Store(path, MediaMeta{Source: "test"}, "scope1")
+ if err != nil {
+ t.Fatalf("Store failed: %v", err)
+ }
+
+ // Advance clock to present
+ store.nowFunc = func() time.Time { return now }
+ removed := store.CleanExpired()
+
+ if removed != 1 {
+ t.Errorf("expected 1 removed, got %d", removed)
+ }
+ if _, err := store.Resolve(ref); err == nil {
+ t.Error("expired ref should be unresolvable")
+ }
+ if _, err := os.Stat(path); !os.IsNotExist(err) {
+ t.Error("expired file should be deleted")
+ }
+}
+
+func TestCleanExpiredKeepsNonExpired(t *testing.T) {
+ dir := t.TempDir()
+ now := time.Now()
+ store := newTestStoreWithCleanup(10 * time.Minute)
+ store.nowFunc = func() time.Time { return now }
+
+ path := createTempFile(t, dir, "fresh.jpg")
+ ref, err := store.Store(path, MediaMeta{Source: "test"}, "scope1")
+ if err != nil {
+ t.Fatalf("Store failed: %v", err)
+ }
+
+ removed := store.CleanExpired()
+ if removed != 0 {
+ t.Errorf("expected 0 removed, got %d", removed)
+ }
+
+ if _, err := store.Resolve(ref); err != nil {
+ t.Errorf("fresh ref should still resolve: %v", err)
+ }
+ if _, err := os.Stat(path); err != nil {
+ t.Error("fresh file should still exist")
+ }
+}
+
+func TestCleanExpiredMixedAges(t *testing.T) {
+ dir := t.TempDir()
+ now := time.Now()
+ store := newTestStoreWithCleanup(10 * time.Minute)
+
+ // Store old entry
+ store.nowFunc = func() time.Time { return now.Add(-20 * time.Minute) }
+ oldPath := createTempFile(t, dir, "old.jpg")
+ oldRef, _ := store.Store(oldPath, MediaMeta{Source: "test"}, "scope1")
+
+ // Store fresh entry
+ store.nowFunc = func() time.Time { return now }
+ freshPath := createTempFile(t, dir, "fresh.jpg")
+ freshRef, _ := store.Store(freshPath, MediaMeta{Source: "test"}, "scope1")
+
+ removed := store.CleanExpired()
+ if removed != 1 {
+ t.Errorf("expected 1 removed, got %d", removed)
+ }
+
+ if _, err := store.Resolve(oldRef); err == nil {
+ t.Error("old ref should be gone")
+ }
+ if _, err := store.Resolve(freshRef); err != nil {
+ t.Errorf("fresh ref should still resolve: %v", err)
+ }
+}
+
+func TestCleanExpiredCleansEmptyScopes(t *testing.T) {
+ dir := t.TempDir()
+ now := time.Now()
+ store := newTestStoreWithCleanup(10 * time.Minute)
+
+ // Store old entry as the only one in scope
+ store.nowFunc = func() time.Time { return now.Add(-20 * time.Minute) }
+ path := createTempFile(t, dir, "only.jpg")
+ store.Store(path, MediaMeta{Source: "test"}, "lonely_scope")
+
+ store.nowFunc = func() time.Time { return now }
+ store.CleanExpired()
+
+ store.mu.RLock()
+ defer store.mu.RUnlock()
+ if _, ok := store.scopeToRefs["lonely_scope"]; ok {
+ t.Error("empty scope should be cleaned up")
+ }
+}
+
+func TestStartStopLifecycle(t *testing.T) {
+ store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{
+ Enabled: true,
+ MaxAge: time.Minute,
+ Interval: 50 * time.Millisecond,
+ })
+
+ // Start and stop should not panic
+ store.Start()
+ // Double start should not spawn a second goroutine
+ store.Start()
+ time.Sleep(100 * time.Millisecond)
+ store.Stop()
+
+ // Double stop should not panic
+ store.Stop()
+}
+
+func TestCleanExpiredZeroMaxAge(t *testing.T) {
+ store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{
+ Enabled: true,
+ MaxAge: 0,
+ Interval: time.Hour,
+ })
+
+ dir := t.TempDir()
+ path := createTempFile(t, dir, "file.jpg")
+ ref, _ := store.Store(path, MediaMeta{Source: "test"}, "scope1")
+
+ // Zero MaxAge should be a no-op
+ removed := store.CleanExpired()
+ if removed != 0 {
+ t.Errorf("expected 0 removed with zero MaxAge, got %d", removed)
+ }
+ if _, err := store.Resolve(ref); err != nil {
+ t.Errorf("ref should still resolve: %v", err)
+ }
+}
+
+func TestStartDisabledIsNoop(t *testing.T) {
+ store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{
+ Enabled: false,
+ MaxAge: time.Minute,
+ Interval: time.Minute,
+ })
+ // Should not start any goroutine or panic
+ store.Start()
+ store.Stop()
+}
+
+func TestStartZeroIntervalNoPanic(t *testing.T) {
+ store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{
+ Enabled: true,
+ MaxAge: time.Minute,
+ Interval: 0,
+ })
+ // Zero interval should not panic (time.NewTicker panics on <= 0)
+ store.Start()
+ store.Stop()
+}
+
+func TestStartZeroMaxAgeNoPanic(t *testing.T) {
+ store := NewFileMediaStoreWithCleanup(MediaCleanerConfig{
+ Enabled: true,
+ MaxAge: 0,
+ Interval: time.Minute,
+ })
+ store.Start()
+ store.Stop()
+}
+
+func TestConcurrentCleanupSafety(t *testing.T) {
+ dir := t.TempDir()
+ store := newTestStoreWithCleanup(50 * time.Millisecond)
+ store.nowFunc = time.Now
+
+ const workers = 10
+ const ops = 20
+ var wg sync.WaitGroup
+ wg.Add(workers * 4)
+
+ // Store workers
+ for w := range workers {
+ go func(wIdx int) {
+ defer wg.Done()
+ scope := fmt.Sprintf("scope-%d", wIdx)
+ for i := range ops {
+ p := createTempFile(t, dir, fmt.Sprintf("w%d-f%d.tmp", wIdx, i))
+ store.Store(p, MediaMeta{Source: "test"}, scope)
+ }
+ }(w)
+ }
+
+ // Resolve workers
+ for range workers {
+ go func() {
+ defer wg.Done()
+ for range ops {
+ store.Resolve("media://nonexistent")
+ }
+ }()
+ }
+
+ // ReleaseAll workers
+ for w := range workers {
+ go func(wIdx int) {
+ defer wg.Done()
+ for range ops {
+ store.ReleaseAll(fmt.Sprintf("scope-%d", wIdx))
+ }
+ }(w)
+ }
+
+ // CleanExpired workers
+ for range workers {
+ go func() {
+ defer wg.Done()
+ for range ops {
+ store.CleanExpired()
+ }
+ }()
+ }
+
+ wg.Wait()
+}
+
+func TestRefToScopeConsistency(t *testing.T) {
+ dir := t.TempDir()
+ store := NewFileMediaStore()
+
+ // Store entries in two scopes
+ ref1, _ := store.Store(createTempFile(t, dir, "a.jpg"), MediaMeta{Source: "test"}, "s1")
+ ref2, _ := store.Store(createTempFile(t, dir, "b.jpg"), MediaMeta{Source: "test"}, "s1")
+ ref3, _ := store.Store(createTempFile(t, dir, "c.jpg"), MediaMeta{Source: "test"}, "s2")
+
+ store.mu.RLock()
+ checkRef := func(ref, expectedScope string) {
+ t.Helper()
+ if scope, ok := store.refToScope[ref]; !ok || scope != expectedScope {
+ t.Errorf("refToScope[%s] = %q, want %q", ref, scope, expectedScope)
+ }
+ }
+ checkRef(ref1, "s1")
+ checkRef(ref2, "s1")
+ checkRef(ref3, "s2")
+ store.mu.RUnlock()
+
+ // Release s1 and verify refToScope is cleaned
+ store.ReleaseAll("s1")
+
+ store.mu.RLock()
+ defer store.mu.RUnlock()
+ if _, ok := store.refToScope[ref1]; ok {
+ t.Error("refToScope should not contain ref1 after ReleaseAll")
+ }
+ if _, ok := store.refToScope[ref2]; ok {
+ t.Error("refToScope should not contain ref2 after ReleaseAll")
+ }
+ if _, ok := store.refToScope[ref3]; !ok {
+ t.Error("refToScope should still contain ref3")
+ }
+}
diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go
new file mode 100644
index 000000000..e12e2c5ab
--- /dev/null
+++ b/pkg/memory/jsonl.go
@@ -0,0 +1,460 @@
+package memory
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "hash/fnv"
+ "log"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+const (
+ // numLockShards is the fixed number of mutexes used to serialize
+ // per-session access. Using a sharded array instead of a map keeps
+ // memory bounded regardless of how many sessions are created over
+ // the lifetime of the process — important for a long-running daemon.
+ numLockShards = 64
+
+ // maxLineSize is the maximum size of a single JSON line in a .jsonl
+ // file. Tool results (read_file, web search, etc.) can be large, so
+ // we set a generous limit. The scanner starts at 64 KB and grows
+ // only as needed up to this cap.
+ maxLineSize = 10 * 1024 * 1024 // 10 MB
+)
+
+// sessionMeta holds per-session metadata stored in a .meta.json file.
+type sessionMeta struct {
+ Key string `json:"key"`
+ Summary string `json:"summary"`
+ Skip int `json:"skip"`
+ Count int `json:"count"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+// JSONLStore implements Store using append-only JSONL files.
+//
+// Each session is stored as two files:
+//
+// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only
+// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset)
+//
+// Messages are never physically deleted from the JSONL file. Instead,
+// TruncateHistory records a "skip" offset in the metadata file and
+// GetHistory ignores lines before that offset. This keeps all writes
+// append-only, which is both fast and crash-safe.
+type JSONLStore struct {
+ dir string
+ locks [numLockShards]sync.Mutex
+}
+
+// NewJSONLStore creates a new JSONL-backed store rooted at dir.
+func NewJSONLStore(dir string) (*JSONLStore, error) {
+ err := os.MkdirAll(dir, 0o755)
+ if err != nil {
+ return nil, fmt.Errorf("memory: create directory: %w", err)
+ }
+ return &JSONLStore{dir: dir}, nil
+}
+
+// sessionLock returns a mutex for the given session key.
+// Keys are mapped to a fixed pool of shards via FNV hash, so
+// memory usage is O(1) regardless of total session count.
+func (s *JSONLStore) sessionLock(key string) *sync.Mutex {
+ h := fnv.New32a()
+ h.Write([]byte(key))
+ return &s.locks[h.Sum32()%numLockShards]
+}
+
+func (s *JSONLStore) jsonlPath(key string) string {
+ return filepath.Join(s.dir, sanitizeKey(key)+".jsonl")
+}
+
+func (s *JSONLStore) metaPath(key string) string {
+ return filepath.Join(s.dir, sanitizeKey(key)+".meta.json")
+}
+
+// sanitizeKey converts a session key to a safe filename component.
+// Mirrors pkg/session.sanitizeFilename so that migration paths match.
+//
+// Note: this is a lossy mapping — "telegram:123" and "telegram_123"
+// both produce the same filename. This is an intentional tradeoff:
+// keys with colons (e.g. from channels) are by far the common case,
+// and a bidirectional encoding (like URL-encoding) would complicate
+// file listings and debugging.
+func sanitizeKey(key string) string {
+ return strings.ReplaceAll(key, ":", "_")
+}
+
+// readMeta loads the metadata file for a session.
+// Returns a zero-value sessionMeta if the file does not exist.
+func (s *JSONLStore) readMeta(key string) (sessionMeta, error) {
+ data, err := os.ReadFile(s.metaPath(key))
+ if os.IsNotExist(err) {
+ return sessionMeta{Key: key}, nil
+ }
+ if err != nil {
+ return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err)
+ }
+ var meta sessionMeta
+ err = json.Unmarshal(data, &meta)
+ if err != nil {
+ return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err)
+ }
+ return meta, nil
+}
+
+// writeMeta atomically writes the metadata file using the project's
+// standard WriteFileAtomic (temp + fsync + rename).
+func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
+ data, err := json.MarshalIndent(meta, "", " ")
+ if err != nil {
+ return fmt.Errorf("memory: encode meta: %w", err)
+ }
+ return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644)
+}
+
+// readMessages reads valid JSON lines from a .jsonl file, skipping
+// the first `skip` lines without unmarshaling them. This avoids the
+// cost of json.Unmarshal on logically truncated messages.
+// Malformed trailing lines (e.g. from a crash) are silently skipped.
+func readMessages(path string, skip int) ([]providers.Message, error) {
+ f, err := os.Open(path)
+ if os.IsNotExist(err) {
+ return []providers.Message{}, nil
+ }
+ if err != nil {
+ return nil, fmt.Errorf("memory: open jsonl: %w", err)
+ }
+ defer f.Close()
+
+ var msgs []providers.Message
+ scanner := bufio.NewScanner(f)
+ // Allow large lines for tool results (read_file, web search, etc.).
+ scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
+
+ lineNum := 0
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ if len(line) == 0 {
+ continue
+ }
+ lineNum++
+ if lineNum <= skip {
+ continue
+ }
+ var msg providers.Message
+ if err := json.Unmarshal(line, &msg); err != nil {
+ // Corrupt line — likely a partial write from a crash.
+ // Log so operators know data was skipped, but don't
+ // fail the entire read; this is the standard JSONL
+ // recovery pattern.
+ log.Printf("memory: skipping corrupt line %d in %s: %v",
+ lineNum, filepath.Base(path), err)
+ continue
+ }
+ msgs = append(msgs, msg)
+ }
+ if scanner.Err() != nil {
+ return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err())
+ }
+
+ if msgs == nil {
+ msgs = []providers.Message{}
+ }
+ return msgs, nil
+}
+
+// countLines counts the total number of non-empty lines in a .jsonl file.
+// Used by TruncateHistory to reconcile a stale meta.Count without
+// the overhead of unmarshaling every message.
+func countLines(path string) (int, error) {
+ f, err := os.Open(path)
+ if os.IsNotExist(err) {
+ return 0, nil
+ }
+ if err != nil {
+ return 0, fmt.Errorf("memory: open jsonl: %w", err)
+ }
+ defer f.Close()
+
+ n := 0
+ scanner := bufio.NewScanner(f)
+ scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
+ for scanner.Scan() {
+ if len(scanner.Bytes()) > 0 {
+ n++
+ }
+ }
+ return n, scanner.Err()
+}
+
+func (s *JSONLStore) AddMessage(
+ _ context.Context, sessionKey, role, content string,
+) error {
+ return s.addMsg(sessionKey, providers.Message{
+ Role: role,
+ Content: content,
+ })
+}
+
+func (s *JSONLStore) AddFullMessage(
+ _ context.Context, sessionKey string, msg providers.Message,
+) error {
+ return s.addMsg(sessionKey, msg)
+}
+
+// addMsg is the shared implementation for AddMessage and AddFullMessage.
+func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error {
+ l := s.sessionLock(sessionKey)
+ l.Lock()
+ defer l.Unlock()
+
+ // Append the message as a single JSON line.
+ line, err := json.Marshal(msg)
+ if err != nil {
+ return fmt.Errorf("memory: marshal message: %w", err)
+ }
+ line = append(line, '\n')
+
+ f, err := os.OpenFile(
+ s.jsonlPath(sessionKey),
+ os.O_CREATE|os.O_WRONLY|os.O_APPEND,
+ 0o644,
+ )
+ if err != nil {
+ return fmt.Errorf("memory: open jsonl for append: %w", err)
+ }
+ _, writeErr := f.Write(line)
+ if writeErr != nil {
+ f.Close()
+ return fmt.Errorf("memory: append message: %w", writeErr)
+ }
+ // Flush to physical storage before closing. This matches the
+ // durability guarantee of writeMeta and rewriteJSONL (which use
+ // WriteFileAtomic with fsync). Without Sync, a power loss could
+ // leave the append in the kernel page cache only — lost on reboot.
+ if syncErr := f.Sync(); syncErr != nil {
+ f.Close()
+ return fmt.Errorf("memory: sync jsonl: %w", syncErr)
+ }
+ if closeErr := f.Close(); closeErr != nil {
+ return fmt.Errorf("memory: close jsonl: %w", closeErr)
+ }
+
+ // Update metadata.
+ meta, err := s.readMeta(sessionKey)
+ if err != nil {
+ return err
+ }
+ now := time.Now()
+ if meta.Count == 0 && meta.CreatedAt.IsZero() {
+ meta.CreatedAt = now
+ }
+ meta.Count++
+ meta.UpdatedAt = now
+
+ return s.writeMeta(sessionKey, meta)
+}
+
+func (s *JSONLStore) GetHistory(
+ _ context.Context, sessionKey string,
+) ([]providers.Message, error) {
+ l := s.sessionLock(sessionKey)
+ l.Lock()
+ defer l.Unlock()
+
+ meta, err := s.readMeta(sessionKey)
+ if err != nil {
+ return nil, err
+ }
+
+ // Pass meta.Skip so readMessages skips those lines without
+ // unmarshaling them — avoids wasted CPU on truncated messages.
+ msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
+ if err != nil {
+ return nil, err
+ }
+
+ return msgs, nil
+}
+
+func (s *JSONLStore) GetSummary(
+ _ context.Context, sessionKey string,
+) (string, error) {
+ l := s.sessionLock(sessionKey)
+ l.Lock()
+ defer l.Unlock()
+
+ meta, err := s.readMeta(sessionKey)
+ if err != nil {
+ return "", err
+ }
+ return meta.Summary, nil
+}
+
+func (s *JSONLStore) SetSummary(
+ _ context.Context, sessionKey, summary string,
+) error {
+ l := s.sessionLock(sessionKey)
+ l.Lock()
+ defer l.Unlock()
+
+ meta, err := s.readMeta(sessionKey)
+ if err != nil {
+ return err
+ }
+ now := time.Now()
+ if meta.CreatedAt.IsZero() {
+ meta.CreatedAt = now
+ }
+ meta.Summary = summary
+ meta.UpdatedAt = now
+
+ return s.writeMeta(sessionKey, meta)
+}
+
+func (s *JSONLStore) TruncateHistory(
+ _ context.Context, sessionKey string, keepLast int,
+) error {
+ l := s.sessionLock(sessionKey)
+ l.Lock()
+ defer l.Unlock()
+
+ meta, err := s.readMeta(sessionKey)
+ if err != nil {
+ return err
+ }
+
+ // Always reconcile meta.Count with the actual line count on disk.
+ // A crash between the JSONL append and the meta update in addMsg
+ // leaves meta.Count stale (e.g. file has 101 lines but meta says
+ // 100). Counting lines is cheap — no unmarshal, just a scan — and
+ // TruncateHistory is not a hot path, so always re-count.
+ n, countErr := countLines(s.jsonlPath(sessionKey))
+ if countErr != nil {
+ return countErr
+ }
+ meta.Count = n
+
+ if keepLast <= 0 {
+ meta.Skip = meta.Count
+ } else {
+ effective := meta.Count - meta.Skip
+ if keepLast < effective {
+ meta.Skip = meta.Count - keepLast
+ }
+ }
+ meta.UpdatedAt = time.Now()
+
+ return s.writeMeta(sessionKey, meta)
+}
+
+func (s *JSONLStore) SetHistory(
+ _ context.Context,
+ sessionKey string,
+ history []providers.Message,
+) error {
+ l := s.sessionLock(sessionKey)
+ l.Lock()
+ defer l.Unlock()
+
+ meta, err := s.readMeta(sessionKey)
+ if err != nil {
+ return err
+ }
+ now := time.Now()
+ if meta.CreatedAt.IsZero() {
+ meta.CreatedAt = now
+ }
+ meta.Skip = 0
+ meta.Count = len(history)
+ meta.UpdatedAt = now
+
+ // Write meta BEFORE rewriting the JSONL file. If we crash between
+ // the two writes, meta has Skip=0 and the old file is still intact,
+ // so GetHistory reads from line 1 — returning "too many" messages
+ // rather than losing data. The next SetHistory call corrects this.
+ err = s.writeMeta(sessionKey, meta)
+ if err != nil {
+ return err
+ }
+
+ return s.rewriteJSONL(sessionKey, history)
+}
+
+// Compact physically rewrites the JSONL file, dropping all logically
+// skipped lines. This reclaims disk space that accumulates after
+// repeated TruncateHistory calls.
+//
+// It is safe to call at any time; if there is nothing to compact
+// (skip == 0) the method returns immediately.
+func (s *JSONLStore) Compact(
+ _ context.Context, sessionKey string,
+) error {
+ l := s.sessionLock(sessionKey)
+ l.Lock()
+ defer l.Unlock()
+
+ meta, err := s.readMeta(sessionKey)
+ if err != nil {
+ return err
+ }
+ if meta.Skip == 0 {
+ return nil
+ }
+
+ // Read only the active messages, skipping truncated lines
+ // without unmarshaling them.
+ active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
+ if err != nil {
+ return err
+ }
+
+ // Write meta BEFORE rewriting the JSONL file. If the process
+ // crashes between the two writes, meta has Skip=0 and the old
+ // (uncompacted) file is still intact, so GetHistory reads from
+ // line 1 — returning previously-truncated messages rather than
+ // losing data. The next Compact or TruncateHistory corrects this.
+ meta.Skip = 0
+ meta.Count = len(active)
+ meta.UpdatedAt = time.Now()
+
+ err = s.writeMeta(sessionKey, meta)
+ if err != nil {
+ return err
+ }
+
+ return s.rewriteJSONL(sessionKey, active)
+}
+
+// rewriteJSONL atomically replaces the JSONL file with the given messages
+// using the project's standard WriteFileAtomic (temp + fsync + rename).
+func (s *JSONLStore) rewriteJSONL(
+ sessionKey string, msgs []providers.Message,
+) error {
+ var buf bytes.Buffer
+ for i, msg := range msgs {
+ line, err := json.Marshal(msg)
+ if err != nil {
+ return fmt.Errorf("memory: marshal message %d: %w", i, err)
+ }
+ buf.Write(line)
+ buf.WriteByte('\n')
+ }
+ return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
+}
+
+func (s *JSONLStore) Close() error {
+ return nil
+}
diff --git a/pkg/memory/jsonl_test.go b/pkg/memory/jsonl_test.go
new file mode 100644
index 000000000..356ff14ff
--- /dev/null
+++ b/pkg/memory/jsonl_test.go
@@ -0,0 +1,835 @@
+package memory
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+func newTestStore(t *testing.T) *JSONLStore {
+ t.Helper()
+ store, err := NewJSONLStore(t.TempDir())
+ if err != nil {
+ t.Fatalf("NewJSONLStore: %v", err)
+ }
+ return store
+}
+
+func TestNewJSONLStore_CreatesDirectory(t *testing.T) {
+ dir := filepath.Join(t.TempDir(), "nested", "sessions")
+ store, err := NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore: %v", err)
+ }
+ defer store.Close()
+
+ info, err := os.Stat(dir)
+ if err != nil {
+ t.Fatalf("Stat: %v", err)
+ }
+ if !info.IsDir() {
+ t.Errorf("expected directory, got file")
+ }
+}
+
+func TestAddMessage_BasicRoundtrip(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ err := store.AddMessage(ctx, "s1", "user", "hello")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ err = store.AddMessage(ctx, "s1", "assistant", "hi there")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "s1")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 2 {
+ t.Fatalf("expected 2 messages, got %d", len(history))
+ }
+ if history[0].Role != "user" || history[0].Content != "hello" {
+ t.Errorf("msg[0] = %+v", history[0])
+ }
+ if history[1].Role != "assistant" || history[1].Content != "hi there" {
+ t.Errorf("msg[1] = %+v", history[1])
+ }
+}
+
+func TestAddMessage_AutoCreatesSession(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // Adding a message to a non-existent session should work.
+ err := store.AddMessage(ctx, "new-session", "user", "first message")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "new-session")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(history))
+ }
+}
+
+func TestAddFullMessage_WithToolCalls(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ msg := providers.Message{
+ Role: "assistant",
+ Content: "Let me search that.",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_abc",
+ Type: "function",
+ Function: &providers.FunctionCall{
+ Name: "web_search",
+ Arguments: `{"q":"golang jsonl"}`,
+ },
+ },
+ },
+ }
+
+ err := store.AddFullMessage(ctx, "tc", msg)
+ if err != nil {
+ t.Fatalf("AddFullMessage: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "tc")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Fatalf("expected 1, got %d", len(history))
+ }
+ if len(history[0].ToolCalls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
+ }
+ tc := history[0].ToolCalls[0]
+ if tc.ID != "call_abc" {
+ t.Errorf("tool call ID = %q", tc.ID)
+ }
+ if tc.Function == nil || tc.Function.Name != "web_search" {
+ t.Errorf("tool call function = %+v", tc.Function)
+ }
+}
+
+func TestAddFullMessage_ToolCallID(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ msg := providers.Message{
+ Role: "tool",
+ Content: "search results here",
+ ToolCallID: "call_abc",
+ }
+
+ err := store.AddFullMessage(ctx, "tr", msg)
+ if err != nil {
+ t.Fatalf("AddFullMessage: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "tr")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Fatalf("expected 1, got %d", len(history))
+ }
+ if history[0].ToolCallID != "call_abc" {
+ t.Errorf("ToolCallID = %q", history[0].ToolCallID)
+ }
+}
+
+func TestGetHistory_EmptySession(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ history, err := store.GetHistory(ctx, "nonexistent")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if history == nil {
+ t.Fatal("expected non-nil empty slice")
+ }
+ if len(history) != 0 {
+ t.Errorf("expected 0 messages, got %d", len(history))
+ }
+}
+
+func TestGetHistory_Ordering(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ for i := 0; i < 5; i++ {
+ err := store.AddMessage(
+ ctx, "order",
+ "user",
+ string(rune('a'+i)),
+ )
+ if err != nil {
+ t.Fatalf("AddMessage(%d): %v", i, err)
+ }
+ }
+
+ history, err := store.GetHistory(ctx, "order")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 5 {
+ t.Fatalf("expected 5, got %d", len(history))
+ }
+ for i := 0; i < 5; i++ {
+ expected := string(rune('a' + i))
+ if history[i].Content != expected {
+ t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected)
+ }
+ }
+}
+
+func TestSetSummary_GetSummary(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // No summary yet.
+ summary, err := store.GetSummary(ctx, "s1")
+ if err != nil {
+ t.Fatalf("GetSummary: %v", err)
+ }
+ if summary != "" {
+ t.Errorf("expected empty, got %q", summary)
+ }
+
+ // Set a summary.
+ err = store.SetSummary(ctx, "s1", "talked about Go")
+ if err != nil {
+ t.Fatalf("SetSummary: %v", err)
+ }
+
+ summary, err = store.GetSummary(ctx, "s1")
+ if err != nil {
+ t.Fatalf("GetSummary: %v", err)
+ }
+ if summary != "talked about Go" {
+ t.Errorf("summary = %q", summary)
+ }
+
+ // Update summary.
+ err = store.SetSummary(ctx, "s1", "updated summary")
+ if err != nil {
+ t.Fatalf("SetSummary: %v", err)
+ }
+
+ summary, err = store.GetSummary(ctx, "s1")
+ if err != nil {
+ t.Fatalf("GetSummary: %v", err)
+ }
+ if summary != "updated summary" {
+ t.Errorf("summary = %q", summary)
+ }
+}
+
+func TestTruncateHistory_KeepLast(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ for i := 0; i < 10; i++ {
+ err := store.AddMessage(
+ ctx, "trunc",
+ "user",
+ string(rune('a'+i)),
+ )
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ err := store.TruncateHistory(ctx, "trunc", 4)
+ if err != nil {
+ t.Fatalf("TruncateHistory: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "trunc")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 4 {
+ t.Fatalf("expected 4, got %d", len(history))
+ }
+ // Should be the last 4: g, h, i, j
+ if history[0].Content != "g" {
+ t.Errorf("first kept = %q, want 'g'", history[0].Content)
+ }
+ if history[3].Content != "j" {
+ t.Errorf("last kept = %q, want 'j'", history[3].Content)
+ }
+}
+
+func TestTruncateHistory_KeepZero(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ for i := 0; i < 5; i++ {
+ err := store.AddMessage(ctx, "empty", "user", "msg")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ err := store.TruncateHistory(ctx, "empty", 0)
+ if err != nil {
+ t.Fatalf("TruncateHistory: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "empty")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 0 {
+ t.Errorf("expected 0, got %d", len(history))
+ }
+}
+
+func TestTruncateHistory_KeepMoreThanExists(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ for i := 0; i < 3; i++ {
+ err := store.AddMessage(ctx, "few", "user", "msg")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ // Keep 100, but only 3 exist — should keep all.
+ err := store.TruncateHistory(ctx, "few", 100)
+ if err != nil {
+ t.Fatalf("TruncateHistory: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "few")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 3 {
+ t.Errorf("expected 3, got %d", len(history))
+ }
+}
+
+func TestSetHistory_ReplacesAll(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // Add some initial messages.
+ for i := 0; i < 5; i++ {
+ err := store.AddMessage(ctx, "replace", "user", "old")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ // Replace with new history.
+ newHistory := []providers.Message{
+ {Role: "user", Content: "new1"},
+ {Role: "assistant", Content: "new2"},
+ }
+ err := store.SetHistory(ctx, "replace", newHistory)
+ if err != nil {
+ t.Fatalf("SetHistory: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "replace")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 2 {
+ t.Fatalf("expected 2, got %d", len(history))
+ }
+ if history[0].Content != "new1" || history[1].Content != "new2" {
+ t.Errorf("history = %+v", history)
+ }
+}
+
+func TestSetHistory_ResetsSkip(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // Add messages and truncate.
+ for i := 0; i < 10; i++ {
+ err := store.AddMessage(ctx, "skip-reset", "user", "old")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+ err := store.TruncateHistory(ctx, "skip-reset", 3)
+ if err != nil {
+ t.Fatalf("TruncateHistory: %v", err)
+ }
+
+ // SetHistory should reset skip to 0.
+ newHistory := []providers.Message{
+ {Role: "user", Content: "fresh"},
+ }
+ err = store.SetHistory(ctx, "skip-reset", newHistory)
+ if err != nil {
+ t.Fatalf("SetHistory: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "skip-reset")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Fatalf("expected 1, got %d", len(history))
+ }
+ if history[0].Content != "fresh" {
+ t.Errorf("content = %q", history[0].Content)
+ }
+}
+
+func TestColonInKey(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ err := store.AddMessage(ctx, "telegram:123", "user", "hi")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "telegram:123")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Fatalf("expected 1, got %d", len(history))
+ }
+
+ // Verify the file is named with underscore.
+ jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl")
+ if _, statErr := os.Stat(jsonlFile); statErr != nil {
+ t.Errorf("expected file %s to exist: %v", jsonlFile, statErr)
+ }
+}
+
+func TestCompact_RemovesSkippedMessages(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // Write 10 messages, then truncate to keep last 3.
+ for i := 0; i < 10; i++ {
+ err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i)))
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+ err := store.TruncateHistory(ctx, "compact", 3)
+ if err != nil {
+ t.Fatalf("TruncateHistory: %v", err)
+ }
+
+ // Before compact: file still has 10 lines.
+ allOnDisk, err := readMessages(store.jsonlPath("compact"), 0)
+ if err != nil {
+ t.Fatalf("readMessages: %v", err)
+ }
+ if len(allOnDisk) != 10 {
+ t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk))
+ }
+
+ // Compact.
+ err = store.Compact(ctx, "compact")
+ if err != nil {
+ t.Fatalf("Compact: %v", err)
+ }
+
+ // After compact: file should have only 3 lines.
+ allOnDisk, err = readMessages(store.jsonlPath("compact"), 0)
+ if err != nil {
+ t.Fatalf("readMessages: %v", err)
+ }
+ if len(allOnDisk) != 3 {
+ t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk))
+ }
+
+ // GetHistory should still return the same 3 messages.
+ history, err := store.GetHistory(ctx, "compact")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 3 {
+ t.Fatalf("expected 3, got %d", len(history))
+ }
+ if history[0].Content != "h" || history[2].Content != "j" {
+ t.Errorf("wrong content: %+v", history)
+ }
+}
+
+func TestCompact_NoOpWhenNoSkip(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ for i := 0; i < 5; i++ {
+ err := store.AddMessage(ctx, "noop", "user", "msg")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ // Compact without prior truncation — should be a no-op.
+ err := store.Compact(ctx, "noop")
+ if err != nil {
+ t.Fatalf("Compact: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "noop")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 5 {
+ t.Errorf("expected 5, got %d", len(history))
+ }
+}
+
+func TestCompact_ThenAppend(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ for i := 0; i < 8; i++ {
+ err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i)))
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ err := store.TruncateHistory(ctx, "cap", 2)
+ if err != nil {
+ t.Fatalf("TruncateHistory: %v", err)
+ }
+ err = store.Compact(ctx, "cap")
+ if err != nil {
+ t.Fatalf("Compact: %v", err)
+ }
+
+ // Append after compaction should work correctly.
+ err = store.AddMessage(ctx, "cap", "user", "new")
+ if err != nil {
+ t.Fatalf("AddMessage after compact: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "cap")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 3 {
+ t.Fatalf("expected 3, got %d", len(history))
+ }
+ // g, h (kept from truncation), new (appended after compaction).
+ if history[0].Content != "g" {
+ t.Errorf("first = %q, want 'g'", history[0].Content)
+ }
+ if history[2].Content != "new" {
+ t.Errorf("last = %q, want 'new'", history[2].Content)
+ }
+}
+
+func TestTruncateHistory_StaleMetaCount(t *testing.T) {
+ // Simulates a crash between JSONL append and meta update in addMsg:
+ // file has N+1 lines but meta.Count is still N. TruncateHistory must
+ // reconcile with the real line count so that keepLast is accurate.
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // Write 10 messages normally (meta.Count = 10).
+ for i := 0; i < 10; i++ {
+ err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i)))
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ // Simulate crash: append a line to JSONL but do NOT update meta.
+ // This leaves meta.Count = 10 while the file has 11 lines.
+ jsonlPath := store.jsonlPath("stale")
+ f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ t.Fatalf("open for append: %v", err)
+ }
+ _, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n")
+ if err != nil {
+ t.Fatalf("write orphan: %v", err)
+ }
+ f.Close()
+
+ // TruncateHistory(keepLast=4) should keep the last 4 of 11 lines,
+ // not the last 4 of 10.
+ err = store.TruncateHistory(ctx, "stale", 4)
+ if err != nil {
+ t.Fatalf("TruncateHistory: %v", err)
+ }
+
+ history, err := store.GetHistory(ctx, "stale")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 4 {
+ t.Fatalf("expected 4, got %d", len(history))
+ }
+ // Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan]
+ if history[0].Content != "h" {
+ t.Errorf("first kept = %q, want 'h'", history[0].Content)
+ }
+ if history[3].Content != "orphan" {
+ t.Errorf("last kept = %q, want 'orphan'", history[3].Content)
+ }
+}
+
+func TestCrashRecovery_PartialLine(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // Write a valid message first.
+ err := store.AddMessage(ctx, "crash", "user", "valid")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+
+ // Simulate a crash by appending a partial JSON line directly.
+ jsonlPath := store.jsonlPath("crash")
+ f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ t.Fatalf("open for append: %v", err)
+ }
+ _, err = f.WriteString(`{"role":"user","content":"incomple`)
+ if err != nil {
+ t.Fatalf("write partial: %v", err)
+ }
+ f.Close()
+
+ // GetHistory should return only the valid message.
+ history, err := store.GetHistory(ctx, "crash")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Fatalf("expected 1 valid message, got %d", len(history))
+ }
+ if history[0].Content != "valid" {
+ t.Errorf("content = %q", history[0].Content)
+ }
+}
+
+func TestPersistence_AcrossInstances(t *testing.T) {
+ dir := t.TempDir()
+ ctx := context.Background()
+
+ // Write with first instance.
+ store1, err := NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore: %v", err)
+ }
+ err = store1.AddMessage(ctx, "persist", "user", "remember me")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ err = store1.SetSummary(ctx, "persist", "a test session")
+ if err != nil {
+ t.Fatalf("SetSummary: %v", err)
+ }
+ store1.Close()
+
+ // Read with second instance.
+ store2, err := NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore: %v", err)
+ }
+ defer store2.Close()
+
+ history, err := store2.GetHistory(ctx, "persist")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 || history[0].Content != "remember me" {
+ t.Errorf("history = %+v", history)
+ }
+
+ summary, err := store2.GetSummary(ctx, "persist")
+ if err != nil {
+ t.Fatalf("GetSummary: %v", err)
+ }
+ if summary != "a test session" {
+ t.Errorf("summary = %q", summary)
+ }
+}
+
+func TestConcurrent_AddAndRead(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ var wg sync.WaitGroup
+ const goroutines = 10
+ const msgsPerGoroutine = 20
+
+ // Concurrent writes.
+ for g := 0; g < goroutines; g++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < msgsPerGoroutine; i++ {
+ _ = store.AddMessage(ctx, "concurrent", "user", "msg")
+ }
+ }()
+ }
+ wg.Wait()
+
+ history, err := store.GetHistory(ctx, "concurrent")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ expected := goroutines * msgsPerGoroutine
+ if len(history) != expected {
+ t.Errorf("expected %d messages, got %d", expected, len(history))
+ }
+}
+
+func TestConcurrent_SummarizeRace(t *testing.T) {
+ // Simulates the #704 race: one goroutine adds messages while
+ // another truncates + sets summary — like summarizeSession().
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // Seed with some messages.
+ for i := 0; i < 20; i++ {
+ err := store.AddMessage(ctx, "race", "user", "seed")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ }
+
+ var wg sync.WaitGroup
+
+ // Writer goroutine (main agent loop).
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < 50; i++ {
+ _ = store.AddMessage(ctx, "race", "user", "new")
+ }
+ }()
+
+ // Summarizer goroutine (background task).
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < 10; i++ {
+ _ = store.SetSummary(ctx, "race", "summary")
+ _ = store.TruncateHistory(ctx, "race", 5)
+ }
+ }()
+
+ wg.Wait()
+
+ // Verify the store is still in a consistent state.
+ _, err := store.GetHistory(ctx, "race")
+ if err != nil {
+ t.Fatalf("GetHistory after race: %v", err)
+ }
+ _, err = store.GetSummary(ctx, "race")
+ if err != nil {
+ t.Fatalf("GetSummary after race: %v", err)
+ }
+}
+
+func TestMultipleSessions_Isolation(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ err := store.AddMessage(ctx, "s1", "user", "msg for s1")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ err = store.AddMessage(ctx, "s2", "user", "msg for s2")
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+
+ h1, err := store.GetHistory(ctx, "s1")
+ if err != nil {
+ t.Fatalf("GetHistory s1: %v", err)
+ }
+ h2, err := store.GetHistory(ctx, "s2")
+ if err != nil {
+ t.Fatalf("GetHistory s2: %v", err)
+ }
+
+ if len(h1) != 1 || h1[0].Content != "msg for s1" {
+ t.Errorf("s1 history = %+v", h1)
+ }
+ if len(h2) != 1 || h2[0].Content != "msg for s2" {
+ t.Errorf("s2 history = %+v", h2)
+ }
+}
+
+func BenchmarkAddMessage(b *testing.B) {
+ dir := b.TempDir()
+ store, err := NewJSONLStore(dir)
+ if err != nil {
+ b.Fatalf("NewJSONLStore: %v", err)
+ }
+ defer store.Close()
+ ctx := context.Background()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = store.AddMessage(ctx, "bench", "user", "benchmark message content")
+ }
+}
+
+func BenchmarkGetHistory_100(b *testing.B) {
+ dir := b.TempDir()
+ store, err := NewJSONLStore(dir)
+ if err != nil {
+ b.Fatalf("NewJSONLStore: %v", err)
+ }
+ defer store.Close()
+ ctx := context.Background()
+
+ for i := 0; i < 100; i++ {
+ _ = store.AddMessage(ctx, "bench", "user", "message content")
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, _ = store.GetHistory(ctx, "bench")
+ }
+}
+
+func BenchmarkGetHistory_1000(b *testing.B) {
+ dir := b.TempDir()
+ store, err := NewJSONLStore(dir)
+ if err != nil {
+ b.Fatalf("NewJSONLStore: %v", err)
+ }
+ defer store.Close()
+ ctx := context.Background()
+
+ for i := 0; i < 1000; i++ {
+ _ = store.AddMessage(ctx, "bench", "user", "message content")
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, _ = store.GetHistory(ctx, "bench")
+ }
+}
diff --git a/pkg/memory/migration.go b/pkg/memory/migration.go
new file mode 100644
index 000000000..c9d5176ab
--- /dev/null
+++ b/pkg/memory/migration.go
@@ -0,0 +1,108 @@
+package memory
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// jsonSession mirrors pkg/session.Session for migration purposes.
+type jsonSession struct {
+ Key string `json:"key"`
+ Messages []providers.Message `json:"messages"`
+ Summary string `json:"summary,omitempty"`
+ Created time.Time `json:"created"`
+ Updated time.Time `json:"updated"`
+}
+
+// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir,
+// writes them into the Store, and renames each migrated file to
+// .json.migrated as a backup. Returns the number of sessions migrated.
+//
+// Files that fail to parse are logged and skipped. Already-migrated
+// files (.json.migrated) are ignored, making the function idempotent.
+func MigrateFromJSON(
+ ctx context.Context, sessionsDir string, store Store,
+) (int, error) {
+ entries, err := os.ReadDir(sessionsDir)
+ if os.IsNotExist(err) {
+ return 0, nil
+ }
+ if err != nil {
+ return 0, fmt.Errorf("memory: read sessions dir: %w", err)
+ }
+
+ migrated := 0
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ name := entry.Name()
+ if !strings.HasSuffix(name, ".json") {
+ continue
+ }
+ // Skip already-migrated files.
+ if strings.HasSuffix(name, ".migrated") {
+ continue
+ }
+
+ srcPath := filepath.Join(sessionsDir, name)
+
+ data, readErr := os.ReadFile(srcPath)
+ if readErr != nil {
+ log.Printf("memory: migrate: skip %s: %v", name, readErr)
+ continue
+ }
+
+ var sess jsonSession
+ if parseErr := json.Unmarshal(data, &sess); parseErr != nil {
+ log.Printf("memory: migrate: skip %s: %v", name, parseErr)
+ continue
+ }
+
+ // Use the key from the JSON content, not the filename.
+ // Filenames are sanitized (":" → "_") but keys are not.
+ key := sess.Key
+ if key == "" {
+ key = strings.TrimSuffix(name, ".json")
+ }
+
+ // Use SetHistory (atomic replace) instead of per-message
+ // AddFullMessage. This makes migration idempotent: if the
+ // process crashes after writing messages but before the
+ // rename below, a retry replaces the partial data cleanly
+ // instead of duplicating messages.
+ if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil {
+ return migrated, fmt.Errorf(
+ "memory: migrate %s: set history: %w",
+ name, setErr,
+ )
+ }
+
+ if sess.Summary != "" {
+ if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil {
+ return migrated, fmt.Errorf(
+ "memory: migrate %s: set summary: %w",
+ name, sumErr,
+ )
+ }
+ }
+
+ // Rename to .migrated as backup (not delete).
+ renameErr := os.Rename(srcPath, srcPath+".migrated")
+ if renameErr != nil {
+ log.Printf("memory: migrate: rename %s: %v", name, renameErr)
+ }
+
+ migrated++
+ }
+
+ return migrated, nil
+}
diff --git a/pkg/memory/migration_test.go b/pkg/memory/migration_test.go
new file mode 100644
index 000000000..3170758b7
--- /dev/null
+++ b/pkg/memory/migration_test.go
@@ -0,0 +1,384 @@
+package memory
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+func writeJSONSession(
+ t *testing.T, dir string, filename string, sess jsonSession,
+) {
+ t.Helper()
+ data, err := json.MarshalIndent(sess, "", " ")
+ if err != nil {
+ t.Fatalf("marshal session: %v", err)
+ }
+ err = os.WriteFile(filepath.Join(dir, filename), data, 0o644)
+ if err != nil {
+ t.Fatalf("write session file: %v", err)
+ }
+}
+
+func TestMigrateFromJSON_Basic(t *testing.T) {
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ writeJSONSession(t, sessionsDir, "test.json", jsonSession{
+ Key: "test",
+ Messages: []providers.Message{
+ {Role: "user", Content: "hello"},
+ {Role: "assistant", Content: "hi"},
+ },
+ Summary: "A greeting.",
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+
+ count, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("MigrateFromJSON: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("expected 1 migrated, got %d", count)
+ }
+
+ history, err := store.GetHistory(ctx, "test")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 2 {
+ t.Fatalf("expected 2 messages, got %d", len(history))
+ }
+ if history[0].Content != "hello" || history[1].Content != "hi" {
+ t.Errorf("unexpected messages: %+v", history)
+ }
+
+ summary, err := store.GetSummary(ctx, "test")
+ if err != nil {
+ t.Fatalf("GetSummary: %v", err)
+ }
+ if summary != "A greeting." {
+ t.Errorf("summary = %q", summary)
+ }
+}
+
+func TestMigrateFromJSON_WithToolCalls(t *testing.T) {
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ writeJSONSession(t, sessionsDir, "tools.json", jsonSession{
+ Key: "tools",
+ Messages: []providers.Message{
+ {
+ Role: "assistant",
+ Content: "Searching...",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Function: &providers.FunctionCall{
+ Name: "web_search",
+ Arguments: `{"q":"test"}`,
+ },
+ },
+ },
+ },
+ {
+ Role: "tool",
+ Content: "result",
+ ToolCallID: "call_1",
+ },
+ },
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+
+ count, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("MigrateFromJSON: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("expected 1, got %d", count)
+ }
+
+ history, err := store.GetHistory(ctx, "tools")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 2 {
+ t.Fatalf("expected 2 messages, got %d", len(history))
+ }
+ if len(history[0].ToolCalls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls))
+ }
+ if history[0].ToolCalls[0].Function.Name != "web_search" {
+ t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name)
+ }
+ if history[1].ToolCallID != "call_1" {
+ t.Errorf("ToolCallID = %q", history[1].ToolCallID)
+ }
+}
+
+func TestMigrateFromJSON_MultipleFiles(t *testing.T) {
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ for i := 0; i < 3; i++ {
+ key := string(rune('a' + i))
+ writeJSONSession(t, sessionsDir, key+".json", jsonSession{
+ Key: key,
+ Messages: []providers.Message{{Role: "user", Content: "msg " + key}},
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+ }
+
+ count, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("MigrateFromJSON: %v", err)
+ }
+ if count != 3 {
+ t.Errorf("expected 3, got %d", count)
+ }
+
+ for i := 0; i < 3; i++ {
+ key := string(rune('a' + i))
+ history, histErr := store.GetHistory(ctx, key)
+ if histErr != nil {
+ t.Fatalf("GetHistory(%q): %v", key, histErr)
+ }
+ if len(history) != 1 {
+ t.Errorf("session %q: expected 1 msg, got %d", key, len(history))
+ }
+ }
+}
+
+func TestMigrateFromJSON_InvalidJSON(t *testing.T) {
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // One valid, one invalid.
+ writeJSONSession(t, sessionsDir, "good.json", jsonSession{
+ Key: "good",
+ Messages: []providers.Message{{Role: "user", Content: "ok"}},
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+ err := os.WriteFile(
+ filepath.Join(sessionsDir, "bad.json"),
+ []byte("{invalid json"),
+ 0o644,
+ )
+ if err != nil {
+ t.Fatalf("write bad file: %v", err)
+ }
+
+ count, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("MigrateFromJSON: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("expected 1 (bad file skipped), got %d", count)
+ }
+
+ history, err := store.GetHistory(ctx, "good")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Errorf("expected 1 message, got %d", len(history))
+ }
+}
+
+func TestMigrateFromJSON_RenamesFiles(t *testing.T) {
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ writeJSONSession(t, sessionsDir, "rename.json", jsonSession{
+ Key: "rename",
+ Messages: []providers.Message{{Role: "user", Content: "hi"}},
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+
+ _, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("MigrateFromJSON: %v", err)
+ }
+
+ // Original .json should not exist.
+ _, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json"))
+ if !os.IsNotExist(statErr) {
+ t.Error("rename.json should have been renamed")
+ }
+ // .json.migrated should exist.
+ _, statErr = os.Stat(
+ filepath.Join(sessionsDir, "rename.json.migrated"),
+ )
+ if statErr != nil {
+ t.Errorf("rename.json.migrated should exist: %v", statErr)
+ }
+}
+
+func TestMigrateFromJSON_Idempotent(t *testing.T) {
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ writeJSONSession(t, sessionsDir, "idem.json", jsonSession{
+ Key: "idem",
+ Messages: []providers.Message{{Role: "user", Content: "once"}},
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+
+ count1, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("first migration: %v", err)
+ }
+ if count1 != 1 {
+ t.Errorf("first run: expected 1, got %d", count1)
+ }
+
+ // Second run should find only .migrated files, skip them.
+ count2, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("second migration: %v", err)
+ }
+ if count2 != 0 {
+ t.Errorf("second run: expected 0, got %d", count2)
+ }
+
+ history, err := store.GetHistory(ctx, "idem")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Errorf("expected 1 message, got %d", len(history))
+ }
+}
+
+func TestMigrateFromJSON_ColonInKey(t *testing.T) {
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ // File is named telegram_123 (sanitized), but the key inside is telegram:123.
+ writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{
+ Key: "telegram:123",
+ Messages: []providers.Message{{Role: "user", Content: "from telegram"}},
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+
+ count, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("MigrateFromJSON: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("expected 1, got %d", count)
+ }
+
+ // Accessible via the original key "telegram:123".
+ history, err := store.GetHistory(ctx, "telegram:123")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(history))
+ }
+ if history[0].Content != "from telegram" {
+ t.Errorf("content = %q", history[0].Content)
+ }
+
+ // In the file-based store, "telegram:123" and "telegram_123" both
+ // sanitize to the same filename, so they share storage. This is
+ // expected — the colon-to-underscore mapping is a one-way function.
+ history2, err := store.GetHistory(ctx, "telegram_123")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ if len(history2) != 1 {
+ t.Errorf("expected 1 (same file), got %d", len(history2))
+ }
+}
+
+func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) {
+ // Simulates a crash during migration: first run writes messages
+ // but doesn't rename the .json file. Second run must replace
+ // (not duplicate) the messages thanks to SetHistory semantics.
+ sessionsDir := t.TempDir()
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ writeJSONSession(t, sessionsDir, "retry.json", jsonSession{
+ Key: "retry",
+ Messages: []providers.Message{
+ {Role: "user", Content: "one"},
+ {Role: "assistant", Content: "two"},
+ },
+ Created: time.Now(),
+ Updated: time.Now(),
+ })
+
+ // First migration succeeds — writes messages and renames file.
+ count, err := MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("first migration: %v", err)
+ }
+ if count != 1 {
+ t.Fatalf("expected 1, got %d", count)
+ }
+
+ // Simulate "crash before rename": restore the .json file.
+ src := filepath.Join(sessionsDir, "retry.json.migrated")
+ dst := filepath.Join(sessionsDir, "retry.json")
+ if renameErr := os.Rename(src, dst); renameErr != nil {
+ t.Fatalf("restore .json: %v", renameErr)
+ }
+
+ // Second migration should re-import without duplicating messages.
+ count, err = MigrateFromJSON(ctx, sessionsDir, store)
+ if err != nil {
+ t.Fatalf("second migration: %v", err)
+ }
+ if count != 1 {
+ t.Fatalf("expected 1, got %d", count)
+ }
+
+ history, err := store.GetHistory(ctx, "retry")
+ if err != nil {
+ t.Fatalf("GetHistory: %v", err)
+ }
+ // Must be exactly 2 messages (not 4 from duplication).
+ if len(history) != 2 {
+ t.Fatalf("expected 2 messages (no duplicates), got %d", len(history))
+ }
+ if history[0].Content != "one" || history[1].Content != "two" {
+ t.Errorf("unexpected messages: %+v", history)
+ }
+}
+
+func TestMigrateFromJSON_NonexistentDir(t *testing.T) {
+ store := newTestStore(t)
+ ctx := context.Background()
+
+ count, err := MigrateFromJSON(ctx, "/nonexistent/path", store)
+ if err != nil {
+ t.Fatalf("MigrateFromJSON: %v", err)
+ }
+ if count != 0 {
+ t.Errorf("expected 0, got %d", count)
+ }
+}
diff --git a/pkg/memory/store.go b/pkg/memory/store.go
new file mode 100644
index 000000000..b6e11707d
--- /dev/null
+++ b/pkg/memory/store.go
@@ -0,0 +1,42 @@
+package memory
+
+import (
+ "context"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// Store defines an interface for persistent session storage.
+// Each method is an atomic operation — there is no separate Save() call.
+type Store interface {
+ // AddMessage appends a simple text message to a session.
+ AddMessage(ctx context.Context, sessionKey, role, content string) error
+
+ // AddFullMessage appends a complete message (with tool calls, etc.) to a session.
+ AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error
+
+ // GetHistory returns all messages for a session in insertion order.
+ // Returns an empty slice (not nil) if the session does not exist.
+ GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error)
+
+ // GetSummary returns the conversation summary for a session.
+ // Returns an empty string if no summary exists.
+ GetSummary(ctx context.Context, sessionKey string) (string, error)
+
+ // SetSummary updates the conversation summary for a session.
+ SetSummary(ctx context.Context, sessionKey, summary string) error
+
+ // TruncateHistory removes all but the last keepLast messages from a session.
+ // If keepLast <= 0, all messages are removed.
+ TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error
+
+ // SetHistory replaces all messages in a session with the provided history.
+ SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error
+
+ // Compact reclaims storage by physically removing logically truncated
+ // data. Backends that do not accumulate dead data may return nil.
+ Compact(ctx context.Context, sessionKey string) error
+
+ // Close releases any resources held by the store.
+ Close() error
+}
diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go
deleted file mode 100644
index 869b39827..000000000
--- a/pkg/migrate/config.go
+++ /dev/null
@@ -1,408 +0,0 @@
-package migrate
-
-import (
- "encoding/json"
- "fmt"
- "os"
- "path/filepath"
- "strings"
- "unicode"
-
- "github.com/sipeed/picoclaw/pkg/config"
-)
-
-var supportedProviders = map[string]bool{
- "anthropic": true,
- "openai": true,
- "openrouter": true,
- "groq": true,
- "zhipu": true,
- "vllm": true,
- "gemini": true,
- "qwen": true,
- "deepseek": true,
- "github_copilot": true,
- "mistral": true,
-}
-
-var supportedChannels = map[string]bool{
- "telegram": true,
- "discord": true,
- "whatsapp": true,
- "feishu": true,
- "qq": true,
- "dingtalk": true,
- "maixcam": true,
-}
-
-func findOpenClawConfig(openclawHome string) (string, error) {
- candidates := []string{
- filepath.Join(openclawHome, "openclaw.json"),
- filepath.Join(openclawHome, "config.json"),
- }
- for _, p := range candidates {
- if _, err := os.Stat(p); err == nil {
- return p, nil
- }
- }
- return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
-}
-
-func LoadOpenClawConfig(configPath string) (map[string]any, error) {
- data, err := os.ReadFile(configPath)
- if err != nil {
- return nil, fmt.Errorf("reading OpenClaw config: %w", err)
- }
-
- var raw map[string]any
- if err := json.Unmarshal(data, &raw); err != nil {
- return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
- }
-
- converted := convertKeysToSnake(raw)
- result, ok := converted.(map[string]any)
- if !ok {
- return nil, fmt.Errorf("unexpected config format")
- }
- return result, nil
-}
-
-func ConvertConfig(data map[string]any) (*config.Config, []string, error) {
- cfg := config.DefaultConfig()
- var warnings []string
-
- if agents, ok := getMap(data, "agents"); ok {
- if defaults, ok := getMap(agents, "defaults"); ok {
- // Prefer model_name, fallback to model for backward compatibility
- if v, ok := getString(defaults, "model_name"); ok {
- cfg.Agents.Defaults.ModelName = v
- } else if v, ok := getString(defaults, "model"); ok {
- cfg.Agents.Defaults.Model = v
- }
- if v, ok := getFloat(defaults, "max_tokens"); ok {
- cfg.Agents.Defaults.MaxTokens = int(v)
- }
- if v, ok := getFloat(defaults, "temperature"); ok {
- cfg.Agents.Defaults.Temperature = &v
- }
- if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
- cfg.Agents.Defaults.MaxToolIterations = int(v)
- }
- if v, ok := getString(defaults, "workspace"); ok {
- cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
- }
- }
- }
-
- if providers, ok := getMap(data, "providers"); ok {
- for name, val := range providers {
- pMap, ok := val.(map[string]any)
- if !ok {
- continue
- }
- apiKey, _ := getString(pMap, "api_key")
- apiBase, _ := getString(pMap, "api_base")
-
- if !supportedProviders[name] {
- if apiKey != "" || apiBase != "" {
- warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
- }
- continue
- }
-
- pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
- switch name {
- case "anthropic":
- cfg.Providers.Anthropic = pc
- case "openai":
- cfg.Providers.OpenAI = config.OpenAIProviderConfig{
- ProviderConfig: pc,
- WebSearch: getBoolOrDefault(pMap, "web_search", true),
- }
- case "openrouter":
- cfg.Providers.OpenRouter = pc
- case "groq":
- cfg.Providers.Groq = pc
- case "zhipu":
- cfg.Providers.Zhipu = pc
- case "vllm":
- cfg.Providers.VLLM = pc
- case "gemini":
- cfg.Providers.Gemini = pc
- }
- }
- }
-
- if channels, ok := getMap(data, "channels"); ok {
- for name, val := range channels {
- cMap, ok := val.(map[string]any)
- if !ok {
- continue
- }
- if !supportedChannels[name] {
- warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
- continue
- }
- enabled, _ := getBool(cMap, "enabled")
- allowFrom := getStringSlice(cMap, "allow_from")
-
- switch name {
- case "telegram":
- cfg.Channels.Telegram.Enabled = enabled
- cfg.Channels.Telegram.AllowFrom = allowFrom
- if v, ok := getString(cMap, "token"); ok {
- cfg.Channels.Telegram.Token = v
- }
- case "discord":
- cfg.Channels.Discord.Enabled = enabled
- cfg.Channels.Discord.AllowFrom = allowFrom
- if v, ok := getString(cMap, "token"); ok {
- cfg.Channels.Discord.Token = v
- }
- case "whatsapp":
- cfg.Channels.WhatsApp.Enabled = enabled
- cfg.Channels.WhatsApp.AllowFrom = allowFrom
- if v, ok := getString(cMap, "bridge_url"); ok {
- cfg.Channels.WhatsApp.BridgeURL = v
- }
- case "feishu":
- cfg.Channels.Feishu.Enabled = enabled
- cfg.Channels.Feishu.AllowFrom = allowFrom
- if v, ok := getString(cMap, "app_id"); ok {
- cfg.Channels.Feishu.AppID = v
- }
- if v, ok := getString(cMap, "app_secret"); ok {
- cfg.Channels.Feishu.AppSecret = v
- }
- if v, ok := getString(cMap, "encrypt_key"); ok {
- cfg.Channels.Feishu.EncryptKey = v
- }
- if v, ok := getString(cMap, "verification_token"); ok {
- cfg.Channels.Feishu.VerificationToken = v
- }
- case "qq":
- cfg.Channels.QQ.Enabled = enabled
- cfg.Channels.QQ.AllowFrom = allowFrom
- if v, ok := getString(cMap, "app_id"); ok {
- cfg.Channels.QQ.AppID = v
- }
- if v, ok := getString(cMap, "app_secret"); ok {
- cfg.Channels.QQ.AppSecret = v
- }
- case "dingtalk":
- cfg.Channels.DingTalk.Enabled = enabled
- cfg.Channels.DingTalk.AllowFrom = allowFrom
- if v, ok := getString(cMap, "client_id"); ok {
- cfg.Channels.DingTalk.ClientID = v
- }
- if v, ok := getString(cMap, "client_secret"); ok {
- cfg.Channels.DingTalk.ClientSecret = v
- }
- case "maixcam":
- cfg.Channels.MaixCam.Enabled = enabled
- cfg.Channels.MaixCam.AllowFrom = allowFrom
- if v, ok := getString(cMap, "host"); ok {
- cfg.Channels.MaixCam.Host = v
- }
- if v, ok := getFloat(cMap, "port"); ok {
- cfg.Channels.MaixCam.Port = int(v)
- }
- }
- }
- }
-
- if gateway, ok := getMap(data, "gateway"); ok {
- if v, ok := getString(gateway, "host"); ok {
- cfg.Gateway.Host = v
- }
- if v, ok := getFloat(gateway, "port"); ok {
- cfg.Gateway.Port = int(v)
- }
- }
-
- if tools, ok := getMap(data, "tools"); ok {
- if web, ok := getMap(tools, "web"); ok {
- // Migrate old "search" config to "brave" if api_key is present
- if search, ok := getMap(web, "search"); ok {
- if v, ok := getString(search, "api_key"); ok {
- cfg.Tools.Web.Brave.APIKey = v
- if v != "" {
- cfg.Tools.Web.Brave.Enabled = true
- }
- }
- if v, ok := getFloat(search, "max_results"); ok {
- cfg.Tools.Web.Brave.MaxResults = int(v)
- cfg.Tools.Web.DuckDuckGo.MaxResults = int(v)
- }
- }
- }
- }
-
- return cfg, warnings, nil
-}
-
-func MergeConfig(existing, incoming *config.Config) *config.Config {
- if existing.Providers.Anthropic.APIKey == "" {
- existing.Providers.Anthropic = incoming.Providers.Anthropic
- }
- if existing.Providers.OpenAI.APIKey == "" {
- existing.Providers.OpenAI = incoming.Providers.OpenAI
- }
- if existing.Providers.OpenRouter.APIKey == "" {
- existing.Providers.OpenRouter = incoming.Providers.OpenRouter
- }
- if existing.Providers.Groq.APIKey == "" {
- existing.Providers.Groq = incoming.Providers.Groq
- }
- if existing.Providers.Zhipu.APIKey == "" {
- existing.Providers.Zhipu = incoming.Providers.Zhipu
- }
- if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
- existing.Providers.VLLM = incoming.Providers.VLLM
- }
- if existing.Providers.Gemini.APIKey == "" {
- existing.Providers.Gemini = incoming.Providers.Gemini
- }
- if existing.Providers.DeepSeek.APIKey == "" {
- existing.Providers.DeepSeek = incoming.Providers.DeepSeek
- }
- if existing.Providers.GitHubCopilot.APIBase == "" {
- existing.Providers.GitHubCopilot = incoming.Providers.GitHubCopilot
- }
- if existing.Providers.Qwen.APIKey == "" {
- existing.Providers.Qwen = incoming.Providers.Qwen
- }
-
- if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
- existing.Channels.Telegram = incoming.Channels.Telegram
- }
- if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
- existing.Channels.Discord = incoming.Channels.Discord
- }
- if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
- existing.Channels.WhatsApp = incoming.Channels.WhatsApp
- }
- if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
- existing.Channels.Feishu = incoming.Channels.Feishu
- }
- if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
- existing.Channels.QQ = incoming.Channels.QQ
- }
- if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
- existing.Channels.DingTalk = incoming.Channels.DingTalk
- }
- if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
- existing.Channels.MaixCam = incoming.Channels.MaixCam
- }
-
- if existing.Tools.Web.Brave.APIKey == "" {
- existing.Tools.Web.Brave = incoming.Tools.Web.Brave
- }
-
- return existing
-}
-
-func camelToSnake(s string) string {
- var result strings.Builder
- for i, r := range s {
- if unicode.IsUpper(r) {
- if i > 0 {
- prev := rune(s[i-1])
- if unicode.IsLower(prev) || unicode.IsDigit(prev) {
- result.WriteRune('_')
- } else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
- result.WriteRune('_')
- }
- }
- result.WriteRune(unicode.ToLower(r))
- } else {
- result.WriteRune(r)
- }
- }
- return result.String()
-}
-
-func convertKeysToSnake(data any) any {
- switch v := data.(type) {
- case map[string]any:
- result := make(map[string]any, len(v))
- for key, val := range v {
- result[camelToSnake(key)] = convertKeysToSnake(val)
- }
- return result
- case []any:
- result := make([]any, len(v))
- for i, val := range v {
- result[i] = convertKeysToSnake(val)
- }
- return result
- default:
- return data
- }
-}
-
-func rewriteWorkspacePath(path string) string {
- path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
- return path
-}
-
-func getMap(data map[string]any, key string) (map[string]any, bool) {
- v, ok := data[key]
- if !ok {
- return nil, false
- }
- m, ok := v.(map[string]any)
- return m, ok
-}
-
-func getString(data map[string]any, key string) (string, bool) {
- v, ok := data[key]
- if !ok {
- return "", false
- }
- s, ok := v.(string)
- return s, ok
-}
-
-func getFloat(data map[string]any, key string) (float64, bool) {
- v, ok := data[key]
- if !ok {
- return 0, false
- }
- f, ok := v.(float64)
- return f, ok
-}
-
-func getBool(data map[string]any, key string) (bool, bool) {
- v, ok := data[key]
- if !ok {
- return false, false
- }
- b, ok := v.(bool)
- return b, ok
-}
-
-func getBoolOrDefault(data map[string]any, key string, defaultVal bool) bool {
- if v, ok := getBool(data, key); ok {
- return v
- }
- return defaultVal
-}
-
-func getStringSlice(data map[string]any, key string) []string {
- v, ok := data[key]
- if !ok {
- return []string{}
- }
- arr, ok := v.([]any)
- if !ok {
- return []string{}
- }
- result := make([]string, 0, len(arr))
- for _, item := range arr {
- if s, ok := item.(string); ok {
- result = append(result, s)
- }
- }
- return result
-}
diff --git a/pkg/migrate/workspace.go b/pkg/migrate/internal/common.go
similarity index 55%
rename from pkg/migrate/workspace.go
rename to pkg/migrate/internal/common.go
index f45748fac..c77ab9f26 100644
--- a/pkg/migrate/workspace.go
+++ b/pkg/migrate/internal/common.go
@@ -1,24 +1,50 @@
-package migrate
+package internal
import (
+ "fmt"
+ "io"
"os"
"path/filepath"
)
-var migrateableFiles = []string{
- "AGENTS.md",
- "SOUL.md",
- "USER.md",
- "TOOLS.md",
- "HEARTBEAT.md",
+func ResolveTargetHome(override string) (string, error) {
+ if override != "" {
+ return ExpandHome(override), nil
+ }
+ if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
+ return ExpandHome(envHome), nil
+ }
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", fmt.Errorf("resolving home directory: %w", err)
+ }
+ return filepath.Join(home, ".picoclaw"), nil
}
-var migrateableDirs = []string{
- "memory",
- "skills",
+func ExpandHome(path string) string {
+ if path == "" {
+ return path
+ }
+ if path[0] == '~' {
+ home, _ := os.UserHomeDir()
+ if len(path) > 1 && path[1] == '/' {
+ return home + path[1:]
+ }
+ return home
+ }
+ return path
}
-func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
+func ResolveWorkspace(homeDir string) string {
+ return filepath.Join(homeDir, "workspace")
+}
+
+func PlanWorkspaceMigration(
+ srcWorkspace, dstWorkspace string,
+ migrateableFiles []string,
+ migrateableDirs []string,
+ force bool,
+) ([]Action, error) {
var actions []Action
for _, filename := range migrateableFiles {
@@ -50,7 +76,7 @@ func planFileCopy(src, dst string, force bool) Action {
return Action{
Type: ActionSkip,
Source: src,
- Destination: dst,
+ Target: dst,
Description: "source file not found",
}
}
@@ -60,7 +86,7 @@ func planFileCopy(src, dst string, force bool) Action {
return Action{
Type: ActionBackup,
Source: src,
- Destination: dst,
+ Target: dst,
Description: "destination exists, will backup and overwrite",
}
}
@@ -68,7 +94,7 @@ func planFileCopy(src, dst string, force bool) Action {
return Action{
Type: ActionCopy,
Source: src,
- Destination: dst,
+ Target: dst,
Description: "copy file",
}
}
@@ -91,7 +117,7 @@ func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
if info.IsDir() {
actions = append(actions, Action{
Type: ActionCreateDir,
- Destination: dst,
+ Target: dst,
Description: "create directory",
})
return nil
@@ -104,3 +130,33 @@ func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
return actions, err
}
+
+func RelPath(path, base string) string {
+ rel, err := filepath.Rel(base, path)
+ if err != nil {
+ return filepath.Base(path)
+ }
+ return rel
+}
+
+func CopyFile(src, dst string) error {
+ srcFile, err := os.Open(src)
+ if err != nil {
+ return err
+ }
+ defer srcFile.Close()
+
+ info, err := srcFile.Stat()
+ if err != nil {
+ return err
+ }
+
+ dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
+ if err != nil {
+ return err
+ }
+ defer dstFile.Close()
+
+ _, err = io.Copy(dstFile, srcFile)
+ return err
+}
diff --git a/pkg/migrate/internal/common_test.go b/pkg/migrate/internal/common_test.go
new file mode 100644
index 000000000..a67293c19
--- /dev/null
+++ b/pkg/migrate/internal/common_test.go
@@ -0,0 +1,186 @@
+package internal
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestExpandHome(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"", ""},
+ {"/absolute/path", "/absolute/path"},
+ {"relative/path", "relative/path"},
+ }
+
+ for _, tt := range tests {
+ result := ExpandHome(tt.input)
+ assert.Equal(t, tt.expected, result)
+ }
+}
+
+func TestExpandHomeWithTilde(t *testing.T) {
+ home, err := os.UserHomeDir()
+ require.NoError(t, err)
+
+ result := ExpandHome("~/path")
+ assert.Equal(t, home+"/path", result)
+
+ result = ExpandHome("~")
+ assert.Equal(t, home, result)
+}
+
+func TestResolveWorkspace(t *testing.T) {
+ result := ResolveWorkspace("/home/user/.picoclaw")
+ assert.Equal(t, "/home/user/.picoclaw/workspace", result)
+}
+
+func TestRelPath(t *testing.T) {
+ result := RelPath("/home/user/.picoclaw/workspace/file.txt", "/home/user/.picoclaw")
+ assert.Equal(t, "workspace/file.txt", result)
+}
+
+func TestRelPathError(t *testing.T) {
+ result := RelPath("relative/path", "/different/base")
+ assert.Equal(t, "path", result)
+}
+
+func TestResolveTargetHome(t *testing.T) {
+ home, err := os.UserHomeDir()
+ require.NoError(t, err)
+
+ result, err := ResolveTargetHome("")
+ require.NoError(t, err)
+ assert.Equal(t, filepath.Join(home, ".picoclaw"), result)
+}
+
+func TestResolveTargetHomeWithOverride(t *testing.T) {
+ result, err := ResolveTargetHome("/custom/path")
+ require.NoError(t, err)
+ assert.Equal(t, "/custom/path", result)
+}
+
+func TestCopyFile(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ sourceFile := filepath.Join(tmpDir, "source.txt")
+ err := os.WriteFile(sourceFile, []byte("test content"), 0o644)
+ require.NoError(t, err)
+
+ dstFile := filepath.Join(tmpDir, "dest.txt")
+ err = CopyFile(sourceFile, dstFile)
+ require.NoError(t, err)
+
+ content, err := os.ReadFile(dstFile)
+ require.NoError(t, err)
+ assert.Equal(t, "test content", string(content))
+}
+
+func TestCopyFileSourceNotFound(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ err := CopyFile(filepath.Join(tmpDir, "nonexistent.txt"), filepath.Join(tmpDir, "dest.txt"))
+ require.Error(t, err)
+}
+
+func TestPlanWorkspaceMigration(t *testing.T) {
+ tmpDir := t.TempDir()
+ srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
+ dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
+
+ err := os.MkdirAll(srcWorkspace, 0o755)
+ require.NoError(t, err)
+
+ err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("content"), 0o644)
+ require.NoError(t, err)
+
+ err = os.MkdirAll(filepath.Join(srcWorkspace, "subdir"), 0o755)
+ require.NoError(t, err)
+
+ err = os.WriteFile(filepath.Join(srcWorkspace, "subdir", "file2.txt"), []byte("content"), 0o644)
+ require.NoError(t, err)
+
+ actions, err := PlanWorkspaceMigration(
+ srcWorkspace,
+ dstWorkspace,
+ []string{"file1.txt"},
+ []string{"subdir"},
+ false,
+ )
+ require.NoError(t, err)
+
+ assert.GreaterOrEqual(t, len(actions), 1)
+}
+
+func TestPlanWorkspaceMigrationExistingFile(t *testing.T) {
+ tests := []struct {
+ name string
+ force bool
+ wantActionType ActionType
+ }{
+ {
+ name: "backup when not forced",
+ force: false,
+ wantActionType: ActionBackup,
+ },
+ {
+ name: "copy when forced",
+ force: true,
+ wantActionType: ActionCopy,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir := t.TempDir()
+ srcWorkspace := filepath.Join(tmpDir, "src", "workspace")
+ dstWorkspace := filepath.Join(tmpDir, "dst", "workspace")
+
+ err := os.MkdirAll(srcWorkspace, 0o755)
+ require.NoError(t, err)
+
+ err = os.MkdirAll(dstWorkspace, 0o755)
+ require.NoError(t, err)
+
+ err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644)
+ require.NoError(t, err)
+
+ err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644)
+ require.NoError(t, err)
+
+ actions, err := PlanWorkspaceMigration(
+ srcWorkspace,
+ dstWorkspace,
+ []string{"file1.txt"},
+ []string{},
+ tt.force,
+ )
+ require.NoError(t, err)
+
+ require.GreaterOrEqual(t, len(actions), 1)
+ assert.Equal(t, tt.wantActionType, actions[0].Type)
+ })
+ }
+}
+
+func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ actions, err := PlanWorkspaceMigration(
+ filepath.Join(tmpDir, "nonexistent"),
+ filepath.Join(tmpDir, "dst", "workspace"),
+ []string{"file1.txt"},
+ []string{},
+ false,
+ )
+ require.NoError(t, err)
+ require.Len(t, actions, 1)
+ assert.Equal(t, ActionSkip, actions[0].Type)
+ assert.Contains(t, actions[0].Description, "source file not found")
+}
diff --git a/pkg/migrate/internal/types.go b/pkg/migrate/internal/types.go
new file mode 100644
index 000000000..e86a4dea1
--- /dev/null
+++ b/pkg/migrate/internal/types.go
@@ -0,0 +1,52 @@
+package internal
+
+type Options struct {
+ DryRun bool
+ ConfigOnly bool
+ WorkspaceOnly bool
+ Force bool
+ Refresh bool
+ Source string
+ SourceHome string
+ TargetHome string
+}
+
+type Operation interface {
+ GetSourceName() string
+ GetSourceHome() (string, error)
+ GetSourceWorkspace() (string, error)
+ GetSourceConfigFile() (string, error)
+ ExecuteConfigMigration(srcConfigPath, dstConfigPath string) error
+ GetMigrateableFiles() []string
+ GetMigrateableDirs() []string
+}
+
+type HandlerFactory func(opts Options) Operation
+
+type ActionType int
+
+const (
+ ActionCopy ActionType = iota
+ ActionSkip
+ ActionBackup
+ ActionConvertConfig
+ ActionCreateDir
+ ActionMergeConfig
+)
+
+type Action struct {
+ Type ActionType
+ Source string
+ Target string
+ Description string
+}
+
+type Result struct {
+ FilesCopied int
+ FilesSkipped int
+ BackupsCreated int
+ ConfigMigrated bool
+ DirsCreated int
+ Warnings []string
+ Errors []error
+}
diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go
index cfa82b7d7..51fecf438 100644
--- a/pkg/migrate/migrate.go
+++ b/pkg/migrate/migrate.go
@@ -2,53 +2,73 @@ package migrate
import (
"fmt"
- "io"
"os"
"path/filepath"
"strings"
- "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/migrate/internal"
+ "github.com/sipeed/picoclaw/pkg/migrate/sources/openclaw"
)
-type ActionType int
+type (
+ Options = internal.Options
+ Operation = internal.Operation
+ ActionType = internal.ActionType
+ Action = internal.Action
+ Result = internal.Result
+ HandlerFactory = internal.HandlerFactory
+)
const (
- ActionCopy ActionType = iota
- ActionSkip
- ActionBackup
- ActionConvertConfig
- ActionCreateDir
- ActionMergeConfig
+ ActionCopy = internal.ActionCopy
+ ActionSkip = internal.ActionSkip
+ ActionBackup = internal.ActionBackup
+ ActionConvertConfig = internal.ActionConvertConfig
+ ActionCreateDir = internal.ActionCreateDir
+ ActionMergeConfig = internal.ActionMergeConfig
)
-type Options struct {
- DryRun bool
- ConfigOnly bool
- WorkspaceOnly bool
- Force bool
- Refresh bool
- OpenClawHome string
- PicoClawHome string
+type MigrateInstance struct {
+ options Options
+ handlers map[string]Operation
}
-type Action struct {
- Type ActionType
- Source string
- Destination string
- Description string
+func NewMigrateInstance(opts Options) *MigrateInstance {
+ instance := &MigrateInstance{
+ options: opts,
+ handlers: make(map[string]Operation),
+ }
+
+ openclaw_handler, err := openclaw.NewOpenclawHandler(opts)
+ if err == nil {
+ instance.Register(openclaw_handler.GetSourceName(), openclaw_handler)
+ }
+
+ return instance
}
-type Result struct {
- FilesCopied int
- FilesSkipped int
- BackupsCreated int
- ConfigMigrated bool
- DirsCreated int
- Warnings []string
- Errors []error
+func (m *MigrateInstance) Register(moduleName string, module Operation) {
+ m.handlers[moduleName] = module
}
-func Run(opts Options) (*Result, error) {
+func (m *MigrateInstance) getCurrentHandler() (Operation, error) {
+ source := m.options.Source
+ if source == "" {
+ source = "openclaw"
+ }
+ handler, ok := m.handlers[source]
+ if !ok {
+ return nil, fmt.Errorf("Source '%s' not found", source)
+ }
+ return handler, nil
+}
+
+func (m *MigrateInstance) Run(opts Options) (*Result, error) {
+ handler, err := m.getCurrentHandler()
+ if err != nil {
+ return nil, err
+ }
+
if opts.ConfigOnly && opts.WorkspaceOnly {
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
}
@@ -57,28 +77,28 @@ func Run(opts Options) (*Result, error) {
opts.WorkspaceOnly = true
}
- openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
+ sourceHome, err := handler.GetSourceHome()
if err != nil {
return nil, err
}
- picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
+ targetHome, err := internal.ResolveTargetHome(opts.TargetHome)
if err != nil {
return nil, err
}
- if _, err = os.Stat(openclawHome); os.IsNotExist(err) {
- return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
+ if _, err = os.Stat(sourceHome); os.IsNotExist(err) {
+ return nil, fmt.Errorf("Source installation not found at %s", sourceHome)
}
- actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
+ actions, warnings, err := m.Plan(opts, sourceHome, targetHome)
if err != nil {
return nil, err
}
- fmt.Println("Migrating from OpenClaw to PicoClaw")
- fmt.Printf(" Source: %s\n", openclawHome)
- fmt.Printf(" Destination: %s\n", picoClawHome)
+ fmt.Println("Migrating from Source to PicoClaw")
+ fmt.Printf(" Source: %s\n", sourceHome)
+ fmt.Printf(" Target: %s\n", targetHome)
fmt.Println()
if opts.DryRun {
@@ -95,19 +115,23 @@ func Run(opts Options) (*Result, error) {
fmt.Println()
}
- result := Execute(actions, openclawHome, picoClawHome)
+ result := m.Execute(actions, sourceHome, targetHome)
result.Warnings = warnings
return result, nil
}
-func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
+func (m *MigrateInstance) Plan(opts Options, sourceHome, targetHome string) ([]Action, []string, error) {
var actions []Action
var warnings []string
+ handler, err := m.getCurrentHandler()
+ if err != nil {
+ return nil, nil, err
+ }
force := opts.Force || opts.Refresh
if !opts.WorkspaceOnly {
- configPath, err := findOpenClawConfig(openclawHome)
+ configPath, err := handler.GetSourceConfigFile()
if err != nil {
if opts.ConfigOnly {
return nil, nil, err
@@ -117,91 +141,95 @@ func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string,
actions = append(actions, Action{
Type: ActionConvertConfig,
Source: configPath,
- Destination: filepath.Join(picoClawHome, "config.json"),
- Description: "convert OpenClaw config to PicoClaw format",
+ Target: filepath.Join(targetHome, "config.json"),
+ Description: "convert Source config to PicoClaw format",
})
-
- data, err := LoadOpenClawConfig(configPath)
- if err == nil {
- _, configWarnings, _ := ConvertConfig(data)
- warnings = append(warnings, configWarnings...)
- }
}
}
if !opts.ConfigOnly {
- srcWorkspace := resolveWorkspace(openclawHome)
- dstWorkspace := resolveWorkspace(picoClawHome)
+ srcWorkspace, err := handler.GetSourceWorkspace()
+ if err != nil {
+ return nil, nil, fmt.Errorf("getting source workspace: %w", err)
+ }
+ dstWorkspace := internal.ResolveWorkspace(targetHome)
if _, err := os.Stat(srcWorkspace); err == nil {
- wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
+ wsActions, err := internal.PlanWorkspaceMigration(srcWorkspace, dstWorkspace,
+ handler.GetMigrateableFiles(),
+ handler.GetMigrateableDirs(),
+ force)
if err != nil {
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
}
actions = append(actions, wsActions...)
} else {
- warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
+ warnings = append(warnings, "Source workspace directory not found, skipping workspace migration")
}
}
return actions, warnings, nil
}
-func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
+func (m *MigrateInstance) Execute(actions []Action, sourceHome, targetHome string) *Result {
result := &Result{}
+ handler, err := m.getCurrentHandler()
+ if err != nil {
+ return result
+ }
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
- if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
+ if err := handler.ExecuteConfigMigration(action.Source, action.Target); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
fmt.Printf(" ✗ Config migration failed: %v\n", err)
} else {
result.ConfigMigrated = true
- fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
+ fmt.Printf(" ✓ Converted config: %s\n", action.Target)
}
case ActionCreateDir:
- if err := os.MkdirAll(action.Destination, 0o755); err != nil {
+ if err := os.MkdirAll(action.Target, 0o755); err != nil {
result.Errors = append(result.Errors, err)
} else {
result.DirsCreated++
}
case ActionBackup:
- bakPath := action.Destination + ".bak"
- if err := copyFile(action.Destination, bakPath); err != nil {
- result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
- fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
+ bakPath := action.Target + ".bak"
+ if err := internal.CopyFile(action.Target, bakPath); err != nil {
+ result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Target, err))
+ fmt.Printf(" ✗ Backup failed: %s\n", action.Target)
continue
}
result.BackupsCreated++
fmt.Printf(
" ✓ Backed up %s -> %s.bak\n",
- filepath.Base(action.Destination),
- filepath.Base(action.Destination),
+ filepath.Base(action.Target),
+ filepath.Base(action.Target),
)
- if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
+ if err := os.MkdirAll(filepath.Dir(action.Target), 0o755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
- if err := copyFile(action.Source, action.Destination); err != nil {
+ if err := internal.CopyFile(action.Source, action.Target); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
- fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
+ fmt.Printf(" ✓ Copied %s\n", internal.RelPath(action.Source, sourceHome))
}
case ActionCopy:
- if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
+ if err := os.MkdirAll(filepath.Dir(action.Target), 0o755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
- if err := copyFile(action.Source, action.Destination); err != nil {
+ if err := internal.CopyFile(action.Source, action.Target); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
- fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
+ fmt.Printf(" ✓ Copied %s\n", internal.RelPath(action.Source, sourceHome))
}
case ActionSkip:
result.FilesSkipped++
@@ -211,31 +239,6 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
return result
}
-func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
- data, err := LoadOpenClawConfig(srcConfigPath)
- if err != nil {
- return err
- }
-
- incoming, _, err := ConvertConfig(data)
- if err != nil {
- return err
- }
-
- if _, err := os.Stat(dstConfigPath); err == nil {
- existing, err := config.LoadConfig(dstConfigPath)
- if err != nil {
- return fmt.Errorf("loading existing PicoClaw config: %w", err)
- }
- incoming = MergeConfig(existing, incoming)
- }
-
- if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil {
- return err
- }
- return config.SaveConfig(dstConfigPath, incoming)
-}
-
func Confirm() bool {
fmt.Print("Proceed with migration? (y/n): ")
var response string
@@ -243,49 +246,7 @@ func Confirm() bool {
return strings.ToLower(strings.TrimSpace(response)) == "y"
}
-func PrintPlan(actions []Action, warnings []string) {
- fmt.Println("Planned actions:")
- copies := 0
- skips := 0
- backups := 0
- configCount := 0
-
- for _, action := range actions {
- switch action.Type {
- case ActionConvertConfig:
- fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
- configCount++
- case ActionCopy:
- fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
- copies++
- case ActionBackup:
- fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
- backups++
- copies++
- case ActionSkip:
- if action.Description != "" {
- fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
- }
- skips++
- case ActionCreateDir:
- fmt.Printf(" [mkdir] %s\n", action.Destination)
- }
- }
-
- if len(warnings) > 0 {
- fmt.Println()
- fmt.Println("Warnings:")
- for _, w := range warnings {
- fmt.Printf(" - %s\n", w)
- }
- }
-
- fmt.Println()
- fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
- copies, configCount, backups, skips)
-}
-
-func PrintSummary(result *Result) {
+func (m *MigrateInstance) PrintSummary(result *Result) {
fmt.Println()
parts := []string{}
if result.FilesCopied > 0 {
@@ -316,83 +277,44 @@ func PrintSummary(result *Result) {
}
}
-func resolveOpenClawHome(override string) (string, error) {
- if override != "" {
- return expandHome(override), nil
- }
- if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
- return expandHome(envHome), nil
- }
- home, err := os.UserHomeDir()
- if err != nil {
- return "", fmt.Errorf("resolving home directory: %w", err)
- }
- return filepath.Join(home, ".openclaw"), nil
-}
+func PrintPlan(actions []Action, warnings []string) {
+ fmt.Println("Planned actions:")
+ copies := 0
+ skips := 0
+ backups := 0
+ configCount := 0
-func resolvePicoClawHome(override string) (string, error) {
- if override != "" {
- return expandHome(override), nil
- }
- if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
- return expandHome(envHome), nil
- }
- home, err := os.UserHomeDir()
- if err != nil {
- return "", fmt.Errorf("resolving home directory: %w", err)
- }
- return filepath.Join(home, ".picoclaw"), nil
-}
-
-func resolveWorkspace(homeDir string) string {
- return filepath.Join(homeDir, "workspace")
-}
-
-func expandHome(path string) string {
- if path == "" {
- return path
- }
- if path[0] == '~' {
- home, _ := os.UserHomeDir()
- if len(path) > 1 && path[1] == '/' {
- return home + path[1:]
+ for _, action := range actions {
+ switch action.Type {
+ case ActionConvertConfig:
+ fmt.Printf(" [config] %s -> %s\n", action.Source, action.Target)
+ configCount++
+ case ActionCopy:
+ fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
+ copies++
+ case ActionBackup:
+ fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Target))
+ backups++
+ copies++
+ case ActionSkip:
+ if action.Description != "" {
+ fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
+ }
+ skips++
+ case ActionCreateDir:
+ fmt.Printf(" [mkdir] %s\n", action.Target)
}
- return home
}
- return path
-}
-
-func backupFile(path string) error {
- bakPath := path + ".bak"
- return copyFile(path, bakPath)
-}
-
-func copyFile(src, dst string) error {
- srcFile, err := os.Open(src)
- if err != nil {
- return err
- }
- defer srcFile.Close()
-
- info, err := srcFile.Stat()
- if err != nil {
- return err
- }
-
- dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
- if err != nil {
- return err
- }
- defer dstFile.Close()
-
- _, err = io.Copy(dstFile, srcFile)
- return err
-}
-
-func relPath(path, base string) string {
- rel, err := filepath.Rel(base, path)
- if err != nil {
- return filepath.Base(path)
- }
- return rel
+
+ if len(warnings) > 0 {
+ fmt.Println()
+ fmt.Println("Warnings:")
+ for _, w := range warnings {
+ fmt.Printf(" - %s\n", w)
+ }
+ }
+
+ fmt.Println()
+ fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
+ copies, configCount, backups, skips)
}
diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go
index b6b3d70aa..fc9c2c3a7 100644
--- a/pkg/migrate/migrate_test.go
+++ b/pkg/migrate/migrate_test.go
@@ -1,875 +1,411 @@
package migrate
import (
- "encoding/json"
"os"
"path/filepath"
"testing"
- "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
-func TestCamelToSnake(t *testing.T) {
- tests := []struct {
- name string
- input string
- want string
- }{
- {"simple", "apiKey", "api_key"},
- {"two words", "apiBase", "api_base"},
- {"three words", "maxToolIterations", "max_tool_iterations"},
- {"already snake", "api_key", "api_key"},
- {"single word", "enabled", "enabled"},
- {"all lower", "model", "model"},
- {"consecutive caps", "apiURL", "api_url"},
- {"starts upper", "Model", "model"},
- {"bridge url", "bridgeUrl", "bridge_url"},
- {"client id", "clientId", "client_id"},
- {"app secret", "appSecret", "app_secret"},
- {"verification token", "verificationToken", "verification_token"},
- {"allow from", "allowFrom", "allow_from"},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := camelToSnake(tt.input)
- if got != tt.want {
- t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want)
- }
- })
+func TestNewMigrateInstance(t *testing.T) {
+ opts := Options{
+ Source: "openclaw",
}
+ instance := NewMigrateInstance(opts)
+ require.NotNil(t, instance)
+ assert.Equal(t, "openclaw", instance.options.Source)
}
-func TestConvertKeysToSnake(t *testing.T) {
- input := map[string]any{
- "apiKey": "test-key",
- "apiBase": "https://example.com",
- "nested": map[string]any{
- "maxTokens": float64(8192),
- "allowFrom": []any{"user1", "user2"},
- "deeperLevel": map[string]any{
- "clientId": "abc",
- },
- },
- }
+func TestMigrateInstanceRegister(t *testing.T) {
+ instance := NewMigrateInstance(Options{})
+ require.NotNil(t, instance)
- result := convertKeysToSnake(input)
- m, ok := result.(map[string]any)
- if !ok {
- t.Fatal("expected map[string]interface{}")
- }
+ mockHandler := &mockOperation{}
+ instance.Register("test-source", mockHandler)
- if _, ok = m["api_key"]; !ok {
- t.Error("expected key 'api_key' after conversion")
- }
- if _, ok = m["api_base"]; !ok {
- t.Error("expected key 'api_base' after conversion")
- }
-
- nested, ok := m["nested"].(map[string]any)
- if !ok {
- t.Fatal("expected nested map")
- }
- if _, ok = nested["max_tokens"]; !ok {
- t.Error("expected key 'max_tokens' in nested map")
- }
- if _, ok = nested["allow_from"]; !ok {
- t.Error("expected key 'allow_from' in nested map")
- }
-
- deeper, ok := nested["deeper_level"].(map[string]any)
- if !ok {
- t.Fatal("expected deeper_level map")
- }
- if _, ok := deeper["client_id"]; !ok {
- t.Error("expected key 'client_id' in deeper level")
- }
+ handler, ok := instance.handlers["test-source"]
+ require.True(t, ok)
+ assert.Equal(t, mockHandler, handler)
}
-func TestLoadOpenClawConfig(t *testing.T) {
+func TestMigrateInstanceGetCurrentHandler(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
- openclawConfig := map[string]any{
- "providers": map[string]any{
- "anthropic": map[string]any{
- "apiKey": "sk-ant-test123",
- "apiBase": "https://api.anthropic.com",
- },
- },
- "agents": map[string]any{
- "defaults": map[string]any{
- "maxTokens": float64(4096),
- "model": "claude-3-opus",
- },
- },
- }
+ instance := NewMigrateInstance(Options{SourceHome: tmpDir})
+ require.NotNil(t, instance)
- data, err := json.Marshal(openclawConfig)
- if err != nil {
- t.Fatal(err)
- }
- if err = os.WriteFile(configPath, data, 0o644); err != nil {
- t.Fatal(err)
- }
-
- result, err := LoadOpenClawConfig(configPath)
- if err != nil {
- t.Fatalf("LoadOpenClawConfig: %v", err)
- }
-
- providers, ok := result["providers"].(map[string]any)
- if !ok {
- t.Fatal("expected providers map")
- }
- anthropic, ok := providers["anthropic"].(map[string]any)
- if !ok {
- t.Fatal("expected anthropic map")
- }
- if anthropic["api_key"] != "sk-ant-test123" {
- t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
- }
-
- agents, ok := result["agents"].(map[string]any)
- if !ok {
- t.Fatal("expected agents map")
- }
- defaults, ok := agents["defaults"].(map[string]any)
- if !ok {
- t.Fatal("expected defaults map")
- }
- if defaults["max_tokens"] != float64(4096) {
- t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"])
- }
+ handler, err := instance.getCurrentHandler()
+ require.NoError(t, err)
+ require.NotNil(t, handler)
+ assert.Equal(t, "openclaw", handler.GetSourceName())
}
-func TestConvertConfig(t *testing.T) {
- t.Run("providers mapping", func(t *testing.T) {
- data := map[string]any{
- "providers": map[string]any{
- "anthropic": map[string]any{
- "api_key": "sk-ant-test",
- "api_base": "https://api.anthropic.com",
- },
- "openrouter": map[string]any{
- "api_key": "sk-or-test",
- },
- "groq": map[string]any{
- "api_key": "gsk-test",
- },
- },
- }
-
- cfg, warnings, err := ConvertConfig(data)
- if err != nil {
- t.Fatalf("ConvertConfig: %v", err)
- }
- if len(warnings) != 0 {
- t.Errorf("expected no warnings, got %v", warnings)
- }
- if cfg.Providers.Anthropic.APIKey != "sk-ant-test" {
- t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test")
- }
- if cfg.Providers.OpenRouter.APIKey != "sk-or-test" {
- t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test")
- }
- if cfg.Providers.Groq.APIKey != "gsk-test" {
- t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test")
- }
- })
-
- t.Run("unsupported provider warning", func(t *testing.T) {
- data := map[string]any{
- "providers": map[string]any{
- "unknown_provider": map[string]any{
- "api_key": "sk-test",
- },
- },
- }
-
- _, warnings, err := ConvertConfig(data)
- if err != nil {
- t.Fatalf("ConvertConfig: %v", err)
- }
- if len(warnings) != 1 {
- t.Fatalf("expected 1 warning, got %d", len(warnings))
- }
- if warnings[0] != "Provider 'unknown_provider' not supported in PicoClaw, skipping" {
- t.Errorf("unexpected warning: %s", warnings[0])
- }
- })
-
- t.Run("channels mapping", func(t *testing.T) {
- data := map[string]any{
- "channels": map[string]any{
- "telegram": map[string]any{
- "enabled": true,
- "token": "tg-token-123",
- "allow_from": []any{"user1"},
- },
- "discord": map[string]any{
- "enabled": true,
- "token": "disc-token-456",
- },
- },
- }
-
- cfg, _, err := ConvertConfig(data)
- if err != nil {
- t.Fatalf("ConvertConfig: %v", err)
- }
- if !cfg.Channels.Telegram.Enabled {
- t.Error("Telegram should be enabled")
- }
- if cfg.Channels.Telegram.Token != "tg-token-123" {
- t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123")
- }
- if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" {
- t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom)
- }
- if !cfg.Channels.Discord.Enabled {
- t.Error("Discord should be enabled")
- }
- })
-
- t.Run("unsupported channel warning", func(t *testing.T) {
- data := map[string]any{
- "channels": map[string]any{
- "email": map[string]any{
- "enabled": true,
- },
- },
- }
-
- _, warnings, err := ConvertConfig(data)
- if err != nil {
- t.Fatalf("ConvertConfig: %v", err)
- }
- if len(warnings) != 1 {
- t.Fatalf("expected 1 warning, got %d", len(warnings))
- }
- if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" {
- t.Errorf("unexpected warning: %s", warnings[0])
- }
- })
-
- t.Run("agent defaults", func(t *testing.T) {
- data := map[string]any{
- "agents": map[string]any{
- "defaults": map[string]any{
- "model": "claude-3-opus",
- "max_tokens": float64(4096),
- "temperature": 0.5,
- "max_tool_iterations": float64(10),
- "workspace": "~/.openclaw/workspace",
- },
- },
- }
-
- cfg, _, err := ConvertConfig(data)
- if err != nil {
- t.Fatalf("ConvertConfig: %v", err)
- }
- if cfg.Agents.Defaults.Model != "claude-3-opus" {
- t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus")
- }
- if cfg.Agents.Defaults.MaxTokens != 4096 {
- t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
- }
- if cfg.Agents.Defaults.Temperature == nil {
- t.Fatalf("Temperature is nil, want %f", 0.5)
- }
- if *cfg.Agents.Defaults.Temperature != 0.5 {
- t.Errorf("Temperature = %f, want %f", *cfg.Agents.Defaults.Temperature, 0.5)
- }
- if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
- t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
- }
- })
-
- t.Run("empty config", func(t *testing.T) {
- data := map[string]any{}
-
- cfg, warnings, err := ConvertConfig(data)
- if err != nil {
- t.Fatalf("ConvertConfig: %v", err)
- }
- if len(warnings) != 0 {
- t.Errorf("expected no warnings, got %v", warnings)
- }
- if cfg.Agents.Defaults.Model != "glm-4.7" {
- t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model)
- }
- })
-}
-
-func TestSupportedProvidersCompatibility(t *testing.T) {
- expected := []string{
- "anthropic",
- "openai",
- "openrouter",
- "groq",
- "zhipu",
- "vllm",
- "gemini",
- }
-
- for _, provider := range expected {
- if !supportedProviders[provider] {
- t.Fatalf("supportedProviders missing expected key %q", provider)
- }
- }
-}
-
-func TestMergeConfig(t *testing.T) {
- t.Run("fills empty fields", func(t *testing.T) {
- existing := config.DefaultConfig()
- incoming := config.DefaultConfig()
- incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
- incoming.Providers.OpenRouter.APIKey = "sk-or-incoming"
-
- result := MergeConfig(existing, incoming)
- if result.Providers.Anthropic.APIKey != "sk-ant-incoming" {
- t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming")
- }
- if result.Providers.OpenRouter.APIKey != "sk-or-incoming" {
- t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming")
- }
- })
-
- t.Run("preserves existing non-empty fields", func(t *testing.T) {
- existing := config.DefaultConfig()
- existing.Providers.Anthropic.APIKey = "sk-ant-existing"
-
- incoming := config.DefaultConfig()
- incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
- incoming.Providers.OpenAI.APIKey = "sk-oai-incoming"
-
- result := MergeConfig(existing, incoming)
- if result.Providers.Anthropic.APIKey != "sk-ant-existing" {
- t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey)
- }
- if result.Providers.OpenAI.APIKey != "sk-oai-incoming" {
- t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey)
- }
- })
-
- t.Run("merges enabled channels", func(t *testing.T) {
- existing := config.DefaultConfig()
- incoming := config.DefaultConfig()
- incoming.Channels.Telegram.Enabled = true
- incoming.Channels.Telegram.Token = "tg-token"
-
- result := MergeConfig(existing, incoming)
- if !result.Channels.Telegram.Enabled {
- t.Error("Telegram should be enabled after merge")
- }
- if result.Channels.Telegram.Token != "tg-token" {
- t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token")
- }
- })
-
- t.Run("preserves existing enabled channels", func(t *testing.T) {
- existing := config.DefaultConfig()
- existing.Channels.Telegram.Enabled = true
- existing.Channels.Telegram.Token = "existing-token"
-
- incoming := config.DefaultConfig()
- incoming.Channels.Telegram.Enabled = true
- incoming.Channels.Telegram.Token = "incoming-token"
-
- result := MergeConfig(existing, incoming)
- if result.Channels.Telegram.Token != "existing-token" {
- t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token)
- }
- })
-}
-
-func TestPlanWorkspaceMigration(t *testing.T) {
- t.Run("copies available files", func(t *testing.T) {
- srcDir := t.TempDir()
- dstDir := t.TempDir()
-
- os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0o644)
- os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0o644)
- os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0o644)
-
- actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
- if err != nil {
- t.Fatalf("PlanWorkspaceMigration: %v", err)
- }
-
- copyCount := 0
- skipCount := 0
- for _, a := range actions {
- if a.Type == ActionCopy {
- copyCount++
- }
- if a.Type == ActionSkip {
- skipCount++
- }
- }
- if copyCount != 3 {
- t.Errorf("expected 3 copies, got %d", copyCount)
- }
- if skipCount != 2 {
- t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount)
- }
- })
-
- t.Run("plans backup for existing destination files", func(t *testing.T) {
- srcDir := t.TempDir()
- dstDir := t.TempDir()
-
- os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0o644)
- os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0o644)
-
- actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
- if err != nil {
- t.Fatalf("PlanWorkspaceMigration: %v", err)
- }
-
- backupCount := 0
- for _, a := range actions {
- if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" {
- backupCount++
- }
- }
- if backupCount != 1 {
- t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount)
- }
- })
-
- t.Run("force skips backup", func(t *testing.T) {
- srcDir := t.TempDir()
- dstDir := t.TempDir()
-
- os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0o644)
- os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0o644)
-
- actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
- if err != nil {
- t.Fatalf("PlanWorkspaceMigration: %v", err)
- }
-
- for _, a := range actions {
- if a.Type == ActionBackup {
- t.Error("expected no backup actions with force=true")
- }
- }
- })
-
- t.Run("handles memory directory", func(t *testing.T) {
- srcDir := t.TempDir()
- dstDir := t.TempDir()
-
- memDir := filepath.Join(srcDir, "memory")
- os.MkdirAll(memDir, 0o755)
- os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0o644)
-
- actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
- if err != nil {
- t.Fatalf("PlanWorkspaceMigration: %v", err)
- }
-
- hasCopy := false
- hasDir := false
- for _, a := range actions {
- if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" {
- hasCopy = true
- }
- if a.Type == ActionCreateDir {
- hasDir = true
- }
- }
- if !hasCopy {
- t.Error("expected copy action for memory/MEMORY.md")
- }
- if !hasDir {
- t.Error("expected create dir action for memory/")
- }
- })
-
- t.Run("handles skills directory", func(t *testing.T) {
- srcDir := t.TempDir()
- dstDir := t.TempDir()
-
- skillDir := filepath.Join(srcDir, "skills", "weather")
- os.MkdirAll(skillDir, 0o755)
- os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0o644)
-
- actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
- if err != nil {
- t.Fatalf("PlanWorkspaceMigration: %v", err)
- }
-
- hasCopy := false
- for _, a := range actions {
- if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" {
- hasCopy = true
- }
- }
- if !hasCopy {
- t.Error("expected copy action for skills/weather/SKILL.md")
- }
- })
-}
-
-func TestFindOpenClawConfig(t *testing.T) {
- t.Run("finds openclaw.json", func(t *testing.T) {
- tmpDir := t.TempDir()
- configPath := filepath.Join(tmpDir, "openclaw.json")
- os.WriteFile(configPath, []byte("{}"), 0o644)
-
- found, err := findOpenClawConfig(tmpDir)
- if err != nil {
- t.Fatalf("findOpenClawConfig: %v", err)
- }
- if found != configPath {
- t.Errorf("found %q, want %q", found, configPath)
- }
- })
-
- t.Run("falls back to config.json", func(t *testing.T) {
- tmpDir := t.TempDir()
- configPath := filepath.Join(tmpDir, "config.json")
- os.WriteFile(configPath, []byte("{}"), 0o644)
-
- found, err := findOpenClawConfig(tmpDir)
- if err != nil {
- t.Fatalf("findOpenClawConfig: %v", err)
- }
- if found != configPath {
- t.Errorf("found %q, want %q", found, configPath)
- }
- })
-
- t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
- tmpDir := t.TempDir()
- openclawPath := filepath.Join(tmpDir, "openclaw.json")
- os.WriteFile(openclawPath, []byte("{}"), 0o644)
- os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0o644)
-
- found, err := findOpenClawConfig(tmpDir)
- if err != nil {
- t.Fatalf("findOpenClawConfig: %v", err)
- }
- if found != openclawPath {
- t.Errorf("should prefer openclaw.json, got %q", found)
- }
- })
-
- t.Run("error when no config found", func(t *testing.T) {
- tmpDir := t.TempDir()
-
- _, err := findOpenClawConfig(tmpDir)
- if err == nil {
- t.Fatal("expected error when no config found")
- }
- })
-}
-
-func TestRewriteWorkspacePath(t *testing.T) {
- tests := []struct {
- name string
- input string
- want string
- }{
- {"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"},
- {"custom path", "/custom/path", "/custom/path"},
- {"empty", "", ""},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := rewriteWorkspacePath(tt.input)
- if got != tt.want {
- t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want)
- }
- })
- }
-}
-
-func TestRunDryRun(t *testing.T) {
- openclawHome := t.TempDir()
- picoClawHome := t.TempDir()
-
- wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0o755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644)
- os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0o644)
-
- configData := map[string]any{
- "providers": map[string]any{
- "anthropic": map[string]any{
- "apiKey": "test-key",
- },
- },
- }
- data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
+func TestMigrateInstanceGetCurrentHandlerWithSource(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
opts := Options{
- DryRun: true,
- OpenClawHome: openclawHome,
- PicoClawHome: picoClawHome,
+ Source: "openclaw",
+ SourceHome: tmpDir,
}
+ instance := NewMigrateInstance(opts)
- result, err := Run(opts)
- if err != nil {
- t.Fatalf("Run: %v", err)
- }
-
- picoWs := filepath.Join(picoClawHome, "workspace")
- if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
- t.Error("dry run should not create files")
- }
- if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) {
- t.Error("dry run should not create config")
- }
-
- _ = result
+ handler, err := instance.getCurrentHandler()
+ require.NoError(t, err)
+ require.NotNil(t, handler)
+ assert.Equal(t, "openclaw", handler.GetSourceName())
}
-func TestRunFullMigration(t *testing.T) {
- openclawHome := t.TempDir()
- picoClawHome := t.TempDir()
-
- wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0o755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0o644)
- os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0o644)
- os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0o644)
-
- memDir := filepath.Join(wsDir, "memory")
- os.MkdirAll(memDir, 0o755)
- os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0o644)
-
- configData := map[string]any{
- "providers": map[string]any{
- "anthropic": map[string]any{
- "apiKey": "sk-ant-migrate-test",
- },
- "openrouter": map[string]any{
- "apiKey": "sk-or-migrate-test",
- },
- },
- "channels": map[string]any{
- "telegram": map[string]any{
- "enabled": true,
- "token": "tg-migrate-test",
- },
- },
- }
- data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
-
- opts := Options{
- Force: true,
- OpenClawHome: openclawHome,
- PicoClawHome: picoClawHome,
+func TestMigrateInstanceGetCurrentHandlerNotFound(t *testing.T) {
+ instance := &MigrateInstance{
+ options: Options{},
+ handlers: make(map[string]Operation),
}
- result, err := Run(opts)
- if err != nil {
- t.Fatalf("Run: %v", err)
- }
-
- picoWs := filepath.Join(picoClawHome, "workspace")
-
- soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
- if err != nil {
- t.Fatalf("reading SOUL.md: %v", err)
- }
- if string(soulData) != "# Soul from OpenClaw" {
- t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw")
- }
-
- agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md"))
- if err != nil {
- t.Fatalf("reading AGENTS.md: %v", err)
- }
- if string(agentsData) != "# Agents from OpenClaw" {
- t.Errorf("AGENTS.md content = %q", string(agentsData))
- }
-
- memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md"))
- if err != nil {
- t.Fatalf("reading memory/MEMORY.md: %v", err)
- }
- if string(memData) != "# Memory notes" {
- t.Errorf("MEMORY.md content = %q", string(memData))
- }
-
- picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json"))
- if err != nil {
- t.Fatalf("loading PicoClaw config: %v", err)
- }
- if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" {
- t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test")
- }
- if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" {
- t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test")
- }
- if !picoConfig.Channels.Telegram.Enabled {
- t.Error("Telegram should be enabled")
- }
- if picoConfig.Channels.Telegram.Token != "tg-migrate-test" {
- t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test")
- }
-
- if result.FilesCopied < 3 {
- t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied)
- }
- if !result.ConfigMigrated {
- t.Error("config should have been migrated")
- }
- if len(result.Errors) > 0 {
- t.Errorf("expected no errors, got %v", result.Errors)
- }
+ _, err := instance.getCurrentHandler()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "not found")
}
-func TestRunOpenClawNotFound(t *testing.T) {
- opts := Options{
- OpenClawHome: "/nonexistent/path/to/openclaw",
- PicoClawHome: t.TempDir(),
+func TestMigrateInstancePlanWithInvalidSource(t *testing.T) {
+ instance := &MigrateInstance{
+ options: Options{},
+ handlers: make(map[string]Operation),
}
- _, err := Run(opts)
- if err == nil {
- t.Fatal("expected error when OpenClaw not found")
- }
+ _, _, err := instance.Plan(Options{}, "/tmp/source", "/tmp/target")
+ require.Error(t, err)
}
-func TestRunMutuallyExclusiveFlags(t *testing.T) {
- opts := Options{
+func TestMigrateInstancePlanConfigOnlyAndWorkspaceOnlyMutuallyExclusive(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ instance := NewMigrateInstance(Options{SourceHome: tmpDir})
+ require.NotNil(t, instance)
+
+ _, err = instance.Run(Options{
ConfigOnly: true,
WorkspaceOnly: true,
- }
-
- _, err := Run(opts)
- if err == nil {
- t.Fatal("expected error for mutually exclusive flags")
- }
+ })
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "mutually exclusive")
}
-func TestBackupFile(t *testing.T) {
+func TestMigrateInstancePlanRefreshSetsWorkspaceOnly(t *testing.T) {
+ opts := Options{
+ Refresh: true,
+ SourceHome: "/tmp/nonexistent",
+ }
+ instance := NewMigrateInstance(opts)
+ require.NotNil(t, instance)
+
+ _, err := instance.Run(opts)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "not found")
+}
+
+func TestMigrateInstancePlanSourceNotFound(t *testing.T) {
+ opts := Options{
+ SourceHome: "/tmp/nonexistent-source-home",
+ }
+ instance := NewMigrateInstance(opts)
+
+ _, err := instance.Run(opts)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "not found")
+}
+
+func TestMigrateInstanceExecute(t *testing.T) {
tmpDir := t.TempDir()
- filePath := filepath.Join(tmpDir, "test.md")
- os.WriteFile(filePath, []byte("original content"), 0o644)
+ sourceDir := filepath.Join(tmpDir, "source")
+ targetDir := filepath.Join(tmpDir, "target")
+ workspaceDir := filepath.Join(sourceDir, "workspace")
- if err := backupFile(filePath); err != nil {
- t.Fatalf("backupFile: %v", err)
+ err := os.MkdirAll(workspaceDir, 0o755)
+ require.NoError(t, err)
+
+ err = os.WriteFile(filepath.Join(workspaceDir, "test.txt"), []byte("test"), 0o644)
+ require.NoError(t, err)
+
+ instance := &MigrateInstance{
+ options: Options{Source: "mock"},
+ handlers: make(map[string]Operation),
}
+ instance.Register("mock", &mockOperation{sourceHome: sourceDir, sourceWs: workspaceDir})
- bakPath := filePath + ".bak"
- bakData, err := os.ReadFile(bakPath)
- if err != nil {
- t.Fatalf("reading backup: %v", err)
- }
- if string(bakData) != "original content" {
- t.Errorf("backup content = %q, want %q", string(bakData), "original content")
- }
-}
-
-func TestCopyFile(t *testing.T) {
- tmpDir := t.TempDir()
- srcPath := filepath.Join(tmpDir, "src.md")
- dstPath := filepath.Join(tmpDir, "dst.md")
-
- os.WriteFile(srcPath, []byte("file content"), 0o644)
-
- if err := copyFile(srcPath, dstPath); err != nil {
- t.Fatalf("copyFile: %v", err)
- }
-
- data, err := os.ReadFile(dstPath)
- if err != nil {
- t.Fatalf("reading copy: %v", err)
- }
- if string(data) != "file content" {
- t.Errorf("copy content = %q, want %q", string(data), "file content")
- }
-}
-
-func TestRunConfigOnly(t *testing.T) {
- openclawHome := t.TempDir()
- picoClawHome := t.TempDir()
-
- wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0o755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644)
-
- configData := map[string]any{
- "providers": map[string]any{
- "anthropic": map[string]any{
- "apiKey": "sk-config-only",
- },
+ actions := []Action{
+ {
+ Type: ActionCopy,
+ Source: filepath.Join(workspaceDir, "test.txt"),
+ Target: filepath.Join(targetDir, "workspace", "test.txt"),
+ Description: "copy file",
},
}
- data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
- opts := Options{
- Force: true,
- ConfigOnly: true,
- OpenClawHome: openclawHome,
- PicoClawHome: picoClawHome,
- }
+ result := instance.Execute(actions, workspaceDir, targetDir)
+ require.NotNil(t, result)
+ assert.Equal(t, 1, result.FilesCopied)
- result, err := Run(opts)
- if err != nil {
- t.Fatalf("Run: %v", err)
- }
-
- if !result.ConfigMigrated {
- t.Error("config should have been migrated")
- }
-
- picoWs := filepath.Join(picoClawHome, "workspace")
- if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
- t.Error("config-only should not copy workspace files")
- }
+ _, err = os.Stat(filepath.Join(targetDir, "workspace", "test.txt"))
+ assert.NoError(t, err)
}
-func TestRunWorkspaceOnly(t *testing.T) {
- openclawHome := t.TempDir()
- picoClawHome := t.TempDir()
+func TestMigrateInstanceExecuteWithInvalidSource(t *testing.T) {
+ tmpDir := t.TempDir()
+ sourceDir := filepath.Join(tmpDir, "source")
+ err := os.MkdirAll(sourceDir, 0o755)
+ require.NoError(t, err)
- wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0o755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644)
+ instance := &MigrateInstance{
+ options: Options{Source: "mock"},
+ handlers: make(map[string]Operation),
+ }
+ instance.Register("mock", &mockOperation{sourceHome: sourceDir})
- configData := map[string]any{
- "providers": map[string]any{
- "anthropic": map[string]any{
- "apiKey": "sk-ws-only",
- },
+ actions := []Action{
+ {
+ Type: ActionCopy,
+ Source: filepath.Join(sourceDir, "nonexistent.txt"),
+ Target: filepath.Join(tmpDir, "target.txt"),
+ Description: "copy file",
},
}
- data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
- opts := Options{
- Force: true,
- WorkspaceOnly: true,
- OpenClawHome: openclawHome,
- PicoClawHome: picoClawHome,
- }
-
- result, err := Run(opts)
- if err != nil {
- t.Fatalf("Run: %v", err)
- }
-
- if result.ConfigMigrated {
- t.Error("workspace-only should not migrate config")
- }
-
- picoWs := filepath.Join(picoClawHome, "workspace")
- soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
- if err != nil {
- t.Fatalf("reading SOUL.md: %v", err)
- }
- if string(soulData) != "# Soul" {
- t.Errorf("SOUL.md content = %q", string(soulData))
- }
+ result := instance.Execute(actions, sourceDir, tmpDir)
+ require.NotNil(t, result)
+ assert.Equal(t, 0, result.FilesCopied)
+ assert.Greater(t, len(result.Errors), 0)
+}
+
+func TestMigrateInstanceExecuteCreateDir(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ instance := &MigrateInstance{
+ options: Options{Source: "mock"},
+ handlers: make(map[string]Operation),
+ }
+ instance.Register("mock", &mockOperation{})
+
+ actions := []Action{
+ {
+ Type: ActionCreateDir,
+ Target: filepath.Join(tmpDir, "new", "dir"),
+ Description: "create directory",
+ },
+ }
+
+ result := instance.Execute(actions, "", "")
+ require.NotNil(t, result)
+ assert.Equal(t, 1, result.DirsCreated)
+
+ _, err := os.Stat(filepath.Join(tmpDir, "new", "dir"))
+ assert.NoError(t, err)
+}
+
+func TestMigrateInstanceExecuteBackup(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ sourceFile := filepath.Join(tmpDir, "source.txt")
+ targetFile := filepath.Join(tmpDir, "target.txt")
+
+ err := os.WriteFile(sourceFile, []byte("source"), 0o644)
+ require.NoError(t, err)
+
+ err = os.WriteFile(targetFile, []byte("target"), 0o644)
+ require.NoError(t, err)
+
+ instance := &MigrateInstance{
+ options: Options{Source: "mock"},
+ handlers: make(map[string]Operation),
+ }
+ instance.Register("mock", &mockOperation{})
+
+ actions := []Action{
+ {
+ Type: ActionBackup,
+ Source: sourceFile,
+ Target: targetFile,
+ Description: "backup and overwrite",
+ },
+ }
+
+ result := instance.Execute(actions, tmpDir, tmpDir)
+ require.NotNil(t, result)
+ assert.Equal(t, 1, result.BackupsCreated)
+ assert.Equal(t, 1, result.FilesCopied)
+
+ bakFile := targetFile + ".bak"
+ _, err = os.Stat(bakFile)
+ assert.NoError(t, err)
+
+ content, err := os.ReadFile(targetFile)
+ assert.NoError(t, err)
+ assert.Equal(t, "source", string(content))
+}
+
+func TestMigrateInstanceExecuteSkip(t *testing.T) {
+ instance := &MigrateInstance{
+ options: Options{Source: "mock"},
+ handlers: make(map[string]Operation),
+ }
+ instance.Register("mock", &mockOperation{})
+
+ actions := []Action{
+ {
+ Type: ActionSkip,
+ Source: "/tmp/source.txt",
+ Target: "/tmp/target.txt",
+ Description: "skip file",
+ },
+ }
+
+ result := instance.Execute(actions, "", "")
+ require.NotNil(t, result)
+ assert.Equal(t, 1, result.FilesSkipped)
+}
+
+func TestMigrateInstancePrintSummary(t *testing.T) {
+ instance := NewMigrateInstance(Options{})
+
+ result := &Result{
+ FilesCopied: 5,
+ ConfigMigrated: true,
+ BackupsCreated: 2,
+ FilesSkipped: 3,
+ Warnings: []string{"warning 1"},
+ Errors: []error{},
+ }
+
+ instance.PrintSummary(result)
+}
+
+func TestMigrateInstancePrintSummaryWithErrors(t *testing.T) {
+ instance := NewMigrateInstance(Options{})
+
+ result := &Result{
+ FilesCopied: 0,
+ ConfigMigrated: false,
+ BackupsCreated: 0,
+ FilesSkipped: 0,
+ Warnings: []string{},
+ Errors: []error{assert.AnError},
+ }
+
+ instance.PrintSummary(result)
+}
+
+func TestMigrateInstancePrintSummaryNoActions(t *testing.T) {
+ instance := NewMigrateInstance(Options{})
+
+ result := &Result{
+ FilesCopied: 0,
+ ConfigMigrated: false,
+ BackupsCreated: 0,
+ FilesSkipped: 0,
+ Warnings: []string{},
+ Errors: []error{},
+ }
+
+ instance.PrintSummary(result)
+}
+
+func TestPrintPlan(t *testing.T) {
+ actions := []Action{
+ {
+ Type: ActionConvertConfig,
+ Source: "/source/config.json",
+ Target: "/target/config.json",
+ Description: "convert config",
+ },
+ {
+ Type: ActionCopy,
+ Source: "/source/file.txt",
+ Target: "/target/file.txt",
+ Description: "copy file",
+ },
+ {
+ Type: ActionBackup,
+ Source: "/source/existing.txt",
+ Target: "/target/existing.txt",
+ Description: "backup and overwrite",
+ },
+ {
+ Type: ActionSkip,
+ Source: "/source/skipped.txt",
+ Target: "/target/skipped.txt",
+ Description: "skip file",
+ },
+ {
+ Type: ActionCreateDir,
+ Target: "/target/newdir",
+ Description: "create directory",
+ },
+ }
+
+ warnings := []string{
+ "Warning: source directory not found",
+ }
+
+ PrintPlan(actions, warnings)
+}
+
+func TestPrintPlanEmpty(t *testing.T) {
+ PrintPlan([]Action{}, []string{})
+}
+
+type mockOperation struct {
+ sourceHome string
+ sourceConfig string
+ sourceWs string
+ migrateFiles []string
+ migrateDirs []string
+}
+
+func (m *mockOperation) GetSourceName() string { return "mock" }
+func (m *mockOperation) GetSourceHome() (string, error) {
+ if m.sourceHome != "" {
+ return m.sourceHome, nil
+ }
+ return "/tmp/mock", nil
+}
+
+func (m *mockOperation) GetSourceWorkspace() (string, error) {
+ if m.sourceWs != "" {
+ return m.sourceWs, nil
+ }
+ if m.sourceHome != "" {
+ return filepath.Join(m.sourceHome, "workspace"), nil
+ }
+ return "/tmp/mock/workspace", nil
+}
+
+func (m *mockOperation) GetSourceConfigFile() (string, error) {
+ if m.sourceConfig != "" {
+ return m.sourceConfig, nil
+ }
+ return "/tmp/mock/config.json", nil
+}
+func (m *mockOperation) ExecuteConfigMigration(src, dst string) error { return nil }
+func (m *mockOperation) GetMigrateableFiles() []string {
+ if m.migrateFiles != nil {
+ return m.migrateFiles
+ }
+ return []string{}
+}
+
+func (m *mockOperation) GetMigrateableDirs() []string {
+ if m.migrateDirs != nil {
+ return m.migrateDirs
+ }
+ return []string{}
}
diff --git a/pkg/migrate/sources/openclaw/common.go b/pkg/migrate/sources/openclaw/common.go
new file mode 100644
index 000000000..dddd98089
--- /dev/null
+++ b/pkg/migrate/sources/openclaw/common.go
@@ -0,0 +1,29 @@
+package openclaw
+
+var migrateableFiles = []string{
+ "AGENTS.md",
+ "SOUL.md",
+ "USER.md",
+ "TOOLS.md",
+ "HEARTBEAT.md",
+}
+
+var migrateableDirs = []string{
+ "memory",
+ "skills",
+}
+
+var supportedChannels = map[string]bool{
+ "whatsapp": true,
+ "telegram": true,
+ "feishu": true,
+ "discord": true,
+ "maixcam": true,
+ "qq": true,
+ "dingtalk": true,
+ "slack": true,
+ "line": true,
+ "onebot": true,
+ "wecom": true,
+ "wecom_app": true,
+}
diff --git a/pkg/migrate/sources/openclaw/openclaw_config.go b/pkg/migrate/sources/openclaw/openclaw_config.go
new file mode 100644
index 000000000..39ad48fad
--- /dev/null
+++ b/pkg/migrate/sources/openclaw/openclaw_config.go
@@ -0,0 +1,1074 @@
+package openclaw
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+type OpenClawConfig struct {
+ Auth *OpenClawAuth `json:"auth"`
+ Models *OpenClawModels `json:"models"`
+ Agents *OpenClawAgents `json:"agents"`
+ Tools *OpenClawTools `json:"tools"`
+ Channels *OpenClawChannels `json:"channels"`
+ Cron json.RawMessage `json:"cron"`
+ Hooks json.RawMessage `json:"hooks"`
+ Skills *OpenClawSkills `json:"skills"`
+ Memory json.RawMessage `json:"memory"`
+ Session json.RawMessage `json:"session"`
+}
+
+type OpenClawAuth struct {
+ Profiles json.RawMessage `json:"profiles"`
+ Order json.RawMessage `json:"order"`
+}
+
+type OpenClawModels struct {
+ Providers map[string]json.RawMessage `json:"providers"`
+}
+
+type ProviderConfig struct {
+ BaseUrl string `json:"baseUrl"`
+ Api string `json:"api"`
+ Models []ModelConfig `json:"models"`
+ ApiKey string `json:"apiKey"`
+}
+
+type OpenClawModelConfig struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Reasoning bool `json:"reasoning"`
+ Input []string `json:"input"`
+ Cost Cost `json:"cost"`
+ ContextWindow int `json:"contextWindow"`
+ MaxTokens int `json:"maxTokens"`
+ Api string `json:"api,omitempty"`
+}
+
+type Cost struct {
+ Input float64 `json:"input"`
+ Output float64 `json:"output"`
+ CacheRead float64 `json:"cacheRead"`
+ CacheWrite float64 `json:"cacheWrite"`
+}
+
+type OpenClawTools struct {
+ Profile *string `json:"profile"`
+ Allow []string `json:"allow"`
+ Deny []string `json:"deny"`
+}
+
+type OpenClawAgents struct {
+ Defaults *OpenClawAgentDefaults `json:"defaults"`
+ List []OpenClawAgentEntry `json:"list"`
+}
+
+type OpenClawAgentDefaults struct {
+ Model *OpenClawAgentModel `json:"model"`
+ Workspace *string `json:"workspace"`
+ Tools *OpenClawAgentTools `json:"tools"`
+ Identity *string `json:"identity"`
+}
+
+type OpenClawAgentModel struct {
+ Simple string `json:"-"`
+ Primary *string `json:"primary"`
+ Fallbacks []string `json:"fallbacks"`
+}
+
+func (m *OpenClawAgentModel) GetPrimary() string {
+ if m.Simple != "" {
+ return m.Simple
+ }
+ if m.Primary != nil {
+ return *m.Primary
+ }
+ return ""
+}
+
+func (m *OpenClawAgentModel) GetFallbacks() []string {
+ return m.Fallbacks
+}
+
+type OpenClawAgentEntry struct {
+ ID string `json:"id"`
+ Name *string `json:"name"`
+ Model *OpenClawAgentModel `json:"model"`
+ Tools *OpenClawAgentTools `json:"tools"`
+ Workspace *string `json:"workspace"`
+ Skills []string `json:"skills"`
+ Identity *string `json:"identity"`
+}
+
+type OpenClawAgentTools struct {
+ Profile *string `json:"profile"`
+ Allow []string `json:"allow"`
+ Deny []string `json:"deny"`
+ AlsoAllow []string `json:"alsoAllow"`
+}
+
+type OpenClawChannels struct {
+ Telegram *OpenClawTelegramConfig `json:"telegram"`
+ Discord *OpenClawDiscordConfig `json:"discord"`
+ Slack *OpenClawSlackConfig `json:"slack"`
+ WhatsApp *OpenClawWhatsAppConfig `json:"whatsapp"`
+ Signal *OpenClawSignalConfig `json:"signal"`
+ Matrix *OpenClawMatrixConfig `json:"matrix"`
+ GoogleChat *OpenClawGoogleChatConfig `json:"googlechat"`
+ Teams *OpenClawTeamsConfig `json:"msteams"`
+ IRC *OpenClawIrcConfig `json:"irc"`
+ Mattermost *OpenClawMattermostConfig `json:"mattermost"`
+ Feishu *OpenClawFeishuConfig `json:"feishu"`
+ IMessage *OpenClawIMessageConfig `json:"imessage"`
+ BlueBubbles *OpenClawBlueBubblesConfig `json:"bluebubbles"`
+ QQ *OpenClawQQConfig `json:"qq"`
+ DingTalk *OpenClawDingTalkConfig `json:"dingtalk"`
+ MaixCam *OpenClawMaixCamConfig `json:"maixcam"`
+}
+
+type OpenClawTelegramConfig struct {
+ BotToken *string `json:"botToken"`
+ AllowFrom []string `json:"allowFrom"`
+ GroupPolicy *string `json:"groupPolicy"`
+ DmPolicy *string `json:"dmPolicy"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawDiscordConfig struct {
+ Token *string `json:"token"`
+ Guilds json.RawMessage `json:"guilds"`
+ DmPolicy *string `json:"dmPolicy"`
+ GroupPolicy *string `json:"groupPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawSlackConfig struct {
+ BotToken *string `json:"botToken"`
+ AppToken *string `json:"appToken"`
+ DmPolicy *string `json:"dmPolicy"`
+ GroupPolicy *string `json:"groupPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawWhatsAppConfig struct {
+ AuthDir *string `json:"authDir"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ GroupPolicy *string `json:"groupPolicy"`
+ Enabled *bool `json:"enabled"`
+ BridgeURL *string `json:"bridgeUrl"`
+}
+
+type OpenClawSignalConfig struct {
+ HttpUrl *string `json:"httpUrl"`
+ HttpHost *string `json:"httpHost"`
+ HttpPort *int `json:"httpPort"`
+ Account *string `json:"account"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawMatrixConfig struct {
+ Homeserver *string `json:"homeserver"`
+ UserID *string `json:"userId"`
+ AccessToken *string `json:"accessToken"`
+ Rooms []string `json:"rooms"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawGoogleChatConfig struct {
+ ServiceAccountFile *string `json:"serviceAccountFile"`
+ WebhookPath *string `json:"webhookPath"`
+ BotUser *string `json:"botUser"`
+ DmPolicy *string `json:"dmPolicy"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawTeamsConfig struct {
+ AppID *string `json:"appId"`
+ AppPassword *string `json:"appPassword"`
+ TenantID *string `json:"tenantId"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawIrcConfig struct {
+ Host *string `json:"host"`
+ Port *int `json:"port"`
+ TLS *bool `json:"tls"`
+ Nick *string `json:"nick"`
+ Password *string `json:"password"`
+ Channels []string `json:"channels"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawMattermostConfig struct {
+ BotToken *string `json:"botToken"`
+ BaseURL *string `json:"baseUrl"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawFeishuConfig struct {
+ AppID *string `json:"appId"`
+ AppSecret *string `json:"appSecret"`
+ Domain *string `json:"domain"`
+ DmPolicy *string `json:"dmPolicy"`
+ Enabled *bool `json:"enabled"`
+ VerificationToken *string `json:"verificationToken"`
+ EncryptKey *string `json:"encryptKey"`
+ AllowFrom []string `json:"allowFrom"`
+}
+
+type OpenClawIMessageConfig struct {
+ CliPath *string `json:"cliPath"`
+ DbPath *string `json:"dbPath"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawBlueBubblesConfig struct {
+ ServerURL *string `json:"serverUrl"`
+ Password *string `json:"password"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawQQConfig struct {
+ AppID *string `json:"appId"`
+ AppSecret *string `json:"appSecret"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawDingTalkConfig struct {
+ AppID *string `json:"appId"`
+ AppSecret *string `json:"appSecret"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawMaixCamConfig struct {
+ Host *string `json:"host"`
+ Port *int `json:"port"`
+ DmPolicy *string `json:"dmPolicy"`
+ AllowFrom []string `json:"allowFrom"`
+ Enabled *bool `json:"enabled"`
+}
+
+type OpenClawSkills struct {
+ Entries map[string]json.RawMessage `json:"entries"`
+ Load json.RawMessage `json:"load"`
+}
+
+type OpenClawProviderConfig struct {
+ APIKey string `json:"api_key"`
+ BaseURL string `json:"base_url"`
+}
+
+func (c *OpenClawConfig) GetEnabled() bool {
+ return true
+}
+
+func LoadOpenClawConfig(path string) (*OpenClawConfig, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read config: %w", err)
+ }
+
+ var config OpenClawConfig
+ if err := json.Unmarshal(data, &config); err != nil {
+ return nil, fmt.Errorf("failed to parse JSON: %w", err)
+ }
+
+ return &config, nil
+}
+
+func LoadOpenClawConfigFromDir(dir string) (*OpenClawConfig, error) {
+ candidates := []string{
+ filepath.Join(dir, "openclaw.json"),
+ filepath.Join(dir, "config.json"),
+ }
+
+ for _, p := range candidates {
+ if _, err := os.Stat(p); err == nil {
+ return LoadOpenClawConfig(p)
+ }
+ }
+
+ return nil, fmt.Errorf("no config file found in %s", dir)
+}
+
+func GetProviderConfig(models *OpenClawModels) map[string]OpenClawProviderConfig {
+ result := make(map[string]OpenClawProviderConfig)
+ if models == nil || models.Providers == nil {
+ return result
+ }
+
+ for name, raw := range models.Providers {
+ var prov OpenClawProviderConfig
+ if err := json.Unmarshal(raw, &prov); err != nil {
+ continue
+ }
+ mappedName := mapProvider(name)
+ result[mappedName] = prov
+ }
+
+ return result
+}
+
+func GetProviderConfigFromDir(dir string) map[string]ProviderConfig {
+ result := make(map[string]ProviderConfig)
+ p := filepath.Join(dir, "agents", "main", "agent", "models.json")
+
+ if _, err := os.Stat(p); err != nil {
+ return result
+ }
+
+ data, err := os.ReadFile(p)
+ if err != nil {
+ return result
+ }
+ var models OpenClawModels
+ if err := json.Unmarshal(data, &models); err != nil {
+ return result
+ }
+
+ for name, raw := range models.Providers {
+ var prov ProviderConfig
+ if err := json.Unmarshal(raw, &prov); err != nil {
+ continue
+ }
+ mappedName := mapProvider(name)
+ result[mappedName] = prov
+ }
+ return result
+}
+
+func (c *OpenClawConfig) IsChannelEnabled(name string) bool {
+ switch name {
+ case "telegram":
+ return c.Channels.Telegram == nil || c.Channels.Telegram.Enabled == nil || *c.Channels.Telegram.Enabled
+ case "discord":
+ return c.Channels.Discord == nil || c.Channels.Discord.Enabled == nil || *c.Channels.Discord.Enabled
+ case "slack":
+ return c.Channels.Slack == nil || c.Channels.Slack.Enabled == nil || *c.Channels.Slack.Enabled
+ case "whatsapp":
+ return c.Channels.WhatsApp == nil || c.Channels.WhatsApp.Enabled == nil || *c.Channels.WhatsApp.Enabled
+ case "feishu":
+ return c.Channels.Feishu == nil || c.Channels.Feishu.Enabled == nil || *c.Channels.Feishu.Enabled
+ default:
+ return false
+ }
+}
+
+func GetChannelAllowFrom(ch any) []string {
+ switch c := ch.(type) {
+ case *OpenClawTelegramConfig:
+ if c == nil {
+ return nil
+ }
+ return c.AllowFrom
+ case *OpenClawDiscordConfig:
+ if c == nil {
+ return nil
+ }
+ return c.AllowFrom
+ case *OpenClawSlackConfig:
+ if c == nil {
+ return nil
+ }
+ return c.AllowFrom
+ case *OpenClawWhatsAppConfig:
+ if c == nil {
+ return nil
+ }
+ return c.AllowFrom
+ case *OpenClawFeishuConfig:
+ if c == nil {
+ return nil
+ }
+ return c.AllowFrom
+ default:
+ return nil
+ }
+}
+
+func (c *OpenClawConfig) GetDefaultModel() (provider, model string) {
+ if c.Agents == nil || c.Agents.Defaults == nil || c.Agents.Defaults.Model == nil {
+ return "anthropic", "claude-sonnet-4-20250514"
+ }
+
+ primary := c.Agents.Defaults.Model.GetPrimary()
+ if primary == "" {
+ return "anthropic", "claude-sonnet-4-20250514"
+ }
+
+ parts := strings.Split(primary, "/")
+ if len(parts) > 1 {
+ return mapProvider(parts[0]), parts[1]
+ }
+
+ return "anthropic", primary
+}
+
+func (c *OpenClawConfig) GetDefaultWorkspace() string {
+ if c.Agents == nil || c.Agents.Defaults == nil || c.Agents.Defaults.Workspace == nil {
+ return ""
+ }
+ return rewriteWorkspacePath(*c.Agents.Defaults.Workspace)
+}
+
+func (c *OpenClawConfig) GetAgents() []OpenClawAgentEntry {
+ if c.Agents == nil {
+ return nil
+ }
+ return c.Agents.List
+}
+
+func (c *OpenClawConfig) HasSkills() bool {
+ return c.Skills != nil && c.Skills.Entries != nil && len(c.Skills.Entries) > 0
+}
+
+func (c *OpenClawConfig) HasMemory() bool {
+ return c.Memory != nil && len(c.Memory) > 0
+}
+
+func (c *OpenClawConfig) HasCron() bool {
+ return c.Cron != nil && len(c.Cron) > 0
+}
+
+func (c *OpenClawConfig) HasHooks() bool {
+ return c.Hooks != nil && len(c.Hooks) > 0
+}
+
+func (c *OpenClawConfig) HasSession() bool {
+ return c.Session != nil && len(c.Session) > 0
+}
+
+func (c *OpenClawConfig) HasAuthProfiles() bool {
+ return c.Auth != nil && c.Auth.Profiles != nil && len(c.Auth.Profiles) > 0
+}
+
+func (c *OpenClawConfig) ConvertToPicoClaw(sourceHome string) (*PicoClawConfig, []string, error) {
+ cfg := &PicoClawConfig{}
+ var warnings []string
+
+ provider, modelName := c.GetDefaultModel()
+ cfg.Agents.Defaults.Workspace = c.GetDefaultWorkspace()
+ cfg.Agents.Defaults.ModelName = modelName
+
+ providerConfigs := GetProviderConfigFromDir(sourceHome)
+ defaultAPIKey := ""
+ defaultBaseURL := ""
+
+ if provCfg, ok := providerConfigs[provider]; ok {
+ defaultAPIKey = provCfg.ApiKey
+ defaultBaseURL = provCfg.BaseUrl
+ }
+
+ cfg.ModelList = []ModelConfig{
+ {
+ ModelName: modelName,
+ Model: fmt.Sprintf("%s/%s", provider, modelName),
+ APIKey: defaultAPIKey,
+ APIBase: defaultBaseURL,
+ },
+ }
+
+ for provName, provCfg := range providerConfigs {
+ if provName == provider {
+ continue
+ }
+ if provCfg.ApiKey != "" {
+ continue
+ }
+ cfg.ModelList = append(cfg.ModelList, ModelConfig{
+ ModelName: fmt.Sprintf("%s", provName),
+ Model: fmt.Sprintf("%s/%s", provName, provName),
+ APIKey: provCfg.ApiKey,
+ APIBase: provCfg.BaseUrl,
+ })
+ }
+
+ cfg.Channels = c.convertChannels(&warnings)
+
+ agentList := c.convertAgents(&warnings)
+ if len(agentList) > 0 {
+ cfg.Agents.List = agentList
+ }
+
+ if c.HasSkills() {
+ warnings = append(
+ warnings,
+ fmt.Sprintf(
+ "Skills (%d entries) not automatically migrated - reinstall via picoclaw CLI",
+ len(c.Skills.Entries),
+ ),
+ )
+ }
+ if c.HasMemory() {
+ warnings = append(warnings, "Memory backend config not migrated - PicoClaw uses SQLite with vector embeddings")
+ }
+ if c.HasCron() {
+ warnings = append(
+ warnings,
+ "Cron job scheduling not supported in PicoClaw - consider using external schedulers",
+ )
+ }
+ if c.HasHooks() {
+ warnings = append(warnings, "Webhook hooks not supported in PicoClaw - use event system instead")
+ }
+ if c.HasSession() {
+ warnings = append(warnings, "Session scope config differs - PicoClaw uses per-agent sessions by default")
+ }
+ if c.HasAuthProfiles() {
+ warnings = append(
+ warnings,
+ "Auth profiles (API keys, OAuth tokens) not migrated for security - set env vars manually",
+ )
+ }
+
+ return cfg, warnings, nil
+}
+
+type ModelConfig struct {
+ ModelName string `json:"model_name"`
+ Model string `json:"model"`
+ APIBase string `json:"api_base,omitempty"`
+ APIKey string `json:"api_key"`
+ Proxy string `json:"proxy,omitempty"`
+}
+
+type PicoClawConfig struct {
+ Agents AgentsConfig `json:"agents"`
+ Bindings []AgentBinding `json:"bindings,omitempty"`
+ Channels ChannelsConfig `json:"channels"`
+ ModelList []ModelConfig `json:"model_list"`
+ Gateway GatewayConfig `json:"gateway"`
+ Tools ToolsConfig `json:"tools"`
+}
+
+type AgentsConfig struct {
+ Defaults AgentDefaults `json:"defaults"`
+ List []AgentConfig `json:"list,omitempty"`
+}
+
+type AgentDefaults struct {
+ Workspace string `json:"workspace"`
+ RestrictToWorkspace bool `json:"restrict_to_workspace"`
+ Provider string `json:"provider"`
+ ModelName string `json:"model_name"`
+ Model string `json:"model,omitempty"`
+ ModelFallbacks []string `json:"model_fallbacks,omitempty"`
+ ImageModel string `json:"image_model,omitempty"`
+ ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
+ MaxTokens int `json:"max_tokens"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ MaxToolIterations int `json:"max_tool_iterations"`
+}
+
+type AgentConfig struct {
+ ID string `json:"id"`
+ Default bool `json:"default,omitempty"`
+ Name string `json:"name,omitempty"`
+ Workspace string `json:"workspace,omitempty"`
+ Model *AgentModelConfig `json:"model,omitempty"`
+ Skills []string `json:"skills,omitempty"`
+}
+
+type AgentModelConfig struct {
+ Primary string `json:"primary,omitempty"`
+ Fallbacks []string `json:"fallbacks,omitempty"`
+}
+
+type AgentBinding struct {
+ AgentID string `json:"agent_id"`
+ Match BindingMatch `json:"match"`
+}
+
+type BindingMatch struct {
+ Channel string `json:"channel"`
+ AccountID string `json:"account_id,omitempty"`
+ Peer *PeerMatch `json:"peer,omitempty"`
+ GuildID string `json:"guild_id,omitempty"`
+ TeamID string `json:"team_id,omitempty"`
+}
+
+type PeerMatch struct {
+ Kind string `json:"kind"`
+ ID string `json:"id"`
+}
+
+type ChannelsConfig struct {
+ WhatsApp WhatsAppConfig `json:"whatsapp"`
+ Telegram TelegramConfig `json:"telegram"`
+ Feishu FeishuConfig `json:"feishu"`
+ Discord DiscordConfig `json:"discord"`
+ MaixCam MaixCamConfig `json:"maixcam"`
+ QQ QQConfig `json:"qq"`
+ DingTalk DingTalkConfig `json:"dingtalk"`
+ Slack SlackConfig `json:"slack"`
+ LINE LINEConfig `json:"line"`
+}
+
+type WhatsAppConfig struct {
+ Enabled bool `json:"enabled"`
+ BridgeURL string `json:"bridge_url"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type TelegramConfig struct {
+ Enabled bool `json:"enabled"`
+ Token string `json:"token"`
+ Proxy string `json:"proxy"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type FeishuConfig struct {
+ Enabled bool `json:"enabled"`
+ AppID string `json:"app_id"`
+ AppSecret string `json:"app_secret"`
+ EncryptKey string `json:"encrypt_key"`
+ VerificationToken string `json:"verification_token"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type DiscordConfig struct {
+ Enabled bool `json:"enabled"`
+ Token string `json:"token"`
+ MentionOnly bool `json:"mention_only"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type MaixCamConfig struct {
+ Enabled bool `json:"enabled"`
+ Host string `json:"host"`
+ Port int `json:"port"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type QQConfig struct {
+ Enabled bool `json:"enabled"`
+ AppID string `json:"app_id"`
+ AppSecret string `json:"app_secret"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type DingTalkConfig struct {
+ Enabled bool `json:"enabled"`
+ ClientID string `json:"client_id"`
+ ClientSecret string `json:"client_secret"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type SlackConfig struct {
+ Enabled bool `json:"enabled"`
+ BotToken string `json:"bot_token"`
+ AppToken string `json:"app_token"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type LINEConfig struct {
+ Enabled bool `json:"enabled"`
+ ChannelSecret string `json:"channel_secret"`
+ ChannelAccessToken string `json:"channel_access_token"`
+ WebhookHost string `json:"webhook_host"`
+ WebhookPort int `json:"webhook_port"`
+ WebhookPath string `json:"webhook_path"`
+ AllowFrom []string `json:"allow_from"`
+}
+
+type GatewayConfig struct {
+ Host string `json:"host"`
+ Port int `json:"port"`
+}
+
+type ToolsConfig struct {
+ Web WebToolsConfig `json:"web"`
+ Cron CronConfig `json:"cron"`
+ Exec ExecConfig `json:"exec"`
+}
+
+type WebToolsConfig struct {
+ Brave BraveConfig `json:"brave"`
+ Tavily TavilyConfig `json:"tavily"`
+ DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
+ Perplexity PerplexityConfig `json:"perplexity"`
+ Proxy string `json:"proxy,omitempty"`
+}
+
+type BraveConfig struct {
+ Enabled bool `json:"enabled"`
+ APIKey string `json:"api_key"`
+ MaxResults int `json:"max_results"`
+}
+
+type TavilyConfig struct {
+ Enabled bool `json:"enabled"`
+ APIKey string `json:"api_key"`
+ BaseURL string `json:"base_url"`
+ MaxResults int `json:"max_results"`
+}
+
+type DuckDuckGoConfig struct {
+ Enabled bool `json:"enabled"`
+ MaxResults int `json:"max_results"`
+}
+
+type PerplexityConfig struct {
+ Enabled bool `json:"enabled"`
+ APIKey string `json:"api_key"`
+ MaxResults int `json:"max_results"`
+}
+
+type CronConfig struct {
+ ExecTimeoutMinutes int `json:"exec_timeout_minutes"`
+}
+
+type ExecConfig struct {
+ EnableDenyPatterns bool `json:"enable_deny_patterns"`
+ CustomDenyPatterns []string `json:"custom_deny_patterns"`
+}
+
+func (c *OpenClawConfig) convertChannels(warnings *[]string) ChannelsConfig {
+ channels := ChannelsConfig{}
+
+ if c.Channels == nil {
+ return channels
+ }
+
+ if c.Channels.Telegram != nil {
+ enabled := c.Channels.Telegram.Enabled == nil || *c.Channels.Telegram.Enabled
+ channels.Telegram = TelegramConfig{
+ Enabled: enabled,
+ AllowFrom: c.Channels.Telegram.AllowFrom,
+ }
+ if c.Channels.Telegram.BotToken != nil {
+ channels.Telegram.Token = *c.Channels.Telegram.BotToken
+ }
+ }
+
+ if c.Channels.Discord != nil {
+ enabled := c.Channels.Discord.Enabled == nil || *c.Channels.Discord.Enabled
+ channels.Discord = DiscordConfig{
+ Enabled: enabled,
+ AllowFrom: c.Channels.Discord.AllowFrom,
+ }
+ if c.Channels.Discord.Token != nil {
+ channels.Discord.Token = *c.Channels.Discord.Token
+ }
+ }
+
+ if c.Channels.Slack != nil {
+ enabled := c.Channels.Slack.Enabled == nil || *c.Channels.Slack.Enabled
+ channels.Slack = SlackConfig{
+ Enabled: enabled,
+ AllowFrom: c.Channels.Slack.AllowFrom,
+ }
+ if c.Channels.Slack.BotToken != nil {
+ channels.Slack.BotToken = *c.Channels.Slack.BotToken
+ }
+ if c.Channels.Slack.AppToken != nil {
+ channels.Slack.AppToken = *c.Channels.Slack.AppToken
+ }
+ }
+
+ if c.Channels.WhatsApp != nil {
+ enabled := c.Channels.WhatsApp.Enabled == nil || *c.Channels.WhatsApp.Enabled
+ channels.WhatsApp = WhatsAppConfig{
+ Enabled: enabled,
+ AllowFrom: c.Channels.WhatsApp.AllowFrom,
+ }
+ if c.Channels.WhatsApp.BridgeURL != nil {
+ channels.WhatsApp.BridgeURL = *c.Channels.WhatsApp.BridgeURL
+ }
+ }
+
+ if c.Channels.Feishu != nil {
+ enabled := c.Channels.Feishu.Enabled == nil || *c.Channels.Feishu.Enabled
+ channels.Feishu = FeishuConfig{
+ Enabled: enabled,
+ AllowFrom: c.Channels.Feishu.AllowFrom,
+ }
+ if c.Channels.Feishu.AppID != nil {
+ channels.Feishu.AppID = *c.Channels.Feishu.AppID
+ }
+ if c.Channels.Feishu.AppSecret != nil {
+ channels.Feishu.AppSecret = *c.Channels.Feishu.AppSecret
+ }
+ if c.Channels.Feishu.EncryptKey != nil {
+ channels.Feishu.EncryptKey = *c.Channels.Feishu.EncryptKey
+ }
+ if c.Channels.Feishu.VerificationToken != nil {
+ channels.Feishu.VerificationToken = *c.Channels.Feishu.VerificationToken
+ }
+ }
+
+ if c.Channels.QQ != nil && supportedChannels["qq"] {
+ channels.QQ = QQConfig{
+ Enabled: true,
+ AllowFrom: c.Channels.QQ.AllowFrom,
+ }
+ if c.Channels.QQ.AppID != nil {
+ channels.QQ.AppID = *c.Channels.QQ.AppID
+ }
+ if c.Channels.QQ.AppSecret != nil {
+ channels.QQ.AppSecret = *c.Channels.QQ.AppSecret
+ }
+ }
+
+ if c.Channels.DingTalk != nil && supportedChannels["dingtalk"] {
+ channels.DingTalk = DingTalkConfig{
+ Enabled: true,
+ AllowFrom: c.Channels.DingTalk.AllowFrom,
+ }
+ if c.Channels.DingTalk.AppID != nil {
+ channels.DingTalk.ClientID = *c.Channels.DingTalk.AppID
+ }
+ if c.Channels.DingTalk.AppSecret != nil {
+ channels.DingTalk.ClientSecret = *c.Channels.DingTalk.AppSecret
+ }
+ }
+
+ if c.Channels.MaixCam != nil && supportedChannels["maixcam"] {
+ channels.MaixCam = MaixCamConfig{
+ Enabled: true,
+ AllowFrom: c.Channels.MaixCam.AllowFrom,
+ }
+ if c.Channels.MaixCam.Host != nil {
+ channels.MaixCam.Host = *c.Channels.MaixCam.Host
+ }
+ if c.Channels.MaixCam.Port != nil {
+ channels.MaixCam.Port = *c.Channels.MaixCam.Port
+ }
+ }
+
+ if c.Channels.Signal != nil {
+ *warnings = append(*warnings, "Channel 'signal': No PicoClaw adapter available")
+ }
+ if c.Channels.Matrix != nil {
+ *warnings = append(*warnings, "Channel 'matrix': No PicoClaw adapter available")
+ }
+ if c.Channels.IRC != nil {
+ *warnings = append(*warnings, "Channel 'irc': No PicoClaw adapter available")
+ }
+ if c.Channels.Mattermost != nil {
+ *warnings = append(*warnings, "Channel 'mattermost': No PicoClaw adapter available")
+ }
+ if c.Channels.IMessage != nil {
+ *warnings = append(*warnings, "Channel 'imessage': macOS-only channel - requires manual setup")
+ }
+ if c.Channels.BlueBubbles != nil {
+ *warnings = append(
+ *warnings,
+ "Channel 'bluebubbles': No PicoClaw adapter available - consider iMessage instead",
+ )
+ }
+
+ return channels
+}
+
+func (c *OpenClawConfig) convertAgents(warnings *[]string) []AgentConfig {
+ var agents []AgentConfig
+
+ if c.Agents == nil {
+ return agents
+ }
+
+ for _, entry := range c.Agents.List {
+ agentID := entry.ID
+ if agentID == "" {
+ continue
+ }
+
+ agentName := agentID
+ if entry.Name != nil {
+ agentName = *entry.Name
+ }
+
+ agentCfg := AgentConfig{
+ ID: agentID,
+ Name: agentName,
+ Default: len(agents) == 0,
+ }
+
+ if entry.Workspace != nil {
+ agentCfg.Workspace = rewriteWorkspacePath(*entry.Workspace)
+ }
+
+ if entry.Model != nil {
+ primary := entry.Model.GetPrimary()
+ if primary != "" {
+ agentCfg.Model = &AgentModelConfig{
+ Primary: primary,
+ Fallbacks: entry.Model.GetFallbacks(),
+ }
+ }
+ }
+
+ if len(entry.Skills) > 0 {
+ agentCfg.Skills = entry.Skills
+ }
+
+ agents = append(agents, agentCfg)
+ }
+
+ return agents
+}
+
+func (c *PicoClawConfig) ToStandardConfig() *config.Config {
+ cfg := config.DefaultConfig()
+
+ cfg.Agents.Defaults.Workspace = c.Agents.Defaults.Workspace
+ cfg.Agents.Defaults.Provider = c.Agents.Defaults.Provider
+ cfg.Agents.Defaults.ModelName = c.Agents.Defaults.ModelName
+ cfg.Agents.Defaults.ModelFallbacks = c.Agents.Defaults.ModelFallbacks
+
+ for _, m := range c.ModelList {
+ cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
+ ModelName: m.ModelName,
+ Model: m.Model,
+ APIBase: m.APIBase,
+ APIKey: m.APIKey,
+ Proxy: m.Proxy,
+ })
+ }
+
+ cfg.Channels = c.Channels.ToStandardChannels()
+ cfg.Gateway = c.Gateway.ToStandardGateway()
+ cfg.Tools = c.Tools.ToStandardTools()
+
+ cfg.Agents.List = make([]config.AgentConfig, len(c.Agents.List))
+ for i, a := range c.Agents.List {
+ cfg.Agents.List[i] = config.AgentConfig{
+ ID: a.ID,
+ Default: a.Default,
+ Name: a.Name,
+ Workspace: a.Workspace,
+ Skills: a.Skills,
+ }
+ if a.Model != nil {
+ cfg.Agents.List[i].Model = &config.AgentModelConfig{
+ Primary: a.Model.Primary,
+ Fallbacks: a.Model.Fallbacks,
+ }
+ }
+ }
+
+ return cfg
+}
+
+func (c ChannelsConfig) ToStandardChannels() config.ChannelsConfig {
+ return config.ChannelsConfig{
+ WhatsApp: config.WhatsAppConfig{
+ Enabled: c.WhatsApp.Enabled,
+ BridgeURL: c.WhatsApp.BridgeURL,
+ },
+ Telegram: config.TelegramConfig{
+ Enabled: c.Telegram.Enabled,
+ Token: c.Telegram.Token,
+ Proxy: c.Telegram.Proxy,
+ },
+ Feishu: config.FeishuConfig{
+ Enabled: c.Feishu.Enabled,
+ AppID: c.Feishu.AppID,
+ AppSecret: c.Feishu.AppSecret,
+ EncryptKey: c.Feishu.EncryptKey,
+ VerificationToken: c.Feishu.VerificationToken,
+ },
+ Discord: config.DiscordConfig{
+ Enabled: c.Discord.Enabled,
+ Token: c.Discord.Token,
+ MentionOnly: c.Discord.MentionOnly,
+ },
+ MaixCam: config.MaixCamConfig{
+ Enabled: c.MaixCam.Enabled,
+ Host: c.MaixCam.Host,
+ Port: c.MaixCam.Port,
+ },
+ QQ: config.QQConfig{
+ Enabled: c.QQ.Enabled,
+ AppID: c.QQ.AppID,
+ AppSecret: c.QQ.AppSecret,
+ },
+ DingTalk: config.DingTalkConfig{
+ Enabled: c.DingTalk.Enabled,
+ ClientID: c.DingTalk.ClientID,
+ ClientSecret: c.DingTalk.ClientSecret,
+ },
+ Slack: config.SlackConfig{
+ Enabled: c.Slack.Enabled,
+ BotToken: c.Slack.BotToken,
+ AppToken: c.Slack.AppToken,
+ },
+ LINE: config.LINEConfig{
+ Enabled: c.LINE.Enabled,
+ ChannelSecret: c.LINE.ChannelSecret,
+ ChannelAccessToken: c.LINE.ChannelAccessToken,
+ WebhookHost: c.LINE.WebhookHost,
+ WebhookPort: c.LINE.WebhookPort,
+ WebhookPath: c.LINE.WebhookPath,
+ },
+ }
+}
+
+func (c GatewayConfig) ToStandardGateway() config.GatewayConfig {
+ return config.GatewayConfig{
+ Host: c.Host,
+ Port: c.Port,
+ }
+}
+
+func (c ToolsConfig) ToStandardTools() config.ToolsConfig {
+ return config.ToolsConfig{
+ Web: config.WebToolsConfig{
+ Brave: config.BraveConfig{
+ Enabled: c.Web.Brave.Enabled,
+ APIKey: c.Web.Brave.APIKey,
+ MaxResults: c.Web.Brave.MaxResults,
+ },
+ Tavily: config.TavilyConfig{
+ Enabled: c.Web.Tavily.Enabled,
+ APIKey: c.Web.Tavily.APIKey,
+ BaseURL: c.Web.Tavily.BaseURL,
+ MaxResults: c.Web.Tavily.MaxResults,
+ },
+ DuckDuckGo: config.DuckDuckGoConfig{
+ Enabled: c.Web.DuckDuckGo.Enabled,
+ MaxResults: c.Web.DuckDuckGo.MaxResults,
+ },
+ Perplexity: config.PerplexityConfig{
+ Enabled: c.Web.Perplexity.Enabled,
+ APIKey: c.Web.Perplexity.APIKey,
+ MaxResults: c.Web.Perplexity.MaxResults,
+ },
+ Proxy: c.Web.Proxy,
+ },
+ Cron: config.CronToolsConfig{
+ ExecTimeoutMinutes: c.Cron.ExecTimeoutMinutes,
+ },
+ Exec: config.ExecConfig{
+ EnableDenyPatterns: c.Exec.EnableDenyPatterns,
+ CustomDenyPatterns: c.Exec.CustomDenyPatterns,
+ },
+ }
+}
diff --git a/pkg/migrate/sources/openclaw/openclaw_config_test.go b/pkg/migrate/sources/openclaw/openclaw_config_test.go
new file mode 100644
index 000000000..7d884522c
--- /dev/null
+++ b/pkg/migrate/sources/openclaw/openclaw_config_test.go
@@ -0,0 +1,714 @@
+package openclaw
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestLoadOpenClawConfig(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+
+ testConfig := `{
+ "agents": {
+ "defaults": {
+ "model": {
+ "primary": "anthropic/claude-sonnet-4-20250514"
+ },
+ "workspace": "~/.openclaw/workspace"
+ },
+ "list": [
+ {
+ "id": "main",
+ "name": "Main Agent",
+ "model": {
+ "primary": "openai/gpt-4o",
+ "fallbacks": ["claude-3-opus"]
+ }
+ }
+ ]
+ },
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "botToken": "test-token",
+ "allowFrom": ["user1", "user2"]
+ },
+ "discord": {
+ "enabled": true,
+ "token": "discord-token"
+ }
+ },
+ "models": {
+ "providers": {
+ "anthropic": {
+ "api_key": "sk-ant-test",
+ "base_url": "https://api.anthropic.com"
+ },
+ "openai": {
+ "api_key": "sk-test"
+ }
+ }
+ }
+ }`
+
+ err := os.WriteFile(configPath, []byte(testConfig), 0o644)
+ if err != nil {
+ t.Fatalf("failed to write test config: %v", err)
+ }
+
+ cfg, err := LoadOpenClawConfig(configPath)
+ if err != nil {
+ t.Fatalf("failed to load config: %v", err)
+ }
+
+ if cfg.Agents == nil {
+ t.Error("agents should not be nil")
+ }
+
+ if cfg.Agents.Defaults == nil {
+ t.Error("agents.defaults should not be nil")
+ }
+
+ provider, model := cfg.GetDefaultModel()
+ if provider != "anthropic" {
+ t.Errorf("expected provider 'anthropic', got '%s'", provider)
+ }
+ if model != "claude-sonnet-4-20250514" {
+ t.Errorf("expected model 'claude-sonnet-4-20250514', got '%s'", model)
+ }
+
+ workspace := cfg.GetDefaultWorkspace()
+ if workspace != "~/.picoclaw/workspace" {
+ t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", workspace)
+ }
+
+ agents := cfg.GetAgents()
+ if len(agents) != 1 {
+ t.Errorf("expected 1 agent, got %d", len(agents))
+ }
+ if agents[0].ID != "main" {
+ t.Errorf("expected agent id 'main', got '%s'", agents[0].ID)
+ }
+
+ if cfg.Channels == nil {
+ t.Error("channels should not be nil")
+ }
+ if cfg.Channels.Telegram == nil {
+ t.Error("telegram channel should not be nil")
+ }
+ if cfg.Channels.Telegram.BotToken == nil || *cfg.Channels.Telegram.BotToken != "test-token" {
+ t.Error("telegram bot token not parsed correctly")
+ }
+}
+
+func TestGetProviderConfig(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+
+ testConfig := `{
+ "models": {
+ "providers": {
+ "anthropic": {
+ "api_key": "sk-ant-test",
+ "base_url": "https://api.anthropic.com",
+ "max_tokens": 4096
+ },
+ "openai": {
+ "api_key": "sk-test",
+ "base_url": "https://api.openai.com"
+ }
+ }
+ }
+ }`
+
+ err := os.WriteFile(configPath, []byte(testConfig), 0o644)
+ if err != nil {
+ t.Fatalf("failed to write test config: %v", err)
+ }
+
+ cfg, err := LoadOpenClawConfig(configPath)
+ if err != nil {
+ t.Fatalf("failed to load config: %v", err)
+ }
+
+ providers := GetProviderConfig(cfg.Models)
+ if len(providers) != 2 {
+ t.Errorf("expected 2 providers, got %d", len(providers))
+ }
+
+ if anthropic, ok := providers["anthropic"]; ok {
+ if anthropic.APIKey != "sk-ant-test" {
+ t.Errorf("expected anthropic api_key 'sk-ant-test', got '%s'", anthropic.APIKey)
+ }
+ if anthropic.BaseURL != "https://api.anthropic.com" {
+ t.Errorf("expected anthropic base_url 'https://api.anthropic.com', got '%s'", anthropic.BaseURL)
+ }
+ } else {
+ t.Error("anthropic provider not found")
+ }
+
+ if openai, ok := providers["openai"]; ok {
+ if openai.APIKey != "sk-test" {
+ t.Errorf("expected openai api_key 'sk-test', got '%s'", openai.APIKey)
+ }
+ } else {
+ t.Error("openai provider not found")
+ }
+}
+
+func TestConvertToPicoClaw(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+
+ testConfig := `{
+ "agents": {
+ "defaults": {
+ "model": {
+ "primary": "anthropic/claude-sonnet-4-20250514"
+ },
+ "workspace": "~/.openclaw/workspace"
+ },
+ "list": [
+ {
+ "id": "main",
+ "name": "Main Agent"
+ },
+ {
+ "id": "assistant",
+ "name": "Assistant",
+ "skills": ["skill1", "skill2"]
+ }
+ ]
+ },
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "botToken": "test-token",
+ "allowFrom": ["user1", "user2"]
+ },
+ "discord": {
+ "enabled": false,
+ "token": "discord-token"
+ },
+ "whatsapp": {
+ "enabled": true,
+ "bridgeUrl": "http://localhost:3000"
+ },
+ "feishu": {
+ "enabled": true,
+ "appId": "app-id",
+ "appSecret": "app-secret",
+ "allowFrom": ["user3"]
+ },
+ "signal": {
+ "enabled": true
+ }
+ },
+ "models": {
+ "providers": {
+ "anthropic": {
+ "api_key": "sk-ant-test"
+ },
+ "openai": {
+ "api_key": "sk-test"
+ }
+ }
+ },
+ "skills": {
+ "entries": {
+ "skill1": {}
+ }
+ },
+ "memory": {"enabled": true},
+ "cron": {"enabled": true}
+ }`
+
+ err := os.WriteFile(configPath, []byte(testConfig), 0o644)
+ if err != nil {
+ t.Fatalf("failed to write test config: %v", err)
+ }
+
+ cfg, err := LoadOpenClawConfig(configPath)
+ if err != nil {
+ t.Fatalf("failed to load config: %v", err)
+ }
+
+ picoCfg, warnings, err := cfg.ConvertToPicoClaw("")
+ if err != nil {
+ t.Fatalf("failed to convert config: %v", err)
+ }
+
+ if picoCfg.Agents.Defaults.ModelName != "claude-sonnet-4-20250514" {
+ t.Errorf("expected model 'claude-sonnet-4-20250514', got '%s'", picoCfg.Agents.Defaults.ModelName)
+ }
+ if picoCfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
+ t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", picoCfg.Agents.Defaults.Workspace)
+ }
+
+ if len(picoCfg.Agents.List) != 2 {
+ t.Errorf("expected 2 agents, got %d", len(picoCfg.Agents.List))
+ }
+ if picoCfg.Agents.List[0].ID != "main" {
+ t.Errorf("expected first agent id 'main', got '%s'", picoCfg.Agents.List[0].ID)
+ }
+ if picoCfg.Agents.List[1].Skills == nil || len(picoCfg.Agents.List[1].Skills) != 2 {
+ t.Errorf("expected 2 skills for assistant agent")
+ }
+
+ if !picoCfg.Channels.Telegram.Enabled {
+ t.Error("telegram should be enabled")
+ }
+ if picoCfg.Channels.Telegram.Token != "test-token" {
+ t.Errorf("expected telegram token 'test-token', got '%s'", picoCfg.Channels.Telegram.Token)
+ }
+
+ if picoCfg.Channels.WhatsApp.BridgeURL != "http://localhost:3000" {
+ t.Errorf("expected whatsapp bridge URL 'http://localhost:3000', got '%s'", picoCfg.Channels.WhatsApp.BridgeURL)
+ }
+
+ if picoCfg.Channels.Feishu.AppID != "app-id" {
+ t.Errorf("expected feishu app ID 'app-id', got '%s'", picoCfg.Channels.Feishu.AppID)
+ }
+
+ if len(picoCfg.ModelList) != 1 {
+ t.Errorf("expected 1 model config (no models.json provided), got %d", len(picoCfg.ModelList))
+ }
+
+ foundWarning := false
+ for _, w := range warnings {
+ if len(w) > 0 {
+ foundWarning = true
+ break
+ }
+ }
+ if !foundWarning {
+ t.Log("warnings should be generated for skills, memory, cron, and unsupported channels")
+ }
+}
+
+func TestConvertToPicoClawWithQQAndDingTalk(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+
+ testConfig := `{
+ "agents": {
+ "defaults": {
+ "model": {
+ "primary": "anthropic/claude-sonnet-4-20250514"
+ }
+ }
+ },
+ "channels": {
+ "qq": {
+ "enabled": true,
+ "appId": "qq-app-id",
+ "appSecret": "qq-app-secret"
+ },
+ "dingtalk": {
+ "enabled": true,
+ "appId": "ding-app-id",
+ "appSecret": "ding-app-secret"
+ },
+ "maixcam": {
+ "enabled": true,
+ "host": "192.168.1.100",
+ "port": 9000
+ },
+ "slack": {
+ "enabled": true,
+ "botToken": "xoxb-test",
+ "appToken": "xapp-test"
+ }
+ }
+ }`
+
+ err := os.WriteFile(configPath, []byte(testConfig), 0o644)
+ if err != nil {
+ t.Fatalf("failed to write test config: %v", err)
+ }
+
+ cfg, err := LoadOpenClawConfig(configPath)
+ if err != nil {
+ t.Fatalf("failed to load config: %v", err)
+ }
+
+ picoCfg, _, err := cfg.ConvertToPicoClaw("")
+ if err != nil {
+ t.Fatalf("failed to convert config: %v", err)
+ }
+
+ if !picoCfg.Channels.QQ.Enabled {
+ t.Error("qq should be enabled")
+ }
+ if picoCfg.Channels.QQ.AppID != "qq-app-id" {
+ t.Errorf("expected qq app ID 'qq-app-id', got '%s'", picoCfg.Channels.QQ.AppID)
+ }
+
+ if !picoCfg.Channels.DingTalk.Enabled {
+ t.Error("dingtalk should be enabled")
+ }
+ if picoCfg.Channels.DingTalk.ClientID != "ding-app-id" {
+ t.Errorf("expected dingtalk client ID 'ding-app-id', got '%s'", picoCfg.Channels.DingTalk.ClientID)
+ }
+
+ if !picoCfg.Channels.MaixCam.Enabled {
+ t.Error("maixcam should be enabled")
+ }
+ if picoCfg.Channels.MaixCam.Host != "192.168.1.100" {
+ t.Errorf("expected maixcam host '192.168.1.100', got '%s'", picoCfg.Channels.MaixCam.Host)
+ }
+ if picoCfg.Channels.MaixCam.Port != 9000 {
+ t.Errorf("expected maixcam port 9000, got %d", picoCfg.Channels.MaixCam.Port)
+ }
+
+ if !picoCfg.Channels.Slack.Enabled {
+ t.Error("slack should be enabled")
+ }
+ if picoCfg.Channels.Slack.BotToken != "xoxb-test" {
+ t.Errorf("expected slack bot token 'xoxb-test', got '%s'", picoCfg.Channels.Slack.BotToken)
+ }
+ if picoCfg.Channels.Slack.AppToken != "xapp-test" {
+ t.Errorf("expected slack app token 'xapp-test', got '%s'", picoCfg.Channels.Slack.AppToken)
+ }
+}
+
+func TestOpenClawAgentModel(t *testing.T) {
+ model := &OpenClawAgentModel{
+ Primary: strPtr("anthropic/claude-3-opus"),
+ Fallbacks: []string{"claude-3-sonnet", "claude-3-haiku"},
+ }
+
+ primary := model.GetPrimary()
+ if primary != "anthropic/claude-3-opus" {
+ t.Errorf("expected primary 'anthropic/claude-3-opus', got '%s'", primary)
+ }
+
+ fallbacks := model.GetFallbacks()
+ if len(fallbacks) != 2 {
+ t.Errorf("expected 2 fallbacks, got %d", len(fallbacks))
+ }
+
+ model2 := &OpenClawAgentModel{
+ Simple: "claude-3-opus",
+ }
+
+ primary2 := model2.GetPrimary()
+ if primary2 != "claude-3-opus" {
+ t.Errorf("expected primary 'claude-3-opus' from Simple, got '%s'", primary2)
+ }
+}
+
+func TestChannelEnabled(t *testing.T) {
+ cfg := &OpenClawConfig{
+ Channels: &OpenClawChannels{
+ Telegram: &OpenClawTelegramConfig{
+ Enabled: boolPtr(true),
+ },
+ Discord: &OpenClawDiscordConfig{
+ Enabled: boolPtr(false),
+ },
+ Slack: &OpenClawSlackConfig{
+ Enabled: boolPtr(true),
+ },
+ },
+ }
+
+ if !cfg.IsChannelEnabled("telegram") {
+ t.Error("telegram should be enabled")
+ }
+ if cfg.IsChannelEnabled("discord") {
+ t.Error("discord should be disabled")
+ }
+ if !cfg.IsChannelEnabled("slack") {
+ t.Error("slack should be enabled (explicitly set)")
+ }
+ if cfg.IsChannelEnabled("line") {
+ t.Error("line should return false (not in switch cases)")
+ }
+}
+
+func TestGetDefaultModel(t *testing.T) {
+ cfg := &OpenClawConfig{
+ Agents: &OpenClawAgents{
+ Defaults: &OpenClawAgentDefaults{
+ Model: &OpenClawAgentModel{
+ Primary: strPtr("openai/gpt-4"),
+ },
+ },
+ },
+ }
+
+ provider, model := cfg.GetDefaultModel()
+ if provider != "openai" {
+ t.Errorf("expected provider 'openai', got '%s'", provider)
+ }
+ if model != "gpt-4" {
+ t.Errorf("expected model 'gpt-4', got '%s'", model)
+ }
+}
+
+func TestGetDefaultModelWithNoDefaults(t *testing.T) {
+ cfg := &OpenClawConfig{}
+
+ provider, model := cfg.GetDefaultModel()
+ if provider != "anthropic" {
+ t.Errorf("expected default provider 'anthropic', got '%s'", provider)
+ }
+ if model != "claude-sonnet-4-20250514" {
+ t.Errorf("expected default model 'claude-sonnet-4-20250514', got '%s'", model)
+ }
+}
+
+func TestHasFunctions(t *testing.T) {
+ cfg := &OpenClawConfig{
+ Skills: &OpenClawSkills{Entries: map[string]json.RawMessage{"skill1": nil}},
+ Memory: json.RawMessage(`{"enabled": true}`),
+ Cron: json.RawMessage(`{"enabled": true}`),
+ Hooks: json.RawMessage(`{"enabled": true}`),
+ Session: json.RawMessage(`{"enabled": true}`),
+ Auth: &OpenClawAuth{Profiles: json.RawMessage(`{"profile1": {}}`)},
+ }
+
+ if !cfg.HasSkills() {
+ t.Error("should have skills")
+ }
+ if !cfg.HasMemory() {
+ t.Error("should have memory")
+ }
+ if !cfg.HasCron() {
+ t.Error("should have cron")
+ }
+ if !cfg.HasHooks() {
+ t.Error("should have hooks")
+ }
+ if !cfg.HasSession() {
+ t.Error("should have session")
+ }
+ if !cfg.HasAuthProfiles() {
+ t.Error("should have auth profiles")
+ }
+
+ cfg2 := &OpenClawConfig{}
+ if cfg2.HasSkills() {
+ t.Error("should not have skills")
+ }
+ if cfg2.HasMemory() {
+ t.Error("should not have memory")
+ }
+}
+
+func TestLoadOpenClawConfigFromDir(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+
+ testConfig := `{"agents": {}}`
+ err := os.WriteFile(configPath, []byte(testConfig), 0o644)
+ if err != nil {
+ t.Fatalf("failed to write test config: %v", err)
+ }
+
+ cfg, err := LoadOpenClawConfigFromDir(tmpDir)
+ if err != nil {
+ t.Fatalf("failed to load config from dir: %v", err)
+ }
+
+ if cfg.Agents == nil {
+ t.Error("agents should not be nil")
+ }
+
+ _, err = LoadOpenClawConfigFromDir("/nonexistent/dir")
+ if err == nil {
+ t.Error("should return error for nonexistent dir")
+ }
+}
+
+func TestToStandardConfig(t *testing.T) {
+ picoCfg := &PicoClawConfig{
+ Agents: AgentsConfig{
+ Defaults: AgentDefaults{
+ Provider: "anthropic",
+ ModelName: "claude-sonnet-4-20250514",
+ Workspace: "~/.picoclaw/workspace",
+ },
+ List: []AgentConfig{
+ {
+ ID: "main",
+ Name: "Main Agent",
+ Default: true,
+ },
+ },
+ },
+ ModelList: []ModelConfig{
+ {
+ ModelName: "claude-sonnet-4-20250514",
+ Model: "anthropic/claude-sonnet-4-20250514",
+ APIKey: "sk-ant-test",
+ },
+ },
+ Channels: ChannelsConfig{
+ Telegram: TelegramConfig{
+ Enabled: true,
+ Token: "test-token",
+ AllowFrom: []string{"user1"},
+ },
+ WhatsApp: WhatsAppConfig{
+ Enabled: true,
+ BridgeURL: "http://localhost:3000",
+ },
+ },
+ Gateway: GatewayConfig{
+ Host: "0.0.0.0",
+ Port: 8080,
+ },
+ }
+
+ stdCfg := picoCfg.ToStandardConfig()
+
+ if stdCfg.Agents.Defaults.Provider != "anthropic" {
+ t.Errorf("expected provider 'anthropic', got '%s'", stdCfg.Agents.Defaults.Provider)
+ }
+ if stdCfg.Agents.Defaults.ModelName != "claude-sonnet-4-20250514" {
+ t.Errorf("expected model name 'claude-sonnet-4-20250514', got '%s'", stdCfg.Agents.Defaults.ModelName)
+ }
+ if stdCfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
+ t.Errorf("expected workspace '~/.picoclaw/workspace', got '%s'", stdCfg.Agents.Defaults.Workspace)
+ }
+
+ if len(stdCfg.Agents.List) != 1 {
+ t.Errorf("expected 1 agent, got %d", len(stdCfg.Agents.List))
+ }
+ if stdCfg.Agents.List[0].ID != "main" {
+ t.Errorf("expected agent id 'main', got '%s'", stdCfg.Agents.List[0].ID)
+ }
+
+ foundModel := false
+ var foundAPIKey string
+ for _, m := range stdCfg.ModelList {
+ if m.ModelName == "claude-sonnet-4-20250514" {
+ foundModel = true
+ foundAPIKey = m.APIKey
+ break
+ }
+ }
+ if !foundModel {
+ t.Error("expected to find claude-sonnet-4-20250514 model config")
+ }
+ if foundAPIKey != "sk-ant-test" {
+ t.Errorf("expected api key 'sk-ant-test', got '%s'", foundAPIKey)
+ }
+
+ if !stdCfg.Channels.Telegram.Enabled {
+ t.Error("telegram should be enabled")
+ }
+ if stdCfg.Channels.Telegram.Token != "test-token" {
+ t.Errorf("expected token 'test-token', got '%s'", stdCfg.Channels.Telegram.Token)
+ }
+
+ if stdCfg.Gateway.Port != 8080 {
+ t.Errorf("expected gateway port 8080, got %d", stdCfg.Gateway.Port)
+ }
+}
+
+func TestLoadProviderConfigFromAgentsDir(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ agentsDir := filepath.Join(tmpDir, "agents", "main", "agent")
+ err := os.MkdirAll(agentsDir, 0o755)
+ if err != nil {
+ t.Fatalf("failed to create agents dir: %v", err)
+ }
+
+ modelsJSON := `{
+ "providers": {
+ "anthropic": {
+ "baseUrl": "https://api.anthropic.com",
+ "api": "anthropic",
+ "apiKey": "sk-ant-from-models",
+ "models": [
+ {
+ "id": "claude-sonnet-4-20250514",
+ "name": "Claude Sonnet 4"
+ }
+ ]
+ },
+ "openai": {
+ "baseUrl": "https://api.openai.com",
+ "api": "openai",
+ "apiKey": "sk-from-models",
+ "models": [
+ {
+ "id": "gpt-4o",
+ "name": "GPT-4o"
+ }
+ ]
+ },
+ "zhipu": {
+ "baseUrl": "https://open.bigmodel.cn/api/paas/v4",
+ "api": "openai",
+ "apiKey": "zhipu-key",
+ "models": []
+ }
+ }
+ }`
+
+ err = os.WriteFile(filepath.Join(agentsDir, "models.json"), []byte(modelsJSON), 0o644)
+ if err != nil {
+ t.Fatalf("failed to write models.json: %v", err)
+ }
+
+ providers := GetProviderConfigFromDir(tmpDir)
+ if len(providers) != 3 {
+ t.Errorf("expected 3 providers, got %d", len(providers))
+ }
+
+ if anthropic, ok := providers["anthropic"]; ok {
+ if anthropic.ApiKey != "sk-ant-from-models" {
+ t.Errorf("expected anthropic apiKey 'sk-ant-from-models', got '%s'", anthropic.ApiKey)
+ }
+ if anthropic.BaseUrl != "https://api.anthropic.com" {
+ t.Errorf("expected anthropic baseUrl 'https://api.anthropic.com', got '%s'", anthropic.BaseUrl)
+ }
+ } else {
+ t.Error("anthropic provider not found")
+ }
+
+ if openai, ok := providers["openai"]; ok {
+ if openai.ApiKey != "sk-from-models" {
+ t.Errorf("expected openai apiKey 'sk-from-models', got '%s'", openai.ApiKey)
+ }
+ if openai.BaseUrl != "https://api.openai.com" {
+ t.Errorf("expected openai baseUrl 'https://api.openai.com', got '%s'", openai.BaseUrl)
+ }
+ } else {
+ t.Error("openai provider not found")
+ }
+
+ if zhipu, ok := providers["zhipu"]; ok {
+ if zhipu.ApiKey != "zhipu-key" {
+ t.Errorf("expected zhipu apiKey 'zhipu-key', got '%s'", zhipu.ApiKey)
+ }
+ if zhipu.BaseUrl != "https://open.bigmodel.cn/api/paas/v4" {
+ t.Errorf("expected zhipu baseUrl 'https://open.bigmodel.cn/api/paas/v4', got '%s'", zhipu.BaseUrl)
+ }
+ } else {
+ t.Error("zhipu provider not found")
+ }
+}
+
+func TestGetProviderConfigFromDirNotExist(t *testing.T) {
+ providers := GetProviderConfigFromDir("/nonexistent/path")
+ if len(providers) != 0 {
+ t.Errorf("expected 0 providers for nonexistent path, got %d", len(providers))
+ }
+}
+
+func strPtr(s string) *string {
+ return &s
+}
+
+func boolPtr(b bool) *bool {
+ return &b
+}
diff --git a/pkg/migrate/sources/openclaw/openclaw_handler.go b/pkg/migrate/sources/openclaw/openclaw_handler.go
new file mode 100644
index 000000000..aaff119f1
--- /dev/null
+++ b/pkg/migrate/sources/openclaw/openclaw_handler.go
@@ -0,0 +1,148 @@
+package openclaw
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/migrate/internal"
+)
+
+var providerMapping = map[string]string{
+ "anthropic": "anthropic",
+ "claude": "anthropic",
+ "openai": "openai",
+ "gpt": "openai",
+ "groq": "groq",
+ "ollama": "ollama",
+ "openrouter": "openrouter",
+ "deepseek": "deepseek",
+ "together": "together",
+ "mistral": "mistral",
+ "fireworks": "fireworks",
+ "google": "google",
+ "gemini": "google",
+ "xai": "xai",
+ "grok": "xai",
+ "cerebras": "cerebras",
+ "sambanova": "sambanova",
+}
+
+type OpenclawHandler struct {
+ opts Options
+ sourceConfigFile string
+ sourceWorkspace string
+}
+
+type (
+ Options = internal.Options
+ Action = internal.Action
+ Result = internal.Result
+ Operation = internal.Operation
+)
+
+func NewOpenclawHandler(opts Options) (Operation, error) {
+ home, err := resolveSourceHome(opts.SourceHome)
+ if err != nil {
+ return nil, err
+ }
+ opts.SourceHome = home
+
+ configFile, err := findSourceConfig(home)
+ if err != nil {
+ return nil, err
+ }
+ return &OpenclawHandler{
+ opts: opts,
+ sourceWorkspace: filepath.Join(opts.SourceHome, "workspace"),
+ sourceConfigFile: configFile,
+ }, nil
+}
+
+func (o *OpenclawHandler) GetSourceName() string {
+ return "openclaw"
+}
+
+func (o *OpenclawHandler) GetSourceHome() (string, error) {
+ return o.opts.SourceHome, nil
+}
+
+func (o *OpenclawHandler) GetSourceWorkspace() (string, error) {
+ return o.sourceWorkspace, nil
+}
+
+func (o *OpenclawHandler) GetSourceConfigFile() (string, error) {
+ return o.sourceConfigFile, nil
+}
+
+func (o *OpenclawHandler) GetMigrateableFiles() []string {
+ return migrateableFiles
+}
+
+func (o *OpenclawHandler) GetMigrateableDirs() []string {
+ return migrateableDirs
+}
+
+func (o *OpenclawHandler) ExecuteConfigMigration(srcConfigPath, dstConfigPath string) error {
+ openclawCfg, err := LoadOpenClawConfig(srcConfigPath)
+ if err != nil {
+ return err
+ }
+
+ picoCfg, warnings, err := openclawCfg.ConvertToPicoClaw(o.opts.SourceHome)
+ if err != nil {
+ return err
+ }
+
+ for _, w := range warnings {
+ fmt.Printf(" Warning: %s\n", w)
+ }
+
+ incoming := picoCfg.ToStandardConfig()
+ if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil {
+ return err
+ }
+
+ return config.SaveConfig(dstConfigPath, incoming)
+}
+
+func resolveSourceHome(override string) (string, error) {
+ if override != "" {
+ return internal.ExpandHome(override), nil
+ }
+ if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
+ return internal.ExpandHome(envHome), nil
+ }
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", fmt.Errorf("resolving home directory: %w", err)
+ }
+ return filepath.Join(home, ".openclaw"), nil
+}
+
+func findSourceConfig(sourceHome string) (string, error) {
+ candidates := []string{
+ filepath.Join(sourceHome, "openclaw.json"),
+ filepath.Join(sourceHome, "config.json"),
+ }
+ for _, p := range candidates {
+ if _, err := os.Stat(p); err == nil {
+ return p, nil
+ }
+ }
+ return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", sourceHome)
+}
+
+func rewriteWorkspacePath(path string) string {
+ path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
+ return path
+}
+
+func mapProvider(provider string) string {
+ if mapped, ok := providerMapping[strings.ToLower(provider)]; ok {
+ return mapped
+ }
+ return strings.ToLower(provider)
+}
diff --git a/pkg/migrate/sources/openclaw/openclaw_handler_test.go b/pkg/migrate/sources/openclaw/openclaw_handler_test.go
new file mode 100644
index 000000000..35bd09be0
--- /dev/null
+++ b/pkg/migrate/sources/openclaw/openclaw_handler_test.go
@@ -0,0 +1,247 @@
+package openclaw
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewOpenclawHandler(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, handler)
+}
+
+func TestNewOpenclawHandlerNoConfig(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ _, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.Error(t, err)
+}
+
+func TestOpenclawHandlerGetSourceName(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+
+ assert.Equal(t, "openclaw", handler.GetSourceName())
+}
+
+func TestOpenclawHandlerGetSourceHome(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+
+ home, err := handler.GetSourceHome()
+ require.NoError(t, err)
+ assert.Equal(t, tmpDir, home)
+}
+
+func TestOpenclawHandlerGetSourceWorkspace(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+
+ workspace, err := handler.GetSourceWorkspace()
+ require.NoError(t, err)
+ assert.Equal(t, filepath.Join(tmpDir, "workspace"), workspace)
+}
+
+func TestOpenclawHandlerGetSourceConfigFile(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+
+ configFile, err := handler.GetSourceConfigFile()
+ require.NoError(t, err)
+ assert.Equal(t, configPath, configFile)
+}
+
+func TestOpenclawHandlerGetSourceConfigFileWithConfigJson(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "config.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+
+ configFile, err := handler.GetSourceConfigFile()
+ require.NoError(t, err)
+ assert.Equal(t, configPath, configFile)
+}
+
+func TestOpenclawHandlerGetMigrateableFiles(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+
+ files := handler.GetMigrateableFiles()
+ assert.NotEmpty(t, files)
+ assert.Contains(t, files, "AGENTS.md")
+ assert.Contains(t, files, "SOUL.md")
+ assert.Contains(t, files, "USER.md")
+}
+
+func TestOpenclawHandlerGetMigrateableDirs(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ handler, err := NewOpenclawHandler(Options{
+ SourceHome: tmpDir,
+ })
+ require.NoError(t, err)
+
+ dirs := handler.GetMigrateableDirs()
+ assert.NotEmpty(t, dirs)
+ assert.Contains(t, dirs, "memory")
+ assert.Contains(t, dirs, "skills")
+}
+
+func TestResolveSourceHome(t *testing.T) {
+ result, err := resolveSourceHome("/custom/path")
+ require.NoError(t, err)
+ assert.Equal(t, "/custom/path", result)
+}
+
+func TestResolveSourceHomeWithEnvVar(t *testing.T) {
+ t.Setenv("OPENCLAW_HOME", "/env/path")
+
+ result, err := resolveSourceHome("")
+ require.NoError(t, err)
+ assert.Equal(t, "/env/path", result)
+}
+
+func TestResolveSourceHomeWithTilde(t *testing.T) {
+ home, err := os.UserHomeDir()
+ require.NoError(t, err)
+
+ result, err := resolveSourceHome("~/openclaw")
+ require.NoError(t, err)
+ assert.Equal(t, filepath.Join(home, "openclaw"), result)
+}
+
+func TestFindSourceConfig(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "openclaw.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ result, err := findSourceConfig(tmpDir)
+ require.NoError(t, err)
+ assert.Equal(t, configPath, result)
+}
+
+func TestFindSourceConfigWithConfigJson(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "config.json")
+ err := os.WriteFile(configPath, []byte("{}"), 0o644)
+ require.NoError(t, err)
+
+ result, err := findSourceConfig(tmpDir)
+ require.NoError(t, err)
+ assert.Equal(t, configPath, result)
+}
+
+func TestFindSourceConfigNotFound(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ _, err := findSourceConfig(tmpDir)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "no config file found")
+}
+
+func TestMapProvider(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"anthropic", "anthropic"},
+ {"claude", "anthropic"},
+ {"openai", "openai"},
+ {"gpt", "openai"},
+ {"groq", "groq"},
+ {"ollama", "ollama"},
+ {"openrouter", "openrouter"},
+ {"deepseek", "deepseek"},
+ {"together", "together"},
+ {"mistral", "mistral"},
+ {"fireworks", "fireworks"},
+ {"google", "google"},
+ {"gemini", "google"},
+ {"xai", "xai"},
+ {"grok", "xai"},
+ {"cerebras", "cerebras"},
+ {"sambanova", "sambanova"},
+ {"unknown", "unknown"},
+ {"", ""},
+ }
+
+ for _, tt := range tests {
+ result := mapProvider(tt.input)
+ assert.Equal(t, tt.expected, result, "mapProvider(%q)", tt.input)
+ }
+}
+
+func TestRewriteWorkspacePath(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"~/.openclaw/workspace", "~/.picoclaw/workspace"},
+ {"/home/user/.openclaw/workspace", "/home/user/.picoclaw/workspace"},
+ {"/path/without/openclaw/change", "/path/without/openclaw/change"},
+ {"", ""},
+ }
+
+ for _, tt := range tests {
+ result := rewriteWorkspacePath(tt.input)
+ assert.Equal(t, tt.expected, result, "rewriteWorkspacePath(%q)", tt.input)
+ }
+}
diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go
index 9162174c9..1bb15f771 100644
--- a/pkg/providers/anthropic/provider.go
+++ b/pkg/providers/anthropic/provider.go
@@ -212,14 +212,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
}
func parseResponse(resp *anthropic.Message) *LLMResponse {
- var content string
+ var content strings.Builder
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
- content += tb.Text
+ content.WriteString(tb.Text)
case "tool_use":
tu := block.AsToolUse()
var args map[string]any
@@ -246,7 +246,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
}
return &LLMResponse{
- Content: content,
+ Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
@@ -264,8 +264,8 @@ func normalizeBaseURL(apiBase string) string {
}
base = strings.TrimRight(base, "/")
- if strings.HasSuffix(base, "/v1") {
- base = strings.TrimSuffix(base, "/v1")
+ if before, ok := strings.CutSuffix(base, "/v1"); ok {
+ base = before
}
if base == "" {
return defaultBaseURL
diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go
index 74ec33b98..6c4f6a767 100644
--- a/pkg/providers/claude_cli_provider.go
+++ b/pkg/providers/claude_cli_provider.go
@@ -100,44 +100,12 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe
}
if len(tools) > 0 {
- parts = append(parts, p.buildToolsPrompt(tools))
+ parts = append(parts, buildCLIToolsPrompt(tools))
}
return strings.Join(parts, "\n\n")
}
-// buildToolsPrompt creates the tool definitions section for the system prompt.
-func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
- var sb strings.Builder
-
- sb.WriteString("## Available Tools\n\n")
- sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
- sb.WriteString("```json\n")
- sb.WriteString(
- `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
- )
- sb.WriteString("\n```\n\n")
- sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
- sb.WriteString("### Tool Definitions:\n\n")
-
- for _, tool := range tools {
- if tool.Type != "function" {
- continue
- }
- sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
- if tool.Function.Description != "" {
- sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
- }
- if len(tool.Function.Parameters) > 0 {
- paramsJSON, _ := json.Marshal(tool.Function.Parameters)
- sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
- }
- sb.WriteString("\n")
- }
-
- return sb.String()
-}
-
// parseClaudeCliResponse parses the JSON output from the claude CLI.
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
var resp claudeCliJSONResponse
diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go
index 3a3cafaca..d4d648f5a 100644
--- a/pkg/providers/claude_cli_provider_test.go
+++ b/pkg/providers/claude_cli_provider_test.go
@@ -660,12 +660,11 @@ func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
// --- buildToolsPrompt tests ---
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
- p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
}
- got := p.buildToolsPrompt(tools)
+ got := buildCLIToolsPrompt(tools)
if strings.Contains(got, "skip_me") {
t.Error("buildToolsPrompt() should skip non-function tools")
}
@@ -675,11 +674,10 @@ func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
}
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
- p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
}
- got := p.buildToolsPrompt(tools)
+ got := buildCLIToolsPrompt(tools)
if !strings.Contains(got, "bare_tool") {
t.Error("should include tool name")
}
@@ -689,14 +687,13 @@ func TestBuildToolsPrompt_NoDescription(t *testing.T) {
}
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
- p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{
Name: "no_params_tool",
Description: "A tool with no parameters",
}},
}
- got := p.buildToolsPrompt(tools)
+ got := buildCLIToolsPrompt(tools)
if strings.Contains(got, "Parameters:") {
t.Error("should not include Parameters: section when nil")
}
diff --git a/pkg/providers/codex_cli_provider.go b/pkg/providers/codex_cli_provider.go
index 4c783ece5..13f53ad9e 100644
--- a/pkg/providers/codex_cli_provider.go
+++ b/pkg/providers/codex_cli_provider.go
@@ -115,7 +115,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
}
if len(tools) > 0 {
- sb.WriteString(p.buildToolsPrompt(tools))
+ sb.WriteString(buildCLIToolsPrompt(tools))
sb.WriteString("\n\n")
}
@@ -128,38 +128,6 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio
return sb.String()
}
-// buildToolsPrompt creates a tool definitions section for the prompt.
-func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
- var sb strings.Builder
-
- sb.WriteString("## Available Tools\n\n")
- sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
- sb.WriteString("```json\n")
- sb.WriteString(
- `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
- )
- sb.WriteString("\n```\n\n")
- sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
- sb.WriteString("### Tool Definitions:\n\n")
-
- for _, tool := range tools {
- if tool.Type != "function" {
- continue
- }
- sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
- if tool.Function.Description != "" {
- sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
- }
- if len(tool.Function.Parameters) > 0 {
- paramsJSON, _ := json.Marshal(tool.Function.Parameters)
- sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
- }
- sb.WriteString("\n")
- }
-
- return sb.String()
-}
-
// codexEvent represents a single JSONL event from `codex exec --json`.
type codexEvent struct {
Type string `json:"type"`
diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go
index dcc740ba4..47618300a 100644
--- a/pkg/providers/codex_provider.go
+++ b/pkg/providers/codex_provider.go
@@ -163,8 +163,8 @@ func resolveCodexModel(model string) (string, string) {
return codexDefaultModel, "empty model"
}
- if strings.HasPrefix(m, "openai/") {
- m = strings.TrimPrefix(m, "openai/")
+ if after, ok := strings.CutPrefix(m, "openai/"); ok {
+ m = after
} else if strings.Contains(m, "/") {
return codexDefaultModel, "non-openai model namespace"
}
diff --git a/pkg/providers/cooldown_test.go b/pkg/providers/cooldown_test.go
index 47f43ad5c..b517e7feb 100644
--- a/pkg/providers/cooldown_test.go
+++ b/pkg/providers/cooldown_test.go
@@ -138,7 +138,7 @@ func TestCooldown_FailureWindowReset(t *testing.T) {
ct, current := newTestTracker(now)
// 4 errors → 1h cooldown
- for i := 0; i < 4; i++ {
+ for range 4 {
ct.MarkFailure("openai", FailoverRateLimit)
*current = current.Add(2 * time.Second) // small advance between errors
}
@@ -230,7 +230,7 @@ func TestCooldown_ConcurrentAccess(t *testing.T) {
ct := NewCooldownTracker()
var wg sync.WaitGroup
- for i := 0; i < 100; i++ {
+ for range 100 {
wg.Add(3)
go func() {
defer wg.Done()
diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go
index a0f003006..fd9bf1e81 100644
--- a/pkg/providers/error_classifier.go
+++ b/pkg/providers/error_classifier.go
@@ -6,6 +6,13 @@ import (
"strings"
)
+// Common patterns in Go HTTP error messages
+var httpStatusPatterns = []*regexp.Regexp{
+ regexp.MustCompile(`status[:\s]+(\d{3})`),
+ regexp.MustCompile(`http[/\s]+\d*\.?\d*\s+(\d{3})`),
+ regexp.MustCompile(`\b([3-5]\d{2})\b`),
+}
+
// errorPattern defines a single pattern (string or regex) for error classification.
type errorPattern struct {
substring string
@@ -198,20 +205,13 @@ func classifyByMessage(msg string) FailoverReason {
}
// extractHTTPStatus extracts an HTTP status code from an error message.
-// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
+// Looks for patterns like "status: 429", "status 429", "http/1.1 429", "http 429", or standalone "429".
func extractHTTPStatus(msg string) int {
- // Common patterns in Go HTTP error messages
- patterns := []*regexp.Regexp{
- regexp.MustCompile(`status[:\s]+(\d{3})`),
- regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
- }
-
- for _, p := range patterns {
+ for _, p := range httpStatusPatterns {
if m := p.FindStringSubmatch(msg); len(m) > 1 {
return parseDigits(m[1])
}
}
-
return 0
}
diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go
index 865aea57a..67d9af62b 100644
--- a/pkg/providers/error_classifier_test.go
+++ b/pkg/providers/error_classifier_test.go
@@ -305,7 +305,8 @@ func TestExtractHTTPStatus(t *testing.T) {
}{
{"status: 429 rate limited", 429},
{"status 401 unauthorized", 401},
- {"HTTP/1.1 502 Bad Gateway", 502},
+ {"http/1.1 502 bad gateway", 502},
+ {"error 429", 429},
{"no status code here", 0},
{"random number 12345", 0},
}
diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go
index 11af14da4..5b3e42b9e 100644
--- a/pkg/providers/factory.go
+++ b/pkg/providers/factory.go
@@ -102,6 +102,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.apiBase = "https://openrouter.ai/api/v1"
}
}
+ case "litellm":
+ if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" {
+ sel.apiKey = cfg.Providers.LiteLLM.APIKey
+ sel.apiBase = cfg.Providers.LiteLLM.APIBase
+ sel.proxy = cfg.Providers.LiteLLM.Proxy
+ if sel.apiBase == "" {
+ sel.apiBase = "http://localhost:4000/v1"
+ }
+ }
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go
index 7d5566eef..155317a3b 100644
--- a/pkg/providers/factory_provider.go
+++ b/pkg/providers/factory_provider.go
@@ -53,7 +53,7 @@ func ExtractProtocol(model string) (protocol, modelID string) {
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
-// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
+// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot
// Returns the provider, the model ID (without protocol prefix), and any error.
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
if cfg == nil {
@@ -84,9 +84,15 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
- return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
+ return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
+ cfg.APIKey,
+ apiBase,
+ cfg.Proxy,
+ cfg.MaxTokensField,
+ cfg.RequestTimeout,
+ ), modelID, nil
- case "openrouter", "groq", "zhipu", "gemini", "nvidia",
+ case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"volcengine", "vllm", "qwen", "mistral":
// All other OpenAI-compatible HTTP providers
@@ -97,7 +103,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if apiBase == "" {
apiBase = getDefaultAPIBase(protocol)
}
- return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
+ return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
+ cfg.APIKey,
+ apiBase,
+ cfg.Proxy,
+ cfg.MaxTokensField,
+ cfg.RequestTimeout,
+ ), modelID, nil
case "anthropic":
if cfg.AuthMethod == "oauth" || cfg.AuthMethod == "token" {
@@ -116,7 +128,13 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
if cfg.APIKey == "" {
return nil, "", fmt.Errorf("api_key is required for anthropic protocol (model: %s)", cfg.Model)
}
- return NewHTTPProviderWithMaxTokensField(cfg.APIKey, apiBase, cfg.Proxy, cfg.MaxTokensField), modelID, nil
+ return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
+ cfg.APIKey,
+ apiBase,
+ cfg.Proxy,
+ cfg.MaxTokensField,
+ cfg.RequestTimeout,
+ ), modelID, nil
case "antigravity":
return NewAntigravityProvider(), modelID, nil
@@ -162,6 +180,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://api.openai.com/v1"
case "openrouter":
return "https://openrouter.ai/api/v1"
+ case "litellm":
+ return "http://localhost:4000/v1"
case "groq":
return "https://api.groq.com/openai/v1"
case "zhipu":
diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go
index 6b133101a..78389f331 100644
--- a/pkg/providers/factory_provider_test.go
+++ b/pkg/providers/factory_provider_test.go
@@ -6,7 +6,11 @@
package providers
import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
"testing"
+ "time"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -131,6 +135,32 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
}
}
+func TestGetDefaultAPIBase_LiteLLM(t *testing.T) {
+ if got := getDefaultAPIBase("litellm"); got != "http://localhost:4000/v1" {
+ t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "litellm", got, "http://localhost:4000/v1")
+ }
+}
+
+func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
+ cfg := &config.ModelConfig{
+ ModelName: "test-litellm",
+ Model: "litellm/my-proxy-alias",
+ APIKey: "test-key",
+ APIBase: "http://localhost:4000/v1",
+ }
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+ if provider == nil {
+ t.Fatal("CreateProviderFromConfig() returned nil provider")
+ }
+ if modelID != "my-proxy-alias" {
+ t.Errorf("modelID = %q, want %q", modelID, "my-proxy-alias")
+ }
+}
+
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-anthropic",
@@ -247,3 +277,42 @@ func TestCreateProviderFromConfig_EmptyModel(t *testing.T) {
t.Fatal("CreateProviderFromConfig() expected error for empty model")
}
}
+
+func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ time.Sleep(1500 * time.Millisecond)
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
+ }))
+ defer server.Close()
+
+ cfg := &config.ModelConfig{
+ ModelName: "test-timeout",
+ Model: "openai/gpt-4o",
+ APIBase: server.URL,
+ RequestTimeout: 1,
+ }
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+ if modelID != "gpt-4o" {
+ t.Fatalf("modelID = %q, want %q", modelID, "gpt-4o")
+ }
+
+ _, err = provider.Chat(
+ t.Context(),
+ []Message{{Role: "user", Content: "hi"}},
+ nil,
+ modelID,
+ nil,
+ )
+ if err == nil {
+ t.Fatal("Chat() expected timeout error, got nil")
+ }
+ errMsg := err.Error()
+ if !strings.Contains(errMsg, "context deadline exceeded") && !strings.Contains(errMsg, "Client.Timeout exceeded") {
+ t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
+ }
+}
diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go
index 5680f23b3..f7a916d9e 100644
--- a/pkg/providers/factory_test.go
+++ b/pkg/providers/factory_test.go
@@ -17,6 +17,27 @@ func TestResolveProviderSelection(t *testing.T) {
wantProxy string
wantErrSubstr string
}{
+ {
+ name: "explicit litellm provider uses configured base",
+ setup: func(cfg *config.Config) {
+ cfg.Agents.Defaults.Provider = "litellm"
+ cfg.Providers.LiteLLM.APIKey = "litellm-key"
+ cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1"
+ cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890"
+ },
+ wantType: providerTypeHTTPCompat,
+ wantAPIBase: "http://localhost:4000/v1",
+ wantProxy: "http://127.0.0.1:7890",
+ },
+ {
+ name: "explicit litellm provider defaults base when only key is configured",
+ setup: func(cfg *config.Config) {
+ cfg.Agents.Defaults.Provider = "litellm"
+ cfg.Providers.LiteLLM.APIKey = "litellm-key"
+ },
+ wantType: providerTypeHTTPCompat,
+ wantAPIBase: "http://localhost:4000/v1",
+ },
{
name: "explicit claude-cli provider routes to cli provider type",
setup: func(cfg *config.Config) {
diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go
index ecd451ec9..7ba563b66 100644
--- a/pkg/providers/fallback.go
+++ b/pkg/providers/fallback.go
@@ -43,11 +43,26 @@ func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
// ResolveCandidates parses model config into a deduplicated candidate list.
func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate {
+ return ResolveCandidatesWithLookup(cfg, defaultProvider, nil)
+}
+
+func ResolveCandidatesWithLookup(
+ cfg ModelConfig,
+ defaultProvider string,
+ lookup func(raw string) (resolved string, ok bool),
+) []FallbackCandidate {
seen := make(map[string]bool)
var candidates []FallbackCandidate
addCandidate := func(raw string) {
- ref := ParseModelRef(raw, defaultProvider)
+ candidateRaw := strings.TrimSpace(raw)
+ if lookup != nil {
+ if resolved, ok := lookup(candidateRaw); ok {
+ candidateRaw = resolved
+ }
+ }
+
+ ref := ParseModelRef(candidateRaw, defaultProvider)
if ref == nil {
return
}
diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go
index ebba054ef..1783ebcb5 100644
--- a/pkg/providers/fallback_test.go
+++ b/pkg/providers/fallback_test.go
@@ -453,6 +453,75 @@ func TestResolveCandidates_EmptyPrimary(t *testing.T) {
}
}
+func TestResolveCandidatesWithLookup_AliasResolvesToNestedModel(t *testing.T) {
+ cfg := ModelConfig{
+ Primary: "step-3.5-flash",
+ Fallbacks: nil,
+ }
+
+ lookup := func(raw string) (string, bool) {
+ if raw == "step-3.5-flash" {
+ return "openrouter/stepfun/step-3.5-flash:free", true
+ }
+ return "", false
+ }
+
+ candidates := ResolveCandidatesWithLookup(cfg, "", lookup)
+ if len(candidates) != 1 {
+ t.Fatalf("candidates = %d, want 1", len(candidates))
+ }
+ if candidates[0].Provider != "openrouter" {
+ t.Fatalf("provider = %q, want openrouter", candidates[0].Provider)
+ }
+ if candidates[0].Model != "stepfun/step-3.5-flash:free" {
+ t.Fatalf("model = %q, want stepfun/step-3.5-flash:free", candidates[0].Model)
+ }
+}
+
+func TestResolveCandidatesWithLookup_DeduplicateAfterLookup(t *testing.T) {
+ cfg := ModelConfig{
+ Primary: "step-3.5-flash",
+ Fallbacks: []string{"openrouter/stepfun/step-3.5-flash:free"},
+ }
+
+ lookup := func(raw string) (string, bool) {
+ if raw == "step-3.5-flash" {
+ return "openrouter/stepfun/step-3.5-flash:free", true
+ }
+ return "", false
+ }
+
+ candidates := ResolveCandidatesWithLookup(cfg, "", lookup)
+ if len(candidates) != 1 {
+ t.Fatalf("candidates = %d, want 1", len(candidates))
+ }
+}
+
+func TestResolveCandidatesWithLookup_AliasWithoutProtocolUsesDefaultProvider(t *testing.T) {
+ cfg := ModelConfig{
+ Primary: "glm-5",
+ Fallbacks: nil,
+ }
+
+ lookup := func(raw string) (string, bool) {
+ if raw == "glm-5" {
+ return "glm-5", true
+ }
+ return "", false
+ }
+
+ candidates := ResolveCandidatesWithLookup(cfg, "openai", lookup)
+ if len(candidates) != 1 {
+ t.Fatalf("candidates = %d, want 1", len(candidates))
+ }
+ if candidates[0].Provider != "openai" {
+ t.Fatalf("provider = %q, want openai", candidates[0].Provider)
+ }
+ if candidates[0].Model != "glm-5" {
+ t.Fatalf("model = %q, want glm-5", candidates[0].Model)
+ }
+}
+
func TestFallbackExhaustedError_Message(t *testing.T) {
e := &FallbackExhaustedError{
Attempts: []FallbackAttempt{
diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go
index 3fb15db2f..6d642b2b5 100644
--- a/pkg/providers/github_copilot_provider.go
+++ b/pkg/providers/github_copilot_provider.go
@@ -26,8 +26,9 @@ func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*Gi
switch connectMode {
case "stdio":
- // TODO:
- return nil, fmt.Errorf("stdio mode not implemented")
+ // TODO: Implement stdio mode for GitHub Copilot provider
+ // See https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md for details
+ return nil, fmt.Errorf("stdio mode not implemented for GitHub Copilot provider; please use 'grpc' mode instead")
case "grpc":
client := copilot.NewClient(&copilot.ClientOptions{
CLIUrl: uri,
@@ -100,9 +101,12 @@ func (p *GitHubCopilotProvider) Chat(
return nil, fmt.Errorf("provider closed")
}
- resp, _ := session.SendAndWait(ctx, copilot.MessageOptions{
+ resp, err := session.SendAndWait(ctx, copilot.MessageOptions{
Prompt: string(fullcontent),
})
+ if err != nil {
+ return nil, fmt.Errorf("failed to send message to copilot: %w", err)
+ }
if resp == nil {
return nil, fmt.Errorf("empty response from copilot")
diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go
index d0c4344f3..5c328f418 100644
--- a/pkg/providers/http_provider.go
+++ b/pkg/providers/http_provider.go
@@ -8,6 +8,7 @@ package providers
import (
"context"
+ "time"
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
)
@@ -23,8 +24,21 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
+ return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0)
+}
+
+func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
+ apiKey, apiBase, proxy, maxTokensField string,
+ requestTimeoutSeconds int,
+) *HTTPProvider {
return &HTTPProvider{
- delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField),
+ delegate: openai_compat.NewProvider(
+ apiKey,
+ apiBase,
+ proxy,
+ openai_compat.WithMaxTokensField(maxTokensField),
+ openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
+ ),
}
}
diff --git a/pkg/providers/legacy_provider.go b/pkg/providers/legacy_provider.go
index 23f137538..26905159f 100644
--- a/pkg/providers/legacy_provider.go
+++ b/pkg/providers/legacy_provider.go
@@ -18,9 +18,21 @@ import (
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
model := cfg.Agents.Defaults.GetModelName()
- // Ensure model_list is populated (should be done by LoadConfig, but handle edge cases)
- if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() {
- cfg.ModelList = config.ConvertProvidersToModelList(cfg)
+ // Ensure model_list is populated from providers config if needed
+ // This handles two cases:
+ // 1. ModelList is empty - convert all providers
+ // 2. ModelList has some entries but not all providers - merge missing ones
+ if cfg.HasProvidersConfig() {
+ providerModels := config.ConvertProvidersToModelList(cfg)
+ existingModelNames := make(map[string]bool)
+ for _, m := range cfg.ModelList {
+ existingModelNames[m.ModelName] = true
+ }
+ for _, pm := range providerModels {
+ if !existingModelNames[pm.ModelName] {
+ cfg.ModelList = append(cfg.ModelList, pm)
+ }
+ }
}
// Must have model_list at this point
diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go
index 087d3506e..ff9109e96 100644
--- a/pkg/providers/openai_compat/provider.go
+++ b/pkg/providers/openai_compat/provider.go
@@ -25,6 +25,7 @@ type (
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
+ ReasoningDetail = protocoltypes.ReasoningDetail
)
type Provider struct {
@@ -34,13 +35,27 @@ type Provider struct {
httpClient *http.Client
}
-func NewProvider(apiKey, apiBase, proxy string) *Provider {
- return NewProviderWithMaxTokensField(apiKey, apiBase, proxy, "")
+type Option func(*Provider)
+
+const defaultRequestTimeout = 120 * time.Second
+
+func WithMaxTokensField(maxTokensField string) Option {
+ return func(p *Provider) {
+ p.maxTokensField = maxTokensField
+ }
}
-func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
+func WithRequestTimeout(timeout time.Duration) Option {
+ return func(p *Provider) {
+ if timeout > 0 {
+ p.httpClient.Timeout = timeout
+ }
+ }
+}
+
+func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
client := &http.Client{
- Timeout: 120 * time.Second,
+ Timeout: defaultRequestTimeout,
}
if proxy != "" {
@@ -54,12 +69,36 @@ func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string
}
}
- return &Provider{
- apiKey: apiKey,
- apiBase: strings.TrimRight(apiBase, "/"),
- maxTokensField: maxTokensField,
- httpClient: client,
+ p := &Provider{
+ apiKey: apiKey,
+ apiBase: strings.TrimRight(apiBase, "/"),
+ httpClient: client,
}
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(p)
+ }
+ }
+
+ return p
+}
+
+func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *Provider {
+ return NewProvider(apiKey, apiBase, proxy, WithMaxTokensField(maxTokensField))
+}
+
+func NewProviderWithMaxTokensFieldAndTimeout(
+ apiKey, apiBase, proxy, maxTokensField string,
+ requestTimeoutSeconds int,
+) *Provider {
+ return NewProvider(
+ apiKey,
+ apiBase,
+ proxy,
+ WithMaxTokensField(maxTokensField),
+ WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
+ )
}
func (p *Provider) Chat(
@@ -77,7 +116,7 @@ func (p *Provider) Chat(
requestBody := map[string]any{
"model": model,
- "messages": stripSystemParts(messages),
+ "messages": serializeMessages(messages),
}
if len(tools) > 0 {
@@ -160,8 +199,10 @@ func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
- Content string `json:"content"`
- ReasoningContent string `json:"reasoning_content"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content"`
+ Reasoning string `json:"reasoning"`
+ ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
@@ -236,6 +277,8 @@ func parseResponse(body []byte) (*LLMResponse, error) {
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
+ Reasoning: choice.Message.Reasoning,
+ ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
@@ -246,31 +289,69 @@ func parseResponse(body []byte) (*LLMResponse, error) {
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
- ToolCallID string `json:"tool_call_id,omitempty"`
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
}
-// stripSystemParts converts []Message to []openaiMessage, dropping the
-// SystemParts field so it doesn't leak into the JSON payload sent to
-// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
-func stripSystemParts(messages []Message) []openaiMessage {
- out := make([]openaiMessage, len(messages))
- for i, m := range messages {
- out[i] = openaiMessage{
- Role: m.Role,
- Content: m.Content,
- ToolCalls: m.ToolCalls,
- ToolCallID: m.ToolCallID,
+// serializeMessages converts internal Message structs to the OpenAI wire format.
+// - Strips SystemParts (unknown to third-party endpoints)
+// - Converts messages with Media to multipart content format (text + image_url parts)
+// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
+func serializeMessages(messages []Message) []any {
+ out := make([]any, 0, len(messages))
+ for _, m := range messages {
+ if len(m.Media) == 0 {
+ out = append(out, openaiMessage{
+ Role: m.Role,
+ Content: m.Content,
+ ReasoningContent: m.ReasoningContent,
+ ToolCalls: m.ToolCalls,
+ ToolCallID: m.ToolCallID,
+ })
+ continue
}
+
+ // Multipart content format for messages with media
+ parts := make([]map[string]any, 0, 1+len(m.Media))
+ if m.Content != "" {
+ parts = append(parts, map[string]any{
+ "type": "text",
+ "text": m.Content,
+ })
+ }
+ for _, mediaURL := range m.Media {
+ parts = append(parts, map[string]any{
+ "type": "image_url",
+ "image_url": map[string]any{
+ "url": mediaURL,
+ },
+ })
+ }
+
+ msg := map[string]any{
+ "role": m.Role,
+ "content": parts,
+ }
+ if m.ToolCallID != "" {
+ msg["tool_call_id"] = m.ToolCallID
+ }
+ if len(m.ToolCalls) > 0 {
+ msg["tool_calls"] = m.ToolCalls
+ }
+ if m.ReasoningContent != "" {
+ msg["reasoning_content"] = m.ReasoningContent
+ }
+ out = append(out, msg)
}
return out
}
func normalizeModel(model, apiBase string) string {
- idx := strings.Index(model, "/")
- if idx == -1 {
+ before, after, ok := strings.Cut(model, "/")
+ if !ok {
return model
}
@@ -278,10 +359,10 @@ func normalizeModel(model, apiBase string) string {
return model
}
- prefix := strings.ToLower(model[:idx])
+ prefix := strings.ToLower(before)
switch prefix {
- case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
- return model[idx+1:]
+ case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
+ return after
default:
return model
}
diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go
index 594a48213..174bcf00d 100644
--- a/pkg/providers/openai_compat/provider_test.go
+++ b/pkg/providers/openai_compat/provider_test.go
@@ -5,7 +5,11 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
@@ -145,6 +149,56 @@ func TestProviderChat_ParsesReasoningContent(t *testing.T) {
}
}
+func TestProviderChat_PreservesReasoningContentInHistory(t *testing.T) {
+ var requestBody map[string]any
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ resp := map[string]any{
+ "choices": []map[string]any{
+ {
+ "message": map[string]any{"content": "ok"},
+ "finish_reason": "stop",
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer server.Close()
+
+ p := NewProvider("key", server.URL, "")
+
+ // Simulate a multi-turn conversation where the assistant's previous
+ // reply included reasoning_content (e.g. from kimi-k2.5).
+ messages := []Message{
+ {Role: "user", Content: "What is 1+1?"},
+ {Role: "assistant", Content: "2", ReasoningContent: "Let me think... 1+1=2"},
+ {Role: "user", Content: "What about 2+2?"},
+ }
+
+ _, err := p.Chat(t.Context(), messages, nil, "kimi-k2.5", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ // Verify reasoning_content is preserved in the serialized request.
+ reqMessages, ok := requestBody["messages"].([]any)
+ if !ok {
+ t.Fatalf("messages is not []any: %T", requestBody["messages"])
+ }
+ assistantMsg, ok := reqMessages[1].(map[string]any)
+ if !ok {
+ t.Fatalf("assistant message is not map[string]any: %T", reqMessages[1])
+ }
+ if assistantMsg["reasoning_content"] != "Let me think... 1+1=2" {
+ t.Errorf("reasoning_content not preserved in request, got %v", assistantMsg["reasoning_content"])
+ }
+}
+
func TestProviderChat_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
@@ -205,6 +259,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
input string
wantModel string
}{
+ {
+ name: "strips litellm prefix and preserves proxy model name",
+ input: "litellm/my-proxy-alias",
+ wantModel: "my-proxy-alias",
+ },
{
name: "strips groq prefix and keeps nested model",
input: "groq/openai/gpt-oss-120b",
@@ -325,3 +384,132 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
}
+
+func TestProvider_RequestTimeoutDefault(t *testing.T) {
+ p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 0)
+ if p.httpClient.Timeout != defaultRequestTimeout {
+ t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
+ }
+}
+
+func TestProvider_RequestTimeoutOverride(t *testing.T) {
+ p := NewProviderWithMaxTokensFieldAndTimeout("key", "https://example.com/v1", "", "", 300)
+ if p.httpClient.Timeout != 300*time.Second {
+ t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
+ }
+}
+
+func TestProvider_FunctionalOptionMaxTokensField(t *testing.T) {
+ p := NewProvider("key", "https://example.com/v1", "", WithMaxTokensField("max_completion_tokens"))
+ if p.maxTokensField != "max_completion_tokens" {
+ t.Fatalf("maxTokensField = %q, want %q", p.maxTokensField, "max_completion_tokens")
+ }
+}
+
+func TestProvider_FunctionalOptionRequestTimeout(t *testing.T) {
+ p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(45*time.Second))
+ if p.httpClient.Timeout != 45*time.Second {
+ t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, 45*time.Second)
+ }
+}
+
+func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) {
+ p := NewProvider("key", "https://example.com/v1", "", WithRequestTimeout(-1*time.Second))
+ if p.httpClient.Timeout != defaultRequestTimeout {
+ t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
+ }
+}
+
+func TestSerializeMessages_PlainText(t *testing.T) {
+ messages := []protocoltypes.Message{
+ {Role: "user", Content: "hello"},
+ {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
+ }
+ result := serializeMessages(messages)
+
+ data, err := json.Marshal(result)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ if msgs[0]["content"] != "hello" {
+ t.Fatalf("expected plain string content, got %v", msgs[0]["content"])
+ }
+ if msgs[1]["reasoning_content"] != "thinking..." {
+ t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
+ }
+}
+
+func TestSerializeMessages_WithMedia(t *testing.T) {
+ messages := []protocoltypes.Message{
+ {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
+ }
+ result := serializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ content, ok := msgs[0]["content"].([]any)
+ if !ok {
+ t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
+ }
+ if len(content) != 2 {
+ t.Fatalf("expected 2 content parts, got %d", len(content))
+ }
+
+ textPart := content[0].(map[string]any)
+ if textPart["type"] != "text" || textPart["text"] != "describe this" {
+ t.Fatalf("text part mismatch: %v", textPart)
+ }
+
+ imgPart := content[1].(map[string]any)
+ if imgPart["type"] != "image_url" {
+ t.Fatalf("expected image_url type, got %v", imgPart["type"])
+ }
+ imgURL := imgPart["image_url"].(map[string]any)
+ if imgURL["url"] != "data:image/png;base64,abc123" {
+ t.Fatalf("image url mismatch: %v", imgURL["url"])
+ }
+}
+
+func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
+ messages := []protocoltypes.Message{
+ {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
+ }
+ result := serializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ var msgs []map[string]any
+ json.Unmarshal(data, &msgs)
+
+ if msgs[0]["tool_call_id"] != "call_1" {
+ t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"])
+ }
+ // Content should be multipart array
+ if _, ok := msgs[0]["content"].([]any); !ok {
+ t.Fatalf("expected array content, got %T", msgs[0]["content"])
+ }
+}
+
+func TestSerializeMessages_StripsSystemParts(t *testing.T) {
+ messages := []protocoltypes.Message{
+ {
+ Role: "system",
+ Content: "you are helpful",
+ SystemParts: []protocoltypes.ContentBlock{
+ {Type: "text", Text: "you are helpful"},
+ },
+ },
+ }
+ result := serializeMessages(messages)
+
+ data, _ := json.Marshal(result)
+ raw := string(data)
+ if strings.Contains(raw, "system_parts") {
+ t.Fatal("system_parts should not appear in serialized output")
+ }
+}
diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go
index 33f052c5a..194c1aa6f 100644
--- a/pkg/providers/protocoltypes/types.go
+++ b/pkg/providers/protocoltypes/types.go
@@ -25,11 +25,20 @@ type FunctionCall struct {
}
type LLMResponse struct {
- Content string `json:"content"`
- ReasoningContent string `json:"reasoning_content,omitempty"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
- FinishReason string `json:"finish_reason"`
- Usage *UsageInfo `json:"usage,omitempty"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ FinishReason string `json:"finish_reason"`
+ Usage *UsageInfo `json:"usage,omitempty"`
+ Reasoning string `json:"reasoning"`
+ ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
+}
+
+type ReasoningDetail struct {
+ Format string `json:"format"`
+ Index int `json:"index"`
+ Type string `json:"type"`
+ Text string `json:"text"`
}
type UsageInfo struct {
@@ -56,6 +65,7 @@ type ContentBlock struct {
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
+ Media []string `json:"media,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go
index 49218b1b1..a33e1eb5c 100644
--- a/pkg/providers/toolcall_utils.go
+++ b/pkg/providers/toolcall_utils.go
@@ -5,7 +5,43 @@
package providers
-import "encoding/json"
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+)
+
+// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt.
+func buildCLIToolsPrompt(tools []ToolDefinition) string {
+ var sb strings.Builder
+
+ sb.WriteString("## Available Tools\n\n")
+ sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
+ sb.WriteString("```json\n")
+ sb.WriteString(
+ `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
+ )
+ sb.WriteString("\n```\n\n")
+ sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
+ sb.WriteString("### Tool Definitions:\n\n")
+
+ for _, tool := range tools {
+ if tool.Type != "function" {
+ continue
+ }
+ sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
+ if tool.Function.Description != "" {
+ sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
+ }
+ if len(tool.Function.Parameters) > 0 {
+ paramsJSON, _ := json.Marshal(tool.Function.Parameters)
+ sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
+ }
+ sb.WriteString("\n")
+ }
+
+ return sb.String()
+}
// NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated.
// It handles cases where Name/Arguments might be in different locations (top-level vs Function)
diff --git a/pkg/routing/agent_id_test.go b/pkg/routing/agent_id_test.go
index 050fe0645..f9a65c969 100644
--- a/pkg/routing/agent_id_test.go
+++ b/pkg/routing/agent_id_test.go
@@ -1,6 +1,9 @@
package routing
-import "testing"
+import (
+ "strings"
+ "testing"
+)
func TestNormalizeAgentID_Empty(t *testing.T) {
if got := NormalizeAgentID(""); got != DefaultAgentID {
@@ -57,11 +60,11 @@ func TestNormalizeAgentID_AllInvalid(t *testing.T) {
}
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
- long := ""
- for i := 0; i < 100; i++ {
- long += "a"
+ var long strings.Builder
+ for range 100 {
+ long.WriteString("a")
}
- got := NormalizeAgentID(long)
+ got := NormalizeAgentID(long.String())
if len(got) > MaxAgentIDLength {
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
}
diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go
index e12f0d1d8..eab592bec 100644
--- a/pkg/routing/session_key.go
+++ b/pkg/routing/session_key.go
@@ -163,6 +163,15 @@ func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID stri
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
candidates[scopedCandidate] = true
}
+
+ // If peerID is already in canonical "platform:id" format, also add the
+ // bare ID part as a candidate for backward compatibility with identity_links
+ // that use raw IDs (e.g. "123" instead of "telegram:123").
+ if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
+ bareID := rawCandidate[idx+1:]
+ candidates[bareID] = true
+ }
+
if len(candidates) == 0 {
return ""
}
diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go
index 81e4ce018..ad7a1ca02 100644
--- a/pkg/routing/session_key_test.go
+++ b/pkg/routing/session_key_test.go
@@ -115,6 +115,51 @@ func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
}
}
+func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) {
+ // When peerID is already in canonical "platform:id" format,
+ // it should match identity_links that use the bare ID.
+ links := map[string][]string{
+ "john": {"123"},
+ }
+ got := resolveLinkedPeerID(links, "telegram", "telegram:123")
+ if got != "john" {
+ t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john")
+ }
+}
+
+func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) {
+ // When identity_links contain canonical IDs and peerID is canonical too
+ links := map[string][]string{
+ "john": {"telegram:123", "discord:456"},
+ }
+ got := resolveLinkedPeerID(links, "telegram", "telegram:123")
+ if got != "john" {
+ t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john")
+ }
+}
+
+func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) {
+ // When peerID is bare "123" and links have "telegram:123",
+ // the scoped candidate "telegram:123" should match.
+ links := map[string][]string{
+ "john": {"telegram:123"},
+ }
+ got := resolveLinkedPeerID(links, "telegram", "123")
+ if got != "john" {
+ t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john")
+ }
+}
+
+func TestResolveLinkedPeerID_NoMatch(t *testing.T) {
+ links := map[string][]string{
+ "john": {"telegram:123"},
+ }
+ got := resolveLinkedPeerID(links, "discord", "999")
+ if got != "" {
+ t.Errorf("resolveLinkedPeerID no match = %q, want empty", got)
+ }
+}
+
func TestParseAgentSessionKey_Valid(t *testing.T) {
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
if parsed == nil {
diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go
index 3210509df..c9f19f25d 100644
--- a/pkg/skills/installer.go
+++ b/pkg/skills/installer.go
@@ -2,27 +2,21 @@ package skills
import (
"context"
- "encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
+ "github.com/sipeed/picoclaw/pkg/utils"
)
type SkillInstaller struct {
workspace string
}
-type AvailableSkill struct {
- Name string `json:"name"`
- Repository string `json:"repository"`
- Description string `json:"description"`
- Author string `json:"author"`
- Tags []string `json:"tags"`
-}
-
func NewSkillInstaller(workspace string) *SkillInstaller {
return &SkillInstaller{
workspace: workspace,
@@ -44,7 +38,7 @@ func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) er
return fmt.Errorf("failed to create request: %w", err)
}
- resp, err := client.Do(req)
+ resp, err := utils.DoRequestWithRetry(client, req)
if err != nil {
return fmt.Errorf("failed to fetch skill: %w", err)
}
@@ -64,7 +58,9 @@ func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) er
}
skillPath := filepath.Join(skillDir, "SKILL.md")
- if err := os.WriteFile(skillPath, body, 0o644); err != nil {
+
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ if err := fileutil.WriteFileAtomic(skillPath, body, 0o600); err != nil {
return fmt.Errorf("failed to write skill file: %w", err)
}
@@ -84,35 +80,3 @@ func (si *SkillInstaller) Uninstall(skillName string) error {
return nil
}
-
-func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableSkill, error) {
- url := "https://raw.githubusercontent.com/sipeed/picoclaw-skills/main/skills.json"
-
- client := &http.Client{Timeout: 15 * time.Second}
- req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
-
- resp, err := client.Do(req)
- if err != nil {
- return nil, fmt.Errorf("failed to fetch skills list: %w", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- return nil, fmt.Errorf("failed to fetch skills list: HTTP %d", resp.StatusCode)
- }
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("failed to read response: %w", err)
- }
-
- var skills []AvailableSkill
- if err := json.Unmarshal(body, &skills); err != nil {
- return nil, fmt.Errorf("failed to parse skills list: %w", err)
- }
-
- return skills, nil
-}
diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go
index 5749d8983..30d84635a 100644
--- a/pkg/skills/loader.go
+++ b/pkg/skills/loader.go
@@ -13,7 +13,11 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
)
-var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
+var (
+ namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
+ reFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`)
+ reStripFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
+)
const (
MaxNameLength = 64
@@ -60,6 +64,29 @@ type SkillsLoader struct {
builtinSkills string // builtin skills
}
+// SkillRoots returns all unique skill root directories used by this loader.
+// The order follows resolution priority: workspace > global > builtin.
+func (sl *SkillsLoader) SkillRoots() []string {
+ roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills}
+ seen := make(map[string]struct{}, len(roots))
+ out := make([]string, 0, len(roots))
+
+ for _, root := range roots {
+ trimmed := strings.TrimSpace(root)
+ if trimmed == "" {
+ continue
+ }
+ clean := filepath.Clean(trimmed)
+ if _, ok := seen[clean]; ok {
+ continue
+ }
+ seen[clean] = struct{}{}
+ out = append(out, clean)
+ }
+
+ return out
+}
+
func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader {
return &SkillsLoader{
workspace: workspace,
@@ -236,7 +263,7 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
normalized := strings.ReplaceAll(content, "\r\n", "\n")
normalized = strings.ReplaceAll(normalized, "\r", "\n")
- for _, line := range strings.Split(normalized, "\n") {
+ for line := range strings.SplitSeq(normalized, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
@@ -257,10 +284,7 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
func (sl *SkillsLoader) extractFrontmatter(content string) string {
// Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
- // (?s) enables DOTALL so . matches newlines;
- // ^--- at start, then ... --- at start of line, honoring all three line ending types
- re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`)
- match := re.FindStringSubmatch(content)
+ match := reFrontmatter.FindStringSubmatch(content)
if len(match) > 1 {
return match[1]
}
@@ -268,12 +292,7 @@ func (sl *SkillsLoader) extractFrontmatter(content string) string {
}
func (sl *SkillsLoader) stripFrontmatter(content string) string {
- // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
- // (?s) enables DOTALL so . matches newlines;
- // ^--- at start, then ... --- at start of line, honoring all three line ending types
- // Match zero or more trailing line endings after closing --- (handles both with and without blank lines)
- re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
- return re.ReplaceAllString(content, "")
+ return reStripFrontmatter.ReplaceAllString(content, "")
}
func escapeXML(s string) string {
diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go
index 9428bea62..31619f9c2 100644
--- a/pkg/skills/loader_test.go
+++ b/pkg/skills/loader_test.go
@@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) {
})
}
}
+
+func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) {
+ tmp := t.TempDir()
+ workspace := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+ builtin := filepath.Join(tmp, "builtin")
+
+ sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n")
+ roots := sl.SkillRoots()
+
+ assert.Equal(t, []string{
+ filepath.Join(workspace, "skills"),
+ global,
+ builtin,
+ }, roots)
+}
diff --git a/pkg/skills/search_cache.go b/pkg/skills/search_cache.go
index 5d7d2797e..1686e3f98 100644
--- a/pkg/skills/search_cache.go
+++ b/pkg/skills/search_cache.go
@@ -1,7 +1,7 @@
package skills
import (
- "sort"
+ "slices"
"strings"
"sync"
"time"
@@ -183,7 +183,7 @@ func buildTrigrams(s string) []uint32 {
}
// Sort and Deduplication
- sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
+ slices.Sort(trigrams)
n := 1
for i := 1; i < len(trigrams); i++ {
if trigrams[i] != trigrams[i-1] {
diff --git a/pkg/skills/search_cache_test.go b/pkg/skills/search_cache_test.go
index 816bdfb93..6bbb0e6eb 100644
--- a/pkg/skills/search_cache_test.go
+++ b/pkg/skills/search_cache_test.go
@@ -153,7 +153,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
// Concurrent writes
go func() {
- for i := 0; i < 100; i++ {
+ for i := range 100 {
cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
}
done <- struct{}{}
@@ -161,7 +161,7 @@ func TestSearchCacheConcurrency(t *testing.T) {
// Concurrent reads
go func() {
- for i := 0; i < 100; i++ {
+ for range 100 {
cache.Get("query-write-a")
}
done <- struct{}{}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index 1a92f82ed..57f371f12 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -8,6 +8,8 @@ import (
"path/filepath"
"sync"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
)
// State represents the persistent state for a workspace.
@@ -38,7 +40,9 @@ func NewManager(workspace string) *Manager {
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
- os.MkdirAll(stateDir, 0o755)
+ if err := os.MkdirAll(stateDir, 0o755); err != nil {
+ log.Fatalf("[FATAL] state: failed to create state directory: %v", err)
+ }
sm := &Manager{
workspace: workspace,
@@ -52,13 +56,17 @@ func NewManager(workspace string) *Manager {
if data, err := os.ReadFile(oldStateFile); err == nil {
if err := json.Unmarshal(data, sm.state); err == nil {
// Migrate to new location
- sm.saveAtomic()
+ if err := sm.saveAtomic(); err != nil {
+ log.Printf("[WARN] state: failed to save state: %v", err)
+ }
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
}
}
} else {
// Load from new location
- sm.load()
+ if err := sm.load(); err != nil {
+ log.Printf("[WARN] state: failed to load state: %v", err)
+ }
}
return sm
@@ -124,33 +132,20 @@ func (sm *Manager) GetTimestamp() time.Time {
// saveAtomic performs an atomic save using temp file + rename.
// This ensures that the state file is never corrupted:
// 1. Write to a temp file
-// 2. Rename temp file to target (atomic on POSIX systems)
-// 3. If rename fails, cleanup the temp file
+// 2. Sync to disk (critical for SD cards/flash storage)
+// 3. Rename temp file to target (atomic on POSIX systems)
+// 4. If rename fails, cleanup the temp file
//
// Must be called with the lock held.
func (sm *Manager) saveAtomic() error {
- // Create temp file in the same directory as the target
- tempFile := sm.stateFile + ".tmp"
-
- // Marshal state to JSON
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ // Using 0o600 (owner read/write only) for secure default permissions.
data, err := json.MarshalIndent(sm.state, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
- // Write to temp file
- if err := os.WriteFile(tempFile, data, 0o644); err != nil {
- return fmt.Errorf("failed to write temp file: %w", err)
- }
-
- // Atomic rename from temp to target
- if err := os.Rename(tempFile, sm.stateFile); err != nil {
- // Cleanup temp file if rename fails
- os.Remove(tempFile)
- return fmt.Errorf("failed to rename temp file: %w", err)
- }
-
- return nil
+ return fileutil.WriteFileAtomic(sm.stateFile, data, 0o600)
}
// load loads the state from disk.
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
index f717a5bb4..e5e116ef6 100644
--- a/pkg/state/state_test.go
+++ b/pkg/state/state_test.go
@@ -2,8 +2,10 @@ package state
import (
"encoding/json"
+ "errors"
"fmt"
"os"
+ "os/exec"
"path/filepath"
"testing"
)
@@ -135,7 +137,7 @@ func TestConcurrentAccess(t *testing.T) {
// Test concurrent writes
done := make(chan bool, 10)
- for i := 0; i < 10; i++ {
+ for i := range 10 {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
@@ -144,7 +146,7 @@ func TestConcurrentAccess(t *testing.T) {
}
// Wait for all goroutines to complete
- for i := 0; i < 10; i++ {
+ for range 10 {
<-done
}
@@ -214,3 +216,39 @@ func TestNewManager_EmptyWorkspace(t *testing.T) {
t.Error("Expected zero timestamp for new state")
}
}
+
+func TestNewManager_MkdirFailureCrashes(t *testing.T) {
+ // Since log.Fatalf calls os.Exit(1), we cannot test it normally
+ // Otherwise, the test suite would stop altogether.
+ // We use the standard pattern of Go: rerun this test in a subprocess.
+ if os.Getenv("BE_CRASHER") == "1" {
+ tmpDir := os.Getenv("CRASH_DIR")
+
+ statePath := filepath.Join(tmpDir, "state")
+ if err := os.WriteFile(statePath, []byte("I'm a file, not a folder"), 0o644); err != nil {
+ fmt.Printf("setup failed: %v", err)
+ os.Exit(0)
+ }
+
+ NewManager(tmpDir)
+ os.Exit(0)
+ }
+
+ tmpDir, err := os.MkdirTemp("", "state-crash-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureCrashes")
+ cmd.Env = append(os.Environ(), "BE_CRASHER=1", "CRASH_DIR="+tmpDir)
+
+ err = cmd.Run()
+
+ var e *exec.ExitError
+ if errors.As(err, &e) && !e.Success() {
+ return
+ }
+
+ t.Fatalf("The process ended without error, a crash was expected via os.Exit(1). Err: %v", err)
+}
diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go
index 562fffc84..6888d1326 100644
--- a/pkg/tools/cron.go
+++ b/pkg/tools/cron.go
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
+ "strings"
"sync"
"time"
@@ -33,15 +34,19 @@ type CronTool struct {
func NewCronTool(
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
execTimeout time.Duration, config *config.Config,
-) *CronTool {
- execTool := NewExecToolWithConfig(workspace, restrict, config)
+) (*CronTool, error) {
+ execTool, err := NewExecToolWithConfig(workspace, restrict, config)
+ if err != nil {
+ return nil, fmt.Errorf("unable to configure exec tool: %w", err)
+ }
+
execTool.SetTimeout(execTimeout)
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
- }
+ }, nil
}
// Name returns the tool name
@@ -218,7 +223,8 @@ func (t *CronTool) listJobs() *ToolResult {
return SilentResult("No scheduled jobs")
}
- result := "Scheduled jobs:\n"
+ var result strings.Builder
+ result.WriteString("Scheduled jobs:\n")
for _, j := range jobs {
var scheduleInfo string
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
@@ -230,10 +236,10 @@ func (t *CronTool) listJobs() *ToolResult {
} else {
scheduleInfo = "unknown"
}
- result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
+ result.WriteString(fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo))
}
- return SilentResult(result)
+ return SilentResult(result.String())
}
func (t *CronTool) removeJob(args map[string]any) *ToolResult {
@@ -294,7 +300,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
}
- t.msgBus.PublishOutbound(bus.OutboundMessage{
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer pubCancel()
+ t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: output,
@@ -304,7 +312,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// If deliver=true, send message directly without agent processing
if job.Payload.Deliver {
- t.msgBus.PublishOutbound(bus.OutboundMessage{
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer pubCancel()
+ t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: job.Payload.Message,
diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go
index d3ab267bf..d5bebf4a2 100644
--- a/pkg/tools/edit.go
+++ b/pkg/tools/edit.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io/fs"
+ "regexp"
"strings"
)
@@ -15,14 +16,12 @@ type EditFileTool struct {
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
-func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewEditFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *EditFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &EditFileTool{fs: fs}
+ return &EditFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *EditFileTool) Name() string {
@@ -80,14 +79,12 @@ type AppendFileTool struct {
fs fileSystem
}
-func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewAppendFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *AppendFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &AppendFileTool{fs: fs}
+ return &AppendFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *AppendFileTool) Name() string {
diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go
index 37db8b4ae..cd8da3195 100644
--- a/pkg/tools/filesystem.go
+++ b/pkg/tools/filesystem.go
@@ -6,8 +6,11 @@ import (
"io/fs"
"os"
"path/filepath"
+ "regexp"
"strings"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/fileutil"
)
// validatePath ensures the given path is within the workspace if restrict is true.
@@ -85,14 +88,12 @@ type ReadFileTool struct {
fs fileSystem
}
-func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewReadFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ReadFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &ReadFileTool{fs: fs}
+ return &ReadFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *ReadFileTool) Name() string {
@@ -133,14 +134,12 @@ type WriteFileTool struct {
fs fileSystem
}
-func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &WriteFileTool{fs: fs}
+ return &WriteFileTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *WriteFileTool) Name() string {
@@ -190,14 +189,12 @@ type ListDirTool struct {
fs fileSystem
}
-func NewListDirTool(workspace string, restrict bool) *ListDirTool {
- var fs fileSystem
- if restrict {
- fs = &sandboxFs{workspace: workspace}
- } else {
- fs = &hostFs{}
+func NewListDirTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *ListDirTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
}
- return &ListDirTool{fs: fs}
+ return &ListDirTool{fs: buildFs(workspace, restrict, patterns)}
}
func (t *ListDirTool) Name() string {
@@ -276,25 +273,9 @@ func (h *hostFs) ReadDir(path string) ([]os.DirEntry, error) {
}
func (h *hostFs) WriteFile(path string, data []byte) error {
- dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0o755); err != nil {
- return fmt.Errorf("failed to create parent directories: %w", err)
- }
-
- // We use a "write-then-rename" pattern here to ensure an atomic write.
- // This prevents the target file from being left in a truncated or partial state
- // if the operation is interrupted, as the rename operation is atomic on Linux.
- tmpPath := fmt.Sprintf("%s.%d.tmp", path, time.Now().UnixNano())
- if err := os.WriteFile(tmpPath, data, 0o644); err != nil {
- os.Remove(tmpPath) // Ensure cleanup of partial/empty temp file
- return fmt.Errorf("failed to write temp file: %w", err)
- }
-
- if err := os.Rename(tmpPath, path); err != nil {
- os.Remove(tmpPath)
- return fmt.Errorf("failed to replace original file: %w", err)
- }
- return nil
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ // Using 0o600 (owner read/write only) for secure default permissions.
+ return fileutil.WriteFileAtomic(path, data, 0o600)
}
// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root.
@@ -351,20 +332,46 @@ func (r *sandboxFs) WriteFile(path string, data []byte) error {
}
}
- // We use a "write-then-rename" pattern here to ensure an atomic write.
- // This prevents the target file from being left in a truncated or partial state
- // if the operation is interrupted, as the rename operation is atomic on Linux.
- tmpRelPath := fmt.Sprintf("%s.%d.tmp", relPath, time.Now().UnixNano())
+ // Use atomic write pattern with explicit sync for flash storage reliability.
+ // Using 0o600 (owner read/write only) for secure default permissions.
+ tmpRelPath := fmt.Sprintf(".tmp-%d-%d", os.Getpid(), time.Now().UnixNano())
- if err := root.WriteFile(tmpRelPath, data, 0o644); err != nil {
- root.Remove(tmpRelPath) // Ensure cleanup of partial/empty temp file
- return fmt.Errorf("failed to write to temp file: %w", err)
+ tmpFile, err := root.OpenFile(tmpRelPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o600)
+ if err != nil {
+ root.Remove(tmpRelPath)
+ return fmt.Errorf("failed to open temp file: %w", err)
+ }
+
+ if _, err := tmpFile.Write(data); err != nil {
+ tmpFile.Close()
+ root.Remove(tmpRelPath)
+ return fmt.Errorf("failed to write temp file: %w", err)
+ }
+
+ // CRITICAL: Force sync to storage medium before rename.
+ // This ensures data is physically written to disk, not just cached.
+ if err := tmpFile.Sync(); err != nil {
+ tmpFile.Close()
+ root.Remove(tmpRelPath)
+ return fmt.Errorf("failed to sync temp file: %w", err)
+ }
+
+ if err := tmpFile.Close(); err != nil {
+ root.Remove(tmpRelPath)
+ return fmt.Errorf("failed to close temp file: %w", err)
}
if err := root.Rename(tmpRelPath, relPath); err != nil {
root.Remove(tmpRelPath)
return fmt.Errorf("failed to rename temp file over target: %w", err)
}
+
+ // Sync directory to ensure rename is durable
+ if dirFile, err := root.Open("."); err == nil {
+ _ = dirFile.Sync()
+ dirFile.Close()
+ }
+
return nil
})
}
@@ -382,6 +389,57 @@ func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
return entries, err
}
+// whitelistFs wraps a sandboxFs and allows access to specific paths outside
+// the workspace when they match any of the provided patterns.
+type whitelistFs struct {
+ sandbox *sandboxFs
+ host hostFs
+ patterns []*regexp.Regexp
+}
+
+func (w *whitelistFs) matches(path string) bool {
+ for _, p := range w.patterns {
+ if p.MatchString(path) {
+ return true
+ }
+ }
+ return false
+}
+
+func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
+ if w.matches(path) {
+ return w.host.ReadFile(path)
+ }
+ return w.sandbox.ReadFile(path)
+}
+
+func (w *whitelistFs) WriteFile(path string, data []byte) error {
+ if w.matches(path) {
+ return w.host.WriteFile(path, data)
+ }
+ return w.sandbox.WriteFile(path, data)
+}
+
+func (w *whitelistFs) ReadDir(path string) ([]os.DirEntry, error) {
+ if w.matches(path) {
+ return w.host.ReadDir(path)
+ }
+ return w.sandbox.ReadDir(path)
+}
+
+// buildFs returns the appropriate fileSystem implementation based on restriction
+// settings and optional path whitelist patterns.
+func buildFs(workspace string, restrict bool, patterns []*regexp.Regexp) fileSystem {
+ if !restrict {
+ return &hostFs{}
+ }
+ sandbox := &sandboxFs{workspace: workspace}
+ if len(patterns) > 0 {
+ return &whitelistFs{sandbox: sandbox, patterns: patterns}
+ }
+ return sandbox
+}
+
// Helper to get a safe relative path for os.Root usage
func getSafeRelPath(workspace, path string) (string, error) {
if workspace == "" {
diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go
index 6f896e22d..666004cd4 100644
--- a/pkg/tools/filesystem_test.go
+++ b/pkg/tools/filesystem_test.go
@@ -5,6 +5,7 @@ import (
"io"
"os"
"path/filepath"
+ "regexp"
"strings"
"testing"
@@ -486,3 +487,36 @@ func TestRootRW_Write(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, newData, content)
}
+
+// TestWhitelistFs_AllowsMatchingPaths verifies that whitelistFs allows access to
+// paths matching the whitelist patterns while blocking non-matching paths.
+func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
+ workspace := t.TempDir()
+ outsideDir := t.TempDir()
+ outsideFile := filepath.Join(outsideDir, "allowed.txt")
+ os.WriteFile(outsideFile, []byte("outside content"), 0o644)
+
+ // Pattern allows access to the outsideDir.
+ patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(outsideDir))}
+
+ tool := NewReadFileTool(workspace, true, patterns)
+
+ // Read from whitelisted path should succeed.
+ result := tool.Execute(context.Background(), map[string]any{"path": outsideFile})
+ if result.IsError {
+ t.Errorf("expected whitelisted path to be readable, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "outside content") {
+ t.Errorf("expected file content, got: %s", result.ForLLM)
+ }
+
+ // Read from non-whitelisted path outside workspace should fail.
+ otherDir := t.TempDir()
+ otherFile := filepath.Join(otherDir, "blocked.txt")
+ os.WriteFile(otherFile, []byte("blocked"), 0o644)
+
+ result = tool.Execute(context.Background(), map[string]any{"path": otherFile})
+ if !result.IsError {
+ t.Errorf("expected non-whitelisted path to be blocked, got: %s", result.ForLLM)
+ }
+}
diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go
new file mode 100644
index 000000000..6e53cf354
--- /dev/null
+++ b/pkg/tools/mcp_tool.go
@@ -0,0 +1,246 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "hash/fnv"
+ "strings"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+// MCPManager defines the interface for MCP manager operations
+// This allows for easier testing with mock implementations
+type MCPManager interface {
+ CallTool(
+ ctx context.Context,
+ serverName, toolName string,
+ arguments map[string]any,
+ ) (*mcp.CallToolResult, error)
+}
+
+// MCPTool wraps an MCP tool to implement the Tool interface
+type MCPTool struct {
+ manager MCPManager
+ serverName string
+ tool *mcp.Tool
+}
+
+// NewMCPTool creates a new MCP tool wrapper
+func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
+ return &MCPTool{
+ manager: manager,
+ serverName: serverName,
+ tool: tool,
+ }
+}
+
+// sanitizeIdentifierComponent normalizes a string so it can be safely used
+// as part of a tool/function identifier for downstream providers.
+// It:
+// - lowercases the string
+// - replaces any character not in [a-z0-9_-] with '_'
+// - collapses multiple consecutive '_' into a single '_'
+// - trims leading/trailing '_'
+// - falls back to "unnamed" if the result is empty
+// - truncates overly long components to a reasonable length
+func sanitizeIdentifierComponent(s string) string {
+ const maxLen = 64
+
+ s = strings.ToLower(s)
+ var b strings.Builder
+ b.Grow(len(s))
+
+ prevUnderscore := false
+ for _, r := range s {
+ isAllowed := (r >= 'a' && r <= 'z') ||
+ (r >= '0' && r <= '9') ||
+ r == '_' || r == '-'
+
+ if !isAllowed {
+ // Normalize any disallowed character to '_'
+ if !prevUnderscore {
+ b.WriteRune('_')
+ prevUnderscore = true
+ }
+ continue
+ }
+
+ if r == '_' {
+ if prevUnderscore {
+ continue
+ }
+ prevUnderscore = true
+ } else {
+ prevUnderscore = false
+ }
+
+ b.WriteRune(r)
+ }
+
+ result := strings.Trim(b.String(), "_")
+ if result == "" {
+ result = "unnamed"
+ }
+
+ if len(result) > maxLen {
+ result = result[:maxLen]
+ }
+
+ return result
+}
+
+// Name returns the tool name, prefixed with the server name.
+// The total length is capped at 64 characters (OpenAI-compatible API limit).
+// A short hash of the original (unsanitized) server and tool names is appended
+// whenever sanitization is lossy or the name is truncated, ensuring that two
+// names which differ only in disallowed characters remain distinct after sanitization.
+func (t *MCPTool) Name() string {
+ // Prefix with server name to avoid conflicts, and sanitize components
+ sanitizedServer := sanitizeIdentifierComponent(t.serverName)
+ sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
+ full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
+
+ // Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
+ lossless := strings.ToLower(t.serverName) == sanitizedServer &&
+ strings.ToLower(t.tool.Name) == sanitizedTool
+
+ const maxTotal = 64
+ if lossless && len(full) <= maxTotal {
+ return full
+ }
+
+ // Sanitization was lossy or name too long: append hash of the ORIGINAL names
+ // (not the sanitized names) so different originals always yield different hashes.
+ h := fnv.New32a()
+ _, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
+ suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
+
+ base := full
+ if len(base) > maxTotal-9 {
+ base = strings.TrimRight(full[:maxTotal-9], "_")
+ }
+ return base + "_" + suffix
+}
+
+// Description returns the tool description
+func (t *MCPTool) Description() string {
+ desc := t.tool.Description
+ if desc == "" {
+ desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
+ }
+ // Add server info to description
+ return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
+}
+
+// Parameters returns the tool parameters schema
+func (t *MCPTool) Parameters() map[string]any {
+ // The InputSchema is already a JSON Schema object
+ schema := t.tool.InputSchema
+
+ // Handle nil schema
+ if schema == nil {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ // Try direct conversion first (fast path)
+ if schemaMap, ok := schema.(map[string]any); ok {
+ return schemaMap
+ }
+
+ // Handle json.RawMessage and []byte - unmarshal directly
+ var jsonData []byte
+ if rawMsg, ok := schema.(json.RawMessage); ok {
+ jsonData = rawMsg
+ } else if bytes, ok := schema.([]byte); ok {
+ jsonData = bytes
+ }
+
+ if jsonData != nil {
+ var result map[string]any
+ if err := json.Unmarshal(jsonData, &result); err == nil {
+ return result
+ }
+ // Fallback on error
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ // For other types (structs, etc.), convert via JSON marshal/unmarshal
+ var err error
+ jsonData, err = json.Marshal(schema)
+ if err != nil {
+ // Fallback to empty schema if marshaling fails
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ var result map[string]any
+ if err := json.Unmarshal(jsonData, &result); err != nil {
+ // Fallback to empty schema if unmarshaling fails
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ "required": []string{},
+ }
+ }
+
+ return result
+}
+
+// Execute executes the MCP tool
+func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
+ result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
+ if err != nil {
+ return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
+ }
+
+ if result == nil {
+ nilErr := fmt.Errorf("MCP tool returned nil result without error")
+ return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
+ }
+
+ // Handle error result from server
+ if result.IsError {
+ errMsg := extractContentText(result.Content)
+ return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
+ WithError(fmt.Errorf("MCP tool error: %s", errMsg))
+ }
+
+ // Extract text content from result
+ output := extractContentText(result.Content)
+
+ return &ToolResult{
+ ForLLM: output,
+ IsError: false,
+ }
+}
+
+// extractContentText extracts text from MCP content array
+func extractContentText(content []mcp.Content) string {
+ var parts []string
+ for _, c := range content {
+ switch v := c.(type) {
+ case *mcp.TextContent:
+ parts = append(parts, v.Text)
+ case *mcp.ImageContent:
+ // For images, just indicate that an image was returned
+ parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
+ default:
+ // For other content types, use string representation
+ parts = append(parts, fmt.Sprintf("[Content: %T]", v))
+ }
+ }
+ return strings.Join(parts, "\n")
+}
diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go
new file mode 100644
index 000000000..95bb0f992
--- /dev/null
+++ b/pkg/tools/mcp_tool_test.go
@@ -0,0 +1,492 @@
+package tools
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+)
+
+// MockMCPManager is a mock implementation of MCPManager interface for testing
+type MockMCPManager struct {
+ callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
+}
+
+func (m *MockMCPManager) CallTool(
+ ctx context.Context,
+ serverName, toolName string,
+ arguments map[string]any,
+) (*mcp.CallToolResult, error) {
+ if m.callToolFunc != nil {
+ return m.callToolFunc(ctx, serverName, toolName, arguments)
+ }
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "mock result"},
+ },
+ IsError: false,
+ }, nil
+}
+
+// TestNewMCPTool verifies MCP tool creation
+func TestNewMCPTool(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ Description: "A test tool",
+ InputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "input": map[string]any{
+ "type": "string",
+ "description": "Test input",
+ },
+ },
+ },
+ }
+
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ if mcpTool == nil {
+ t.Fatal("NewMCPTool should not return nil")
+ }
+ // Verify tool properties we can access
+ if mcpTool.Name() != "mcp_test_server_test_tool" {
+ t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
+ }
+}
+
+// TestMCPTool_Name verifies tool name with server prefix
+func TestMCPTool_Name(t *testing.T) {
+ tests := []struct {
+ name string
+ serverName string
+ toolName string
+ expected string
+ }{
+ {
+ name: "simple name",
+ serverName: "github",
+ toolName: "create_issue",
+ expected: "mcp_github_create_issue",
+ },
+ {
+ name: "filesystem server",
+ serverName: "filesystem",
+ toolName: "read_file",
+ expected: "mcp_filesystem_read_file",
+ },
+ {
+ name: "remote server",
+ serverName: "remote-api",
+ toolName: "fetch_data",
+ expected: "mcp_remote-api_fetch_data",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{Name: tt.toolName}
+ mcpTool := NewMCPTool(manager, tt.serverName, tool)
+
+ result := mcpTool.Name()
+ if result != tt.expected {
+ t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
+ }
+ })
+ }
+}
+
+// TestMCPTool_Description verifies tool description generation
+func TestMCPTool_Description(t *testing.T) {
+ tests := []struct {
+ name string
+ serverName string
+ toolDescription string
+ expectContains []string
+ }{
+ {
+ name: "with description",
+ serverName: "github",
+ toolDescription: "Create a GitHub issue",
+ expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
+ },
+ {
+ name: "empty description",
+ serverName: "filesystem",
+ toolDescription: "",
+ expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ Description: tt.toolDescription,
+ }
+ mcpTool := NewMCPTool(manager, tt.serverName, tool)
+
+ result := mcpTool.Description()
+
+ for _, expected := range tt.expectContains {
+ if !strings.Contains(result, expected) {
+ t.Errorf("Description should contain '%s', got: %s", expected, result)
+ }
+ }
+ })
+ }
+}
+
+// TestMCPTool_Parameters verifies parameter schema conversion
+func TestMCPTool_Parameters(t *testing.T) {
+ tests := []struct {
+ name string
+ inputSchema any
+ expectType string
+ checkProperty string
+ expectProperty bool
+ }{
+ {
+ name: "map schema",
+ inputSchema: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "query": map[string]any{
+ "type": "string",
+ "description": "Search query",
+ },
+ },
+ "required": []string{"query"},
+ },
+ expectType: "object",
+ checkProperty: "query",
+ expectProperty: true,
+ },
+ {
+ name: "nil schema",
+ inputSchema: nil,
+ expectType: "object",
+ expectProperty: false,
+ },
+ {
+ name: "json.RawMessage schema",
+ inputSchema: []byte(`{
+ "type": "object",
+ "properties": {
+ "repo": {
+ "type": "string",
+ "description": "Repository name"
+ },
+ "stars": {
+ "type": "integer",
+ "description": "Minimum stars"
+ }
+ },
+ "required": ["repo"]
+ }`),
+ expectType: "object",
+ checkProperty: "repo",
+ expectProperty: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ InputSchema: tt.inputSchema,
+ }
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ params := mcpTool.Parameters()
+
+ if params == nil {
+ t.Fatal("Parameters should not be nil")
+ }
+
+ if params["type"] != tt.expectType {
+ t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
+ }
+
+ // Check if property exists when expected
+ if tt.checkProperty != "" {
+ properties, ok := params["properties"].(map[string]any)
+ if !ok && tt.expectProperty {
+ t.Errorf("Expected properties to be a map")
+ return
+ }
+ if ok {
+ _, hasProperty := properties[tt.checkProperty]
+ if hasProperty != tt.expectProperty {
+ t.Errorf("Expected property '%s' existence: %v, got: %v",
+ tt.checkProperty, tt.expectProperty, hasProperty)
+ }
+ }
+ }
+ })
+ }
+}
+
+// TestMCPTool_Execute_Success tests successful tool execution
+func TestMCPTool_Execute_Success(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ // Verify correct parameters passed
+ if serverName != "github" {
+ t.Errorf("Expected serverName 'github', got '%s'", serverName)
+ }
+ if toolName != "search_repos" {
+ t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
+ }
+
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "Found 3 repositories"},
+ },
+ IsError: false,
+ }, nil
+ },
+ }
+
+ tool := &mcp.Tool{
+ Name: "search_repos",
+ Description: "Search GitHub repositories",
+ }
+ mcpTool := NewMCPTool(manager, "github", tool)
+
+ ctx := context.Background()
+ args := map[string]any{
+ "query": "golang mcp",
+ }
+
+ result := mcpTool.Execute(ctx, args)
+
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if result.IsError {
+ t.Errorf("Expected no error, got error: %s", result.ForLLM)
+ }
+ if result.ForLLM != "Found 3 repositories" {
+ t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
+ }
+}
+
+// TestMCPTool_Execute_ManagerError tests execution when manager returns error
+func TestMCPTool_Execute_ManagerError(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return nil, fmt.Errorf("connection failed")
+ },
+ }
+
+ tool := &mcp.Tool{Name: "test_tool"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ ctx := context.Background()
+ result := mcpTool.Execute(ctx, map[string]any{})
+
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if !result.IsError {
+ t.Error("Expected IsError to be true")
+ }
+ if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
+ t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "connection failed") {
+ t.Errorf("Error message should include original error, got: %s", result.ForLLM)
+ }
+}
+
+// TestMCPTool_Execute_ServerError tests execution when server returns error
+func TestMCPTool_Execute_ServerError(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "Invalid API key"},
+ },
+ IsError: true,
+ }, nil
+ },
+ }
+
+ tool := &mcp.Tool{Name: "test_tool"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ ctx := context.Background()
+ result := mcpTool.Execute(ctx, map[string]any{})
+
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if !result.IsError {
+ t.Error("Expected IsError to be true")
+ }
+ if !strings.Contains(result.ForLLM, "MCP tool returned error") {
+ t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "Invalid API key") {
+ t.Errorf("Error message should include server message, got: %s", result.ForLLM)
+ }
+}
+
+// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
+func TestMCPTool_Execute_MultipleContent(t *testing.T) {
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: "First line"},
+ &mcp.TextContent{Text: "Second line"},
+ &mcp.TextContent{Text: "Third line"},
+ },
+ IsError: false,
+ }, nil
+ },
+ }
+
+ tool := &mcp.Tool{Name: "multi_output"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ ctx := context.Background()
+ result := mcpTool.Execute(ctx, map[string]any{})
+
+ if result.IsError {
+ t.Errorf("Expected no error, got: %s", result.ForLLM)
+ }
+
+ expected := "First line\nSecond line\nThird line"
+ if result.ForLLM != expected {
+ t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
+ }
+}
+
+// TestExtractContentText_TextContent tests text content extraction
+func TestExtractContentText_TextContent(t *testing.T) {
+ content := []mcp.Content{
+ &mcp.TextContent{Text: "Hello World"},
+ &mcp.TextContent{Text: "Second message"},
+ }
+
+ result := extractContentText(content)
+ expected := "Hello World\nSecond message"
+
+ if result != expected {
+ t.Errorf("Expected '%s', got '%s'", expected, result)
+ }
+}
+
+// TestExtractContentText_ImageContent tests image content extraction
+func TestExtractContentText_ImageContent(t *testing.T) {
+ content := []mcp.Content{
+ &mcp.ImageContent{
+ Data: []byte("base64data"),
+ MIMEType: "image/png",
+ },
+ }
+
+ result := extractContentText(content)
+
+ if !strings.Contains(result, "[Image:") {
+ t.Errorf("Expected image indicator, got: %s", result)
+ }
+ if !strings.Contains(result, "image/png") {
+ t.Errorf("Expected MIME type in output, got: %s", result)
+ }
+}
+
+// TestExtractContentText_MixedContent tests mixed content types
+func TestExtractContentText_MixedContent(t *testing.T) {
+ content := []mcp.Content{
+ &mcp.TextContent{Text: "Description"},
+ &mcp.ImageContent{
+ Data: []byte("data"),
+ MIMEType: "image/jpeg",
+ },
+ &mcp.TextContent{Text: "More text"},
+ }
+
+ result := extractContentText(content)
+
+ if !strings.Contains(result, "Description") {
+ t.Errorf("Should contain text content, got: %s", result)
+ }
+ if !strings.Contains(result, "[Image:") {
+ t.Errorf("Should contain image indicator, got: %s", result)
+ }
+ if !strings.Contains(result, "More text") {
+ t.Errorf("Should contain second text, got: %s", result)
+ }
+}
+
+// TestExtractContentText_EmptyContent tests empty content array
+func TestExtractContentText_EmptyContent(t *testing.T) {
+ content := []mcp.Content{}
+
+ result := extractContentText(content)
+
+ if result != "" {
+ t.Errorf("Expected empty string for empty content, got: %s", result)
+ }
+}
+
+// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
+func TestMCPTool_InterfaceCompliance(t *testing.T) {
+ manager := &MockMCPManager{}
+ tool := &mcp.Tool{Name: "test"}
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ // Verify it implements Tool interface
+ var _ Tool = mcpTool
+}
+
+// TestMCPTool_Parameters_MapSchema tests schema that's already a map
+func TestMCPTool_Parameters_MapSchema(t *testing.T) {
+ manager := &MockMCPManager{}
+ schema := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "name": map[string]any{
+ "type": "string",
+ "description": "The name parameter",
+ },
+ },
+ "required": []string{"name"},
+ }
+
+ tool := &mcp.Tool{
+ Name: "test_tool",
+ InputSchema: schema,
+ }
+ mcpTool := NewMCPTool(manager, "test_server", tool)
+
+ params := mcpTool.Parameters()
+
+ // Should return the schema as-is when it's already a map
+ if params["type"] != "object" {
+ t.Errorf("Expected type 'object', got '%v'", params["type"])
+ }
+
+ props, ok := params["properties"].(map[string]any)
+ if !ok {
+ t.Error("Properties should be a map")
+ }
+
+ nameParam, ok := props["name"].(map[string]any)
+ if !ok {
+ t.Error("Name parameter should exist")
+ }
+
+ if nameParam["type"] != "string" {
+ t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
+ }
+}
diff --git a/pkg/tools/message.go b/pkg/tools/message.go
index 15ef4ff73..d1e4a373e 100644
--- a/pkg/tools/message.go
+++ b/pkg/tools/message.go
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
+ "sync/atomic"
)
type SendCallback func(channel, chatID, content string) error
@@ -11,7 +12,7 @@ type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
- sentInRound bool // Tracks whether a message was sent in the current processing round
+ sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any {
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
- t.sentInRound = false // Reset send tracking for new processing round
+ t.sentInRound.Store(false) // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
- return t.sentInRound
+ return t.sentInRound.Load()
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
@@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes
}
}
- t.sentInRound = true
+ t.sentInRound.Store(true)
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go
index d37a093a8..0ba983e02 100644
--- a/pkg/tools/registry.go
+++ b/pkg/tools/registry.go
@@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry {
func (r *ToolRegistry) Register(tool Tool) {
r.mu.Lock()
defer r.mu.Unlock()
- r.tools[tool.Name()] = tool
+ name := tool.Name()
+ if _, exists := r.tools[name]; exists {
+ logger.WarnCF("tools", "Tool registration overwrites existing tool",
+ map[string]any{"name": name})
+ }
+ r.tools[name] = tool
}
func (r *ToolRegistry) Get(name string) (Tool, bool) {
diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go
index 8ae13b20c..8fe88ca78 100644
--- a/pkg/tools/registry_test.go
+++ b/pkg/tools/registry_test.go
@@ -329,7 +329,7 @@ func TestToolRegistry_ConcurrentAccess(t *testing.T) {
r := NewToolRegistry()
var wg sync.WaitGroup
- for i := 0; i < 50; i++ {
+ for i := range 50 {
wg.Add(1)
go func(n int) {
defer wg.Done()
diff --git a/pkg/tools/result.go b/pkg/tools/result.go
index b13055b1c..cab833284 100644
--- a/pkg/tools/result.go
+++ b/pkg/tools/result.go
@@ -30,6 +30,10 @@ type ToolResult struct {
// Err is the underlying error (not JSON serialized).
// Used for internal error handling and logging.
Err error `json:"-"`
+
+ // Media contains media store refs produced by this tool.
+ // When non-empty, the agent will publish these as OutboundMediaMessage.
+ Media []string `json:"media,omitempty"`
}
// NewToolResult creates a basic ToolResult with content for the LLM.
@@ -120,6 +124,19 @@ func UserResult(content string) *ToolResult {
}
}
+// MediaResult creates a ToolResult with media refs for the user.
+// The agent will publish these refs as OutboundMediaMessage.
+//
+// Example:
+//
+// result := MediaResult("Image generated successfully", []string{"media://abc123"})
+func MediaResult(forLLM string, mediaRefs []string) *ToolResult {
+ return &ToolResult{
+ ForLLM: forLLM,
+ Media: mediaRefs,
+ }
+}
+
// MarshalJSON implements custom JSON serialization.
// The Err field is excluded from JSON output via the json:"-" tag.
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go
index ad1664b5b..a0c83eb1e 100644
--- a/pkg/tools/shell.go
+++ b/pkg/tools/shell.go
@@ -21,60 +21,85 @@ type ExecTool struct {
timeout time.Duration
denyPatterns []*regexp.Regexp
allowPatterns []*regexp.Regexp
+ customAllowPatterns []*regexp.Regexp
restrictToWorkspace bool
}
-var defaultDenyPatterns = []*regexp.Regexp{
- regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
- regexp.MustCompile(`\bdel\s+/[fq]\b`),
- regexp.MustCompile(`\brmdir\s+/s\b`),
- regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
- regexp.MustCompile(`\bdd\s+if=`),
- regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
- regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
- regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
- regexp.MustCompile(`\$\([^)]+\)`),
- regexp.MustCompile(`\$\{[^}]+\}`),
- regexp.MustCompile("`[^`]+`"),
- regexp.MustCompile(`\|\s*sh\b`),
- regexp.MustCompile(`\|\s*bash\b`),
- regexp.MustCompile(`;\s*rm\s+-[rf]`),
- regexp.MustCompile(`&&\s*rm\s+-[rf]`),
- regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
- regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
- regexp.MustCompile(`<<\s*EOF`),
- regexp.MustCompile(`\$\(\s*cat\s+`),
- regexp.MustCompile(`\$\(\s*curl\s+`),
- regexp.MustCompile(`\$\(\s*wget\s+`),
- regexp.MustCompile(`\$\(\s*which\s+`),
- regexp.MustCompile(`\bsudo\b`),
- regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
- regexp.MustCompile(`\bchown\b`),
- regexp.MustCompile(`\bpkill\b`),
- regexp.MustCompile(`\bkillall\b`),
- regexp.MustCompile(`\bkill\s+-[9]\b`),
- regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
- regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
- regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
- regexp.MustCompile(`\bpip\s+install\s+--user\b`),
- regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
- regexp.MustCompile(`\byum\s+(install|remove)\b`),
- regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
- regexp.MustCompile(`\bdocker\s+run\b`),
- regexp.MustCompile(`\bdocker\s+exec\b`),
- regexp.MustCompile(`\bgit\s+push\b`),
- regexp.MustCompile(`\bgit\s+force\b`),
- regexp.MustCompile(`\bssh\b.*@`),
- regexp.MustCompile(`\beval\b`),
- regexp.MustCompile(`\bsource\s+.*\.sh\b`),
-}
+var (
+ defaultDenyPatterns = []*regexp.Regexp{
+ regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
+ regexp.MustCompile(`\bdel\s+/[fq]\b`),
+ regexp.MustCompile(`\brmdir\s+/s\b`),
+ // Match disk wiping commands (must be followed by space/args)
+ regexp.MustCompile(
+ `\b(format|mkfs|diskpart)\b\s`,
+ ),
+ regexp.MustCompile(`\bdd\s+if=`),
+ // Block writes to block devices (all common naming schemes).
+ regexp.MustCompile(
+ `>\s*/dev/(sd[a-z]|hd[a-z]|vd[a-z]|xvd[a-z]|nvme\d|mmcblk\d|loop\d|dm-\d|md\d|sr\d|nbd\d)`,
+ ),
+ regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
+ regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
+ regexp.MustCompile(`\$\([^)]+\)`),
+ regexp.MustCompile(`\$\{[^}]+\}`),
+ regexp.MustCompile("`[^`]+`"),
+ regexp.MustCompile(`\|\s*sh\b`),
+ regexp.MustCompile(`\|\s*bash\b`),
+ regexp.MustCompile(`;\s*rm\s+-[rf]`),
+ regexp.MustCompile(`&&\s*rm\s+-[rf]`),
+ regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
+ regexp.MustCompile(`<<\s*EOF`),
+ regexp.MustCompile(`\$\(\s*cat\s+`),
+ regexp.MustCompile(`\$\(\s*curl\s+`),
+ regexp.MustCompile(`\$\(\s*wget\s+`),
+ regexp.MustCompile(`\$\(\s*which\s+`),
+ regexp.MustCompile(`\bsudo\b`),
+ regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
+ regexp.MustCompile(`\bchown\b`),
+ regexp.MustCompile(`\bpkill\b`),
+ regexp.MustCompile(`\bkillall\b`),
+ regexp.MustCompile(`\bkill\s+-[9]\b`),
+ regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
+ regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
+ regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
+ regexp.MustCompile(`\bpip\s+install\s+--user\b`),
+ regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
+ regexp.MustCompile(`\byum\s+(install|remove)\b`),
+ regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
+ regexp.MustCompile(`\bdocker\s+run\b`),
+ regexp.MustCompile(`\bdocker\s+exec\b`),
+ regexp.MustCompile(`\bgit\s+push\b`),
+ regexp.MustCompile(`\bgit\s+force\b`),
+ regexp.MustCompile(`\bssh\b.*@`),
+ regexp.MustCompile(`\beval\b`),
+ regexp.MustCompile(`\bsource\s+.*\.sh\b`),
+ }
-func NewExecTool(workingDir string, restrict bool) *ExecTool {
+ // absolutePathPattern matches absolute file paths in commands (Unix and Windows).
+ absolutePathPattern = regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
+
+ // safePaths are kernel pseudo-devices that are always safe to reference in
+ // commands, regardless of workspace restriction. They contain no user data
+ // and cannot cause destructive writes.
+ safePaths = map[string]bool{
+ "/dev/null": true,
+ "/dev/zero": true,
+ "/dev/random": true,
+ "/dev/urandom": true,
+ "/dev/stdin": true,
+ "/dev/stdout": true,
+ "/dev/stderr": true,
+ }
+)
+
+func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
}
-func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
+func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
+ customAllowPatterns := make([]*regexp.Regexp, 0)
if config != nil {
execConfig := config.Tools.Exec
@@ -86,8 +111,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
for _, pattern := range execConfig.CustomDenyPatterns {
re, err := regexp.Compile(pattern)
if err != nil {
- fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err)
- continue
+ return nil, fmt.Errorf("invalid custom deny pattern %q: %w", pattern, err)
}
denyPatterns = append(denyPatterns, re)
}
@@ -96,6 +120,13 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
}
+ for _, pattern := range execConfig.CustomAllowPatterns {
+ re, err := regexp.Compile(pattern)
+ if err != nil {
+ return nil, fmt.Errorf("invalid custom allow pattern %q: %w", pattern, err)
+ }
+ customAllowPatterns = append(customAllowPatterns, re)
+ }
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
@@ -105,8 +136,9 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
timeout: 60 * time.Second,
denyPatterns: denyPatterns,
allowPatterns: nil,
+ customAllowPatterns: customAllowPatterns,
restrictToWorkspace: restrict,
- }
+ }, nil
}
func (t *ExecTool) Name() string {
@@ -259,9 +291,20 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
cmd := strings.TrimSpace(command)
lower := strings.ToLower(cmd)
- for _, pattern := range t.denyPatterns {
+ // Custom allow patterns exempt a command from deny checks.
+ explicitlyAllowed := false
+ for _, pattern := range t.customAllowPatterns {
if pattern.MatchString(lower) {
- return "Command blocked by safety guard (dangerous pattern detected)"
+ explicitlyAllowed = true
+ break
+ }
+ }
+
+ if !explicitlyAllowed {
+ for _, pattern := range t.denyPatterns {
+ if pattern.MatchString(lower) {
+ return "Command blocked by safety guard (dangerous pattern detected)"
+ }
}
}
@@ -288,8 +331,7 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
return ""
}
- pathPattern := regexp.MustCompile(`[A-Za-z]:\\[^\\\"']+|/[^\s\"']+`)
- matches := pathPattern.FindAllString(cmd, -1)
+ matches := absolutePathPattern.FindAllString(cmd, -1)
for _, raw := range matches {
p, err := filepath.Abs(raw)
@@ -297,6 +339,10 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
continue
}
+ if safePaths[p] {
+ continue
+ }
+
rel, err := filepath.Rel(cwdPath, p)
if err != nil {
continue
diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go
index 6d35815e8..a6abca8ea 100644
--- a/pkg/tools/shell_test.go
+++ b/pkg/tools/shell_test.go
@@ -7,11 +7,16 @@ import (
"strings"
"testing"
"time"
+
+ "github.com/sipeed/picoclaw/pkg/config"
)
// TestShellTool_Success verifies successful command execution
func TestShellTool_Success(t *testing.T) {
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
ctx := context.Background()
args := map[string]any{
@@ -38,7 +43,10 @@ func TestShellTool_Success(t *testing.T) {
// TestShellTool_Failure verifies failed command execution
func TestShellTool_Failure(t *testing.T) {
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
ctx := context.Background()
args := map[string]any{
@@ -65,7 +73,11 @@ func TestShellTool_Failure(t *testing.T) {
// TestShellTool_Timeout verifies command timeout handling
func TestShellTool_Timeout(t *testing.T) {
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
+
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
@@ -93,7 +105,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0o644)
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
ctx := context.Background()
args := map[string]any{
@@ -114,7 +129,10 @@ func TestShellTool_WorkingDir(t *testing.T) {
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
func TestShellTool_DangerousCommand(t *testing.T) {
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
ctx := context.Background()
args := map[string]any{
@@ -135,7 +153,10 @@ func TestShellTool_DangerousCommand(t *testing.T) {
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
ctx := context.Background()
args := map[string]any{}
@@ -150,7 +171,10 @@ func TestShellTool_MissingCommand(t *testing.T) {
// TestShellTool_StderrCapture verifies stderr is captured and included
func TestShellTool_StderrCapture(t *testing.T) {
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
ctx := context.Background()
args := map[string]any{
@@ -170,7 +194,10 @@ func TestShellTool_StderrCapture(t *testing.T) {
// TestShellTool_OutputTruncation verifies long output is truncated
func TestShellTool_OutputTruncation(t *testing.T) {
- tool := NewExecTool("", false)
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
ctx := context.Background()
// Generate long output (>10000 chars)
@@ -198,7 +225,11 @@ func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
t.Fatalf("failed to create outside dir: %v", err)
}
- tool := NewExecTool(workspace, true)
+ tool, err := NewExecTool(workspace, true)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
+
result := tool.Execute(context.Background(), map[string]any{
"command": "pwd",
"working_dir": outsideDir,
@@ -232,7 +263,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
t.Skipf("symlinks not supported in this environment: %v", err)
}
- tool := NewExecTool(workspace, true)
+ tool, err := NewExecTool(workspace, true)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
+
result := tool.Execute(context.Background(), map[string]any{
"command": "cat secret.txt",
"working_dir": link,
@@ -249,7 +284,11 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
- tool := NewExecTool(tmpDir, false)
+ tool, err := NewExecTool(tmpDir, false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
+
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
@@ -272,3 +311,115 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
)
}
}
+
+// TestShellTool_DevNullAllowed verifies that /dev/null redirections are not blocked (issue #964).
+func TestShellTool_DevNullAllowed(t *testing.T) {
+ tmpDir := t.TempDir()
+ tool, err := NewExecTool(tmpDir, true)
+ if err != nil {
+ t.Fatalf("unable to configure exec tool: %s", err)
+ }
+
+ commands := []string{
+ "echo hello 2>/dev/null",
+ "echo hello >/dev/null",
+ "echo hello > /dev/null",
+ "echo hello 2> /dev/null",
+ "echo hello >/dev/null 2>&1",
+ "find " + tmpDir + " -name '*.go' 2>/dev/null",
+ }
+
+ for _, cmd := range commands {
+ result := tool.Execute(context.Background(), map[string]any{"command": cmd})
+ if result.IsError && strings.Contains(result.ForLLM, "blocked") {
+ t.Errorf("command should not be blocked: %s\n error: %s", cmd, result.ForLLM)
+ }
+ }
+}
+
+// TestShellTool_BlockDevices verifies that writes to block devices are blocked (issue #965).
+func TestShellTool_BlockDevices(t *testing.T) {
+ tool, err := NewExecTool("", false)
+ if err != nil {
+ t.Fatalf("unable to configure exec tool: %s", err)
+ }
+
+ blocked := []string{
+ "echo x > /dev/sda",
+ "echo x > /dev/hda",
+ "echo x > /dev/vda",
+ "echo x > /dev/xvda",
+ "echo x > /dev/nvme0n1",
+ "echo x > /dev/mmcblk0",
+ "echo x > /dev/loop0",
+ "echo x > /dev/dm-0",
+ "echo x > /dev/md0",
+ "echo x > /dev/sr0",
+ "echo x > /dev/nbd0",
+ }
+
+ for _, cmd := range blocked {
+ result := tool.Execute(context.Background(), map[string]any{"command": cmd})
+ if !result.IsError {
+ t.Errorf("expected block device write to be blocked: %s", cmd)
+ }
+ }
+}
+
+// TestShellTool_SafePathsInWorkspaceRestriction verifies that safe kernel pseudo-devices
+// are allowed even when workspace restriction is active.
+func TestShellTool_SafePathsInWorkspaceRestriction(t *testing.T) {
+ tmpDir := t.TempDir()
+ tool, err := NewExecTool(tmpDir, true)
+ if err != nil {
+ t.Fatalf("unable to configure exec tool: %s", err)
+ }
+
+ // These reference paths outside workspace but should be allowed via safePaths.
+ commands := []string{
+ "cat /dev/urandom | head -c 16 | od",
+ "echo test > /dev/null",
+ "dd if=/dev/zero bs=1 count=1",
+ }
+
+ for _, cmd := range commands {
+ result := tool.Execute(context.Background(), map[string]any{"command": cmd})
+ if result.IsError && strings.Contains(result.ForLLM, "path outside working dir") {
+ t.Errorf("safe path should not be blocked by workspace check: %s\n error: %s", cmd, result.ForLLM)
+ }
+ }
+}
+
+// TestShellTool_CustomAllowPatterns verifies that custom allow patterns exempt
+// commands from deny pattern checks.
+func TestShellTool_CustomAllowPatterns(t *testing.T) {
+ cfg := &config.Config{
+ Tools: config.ToolsConfig{
+ Exec: config.ExecConfig{
+ EnableDenyPatterns: true,
+ CustomAllowPatterns: []string{`\bgit\s+push\s+origin\b`},
+ },
+ },
+ }
+
+ tool, err := NewExecToolWithConfig("", false, cfg)
+ if err != nil {
+ t.Fatalf("unable to configure exec tool: %s", err)
+ }
+
+ // "git push origin main" should be allowed by custom allow pattern.
+ result := tool.Execute(context.Background(), map[string]any{
+ "command": "git push origin main",
+ })
+ if result.IsError && strings.Contains(result.ForLLM, "blocked") {
+ t.Errorf("custom allow pattern should exempt 'git push origin main', got: %s", result.ForLLM)
+ }
+
+ // "git push upstream main" should still be blocked (does not match allow pattern).
+ result = tool.Execute(context.Background(), map[string]any{
+ "command": "git push upstream main",
+ })
+ if !result.IsError {
+ t.Errorf("'git push upstream main' should still be blocked by deny pattern")
+ }
+}
diff --git a/pkg/tools/shell_timeout_unix_test.go b/pkg/tools/shell_timeout_unix_test.go
index 04ef8e441..357e1276e 100644
--- a/pkg/tools/shell_timeout_unix_test.go
+++ b/pkg/tools/shell_timeout_unix_test.go
@@ -22,7 +22,11 @@ func processExists(pid int) bool {
}
func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
- tool := NewExecTool(t.TempDir(), false)
+ tool, err := NewExecTool(t.TempDir(), false)
+ if err != nil {
+ t.Errorf("unable to configure exec tool: %s", err)
+ }
+
tool.SetTimeout(500 * time.Millisecond)
args := map[string]any{
diff --git a/pkg/tools/skills_install.go b/pkg/tools/skills_install.go
index 55c0b678d..71bfe730b 100644
--- a/pkg/tools/skills_install.go
+++ b/pkg/tools/skills_install.go
@@ -9,6 +9,7 @@ import (
"sync"
"time"
+ "github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/utils"
@@ -197,5 +198,6 @@ func writeOriginMeta(targetDir, registryName, slug, version string) error {
return err
}
- return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0o644)
+ // Use unified atomic write utility with explicit sync for flash storage reliability.
+ return fileutil.WriteFileAtomic(filepath.Join(targetDir, ".skill-origin.json"), data, 0o600)
}
diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go
index ad371a649..69f1a49a2 100644
--- a/pkg/tools/subagent.go
+++ b/pkg/tools/subagent.go
@@ -218,7 +218,9 @@ After completing the task, provide a clear summary of what was done.`
// Send announce message back to main agent
if sm.bus != nil {
announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result)
- sm.bus.PublishInbound(bus.InboundMessage{
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer pubCancel()
+ sm.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("subagent:%s", task.ID),
// Format: "original_channel:original_chat_id" for routing back
diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go
index cdfe0d6ce..244f0d4a2 100644
--- a/pkg/tools/toolloop.go
+++ b/pkg/tools/toolloop.go
@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
+ "sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -121,37 +122,53 @@ func RunToolLoop(
}
messages = append(messages, assistantMsg)
- // 7. Execute tool calls
- for _, tc := range normalizedToolCalls {
- argsJSON, _ := json.Marshal(tc.Arguments)
- argsPreview := utils.Truncate(string(argsJSON), 200)
- logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
- map[string]any{
- "tool": tc.Name,
- "iteration": iteration,
- })
+ // 7. Execute tool calls in parallel
+ type indexedResult struct {
+ result *ToolResult
+ tc providers.ToolCall
+ }
- // Execute tool (no async callback for subagents - they run independently)
- var toolResult *ToolResult
- if config.Tools != nil {
- toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
- } else {
- toolResult = ErrorResult("No tools available")
+ results := make([]indexedResult, len(normalizedToolCalls))
+ var wg sync.WaitGroup
+
+ for i, tc := range normalizedToolCalls {
+ results[i].tc = tc
+
+ wg.Add(1)
+ go func(idx int, tc providers.ToolCall) {
+ defer wg.Done()
+
+ argsJSON, _ := json.Marshal(tc.Arguments)
+ argsPreview := utils.Truncate(string(argsJSON), 200)
+ logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
+ map[string]any{
+ "tool": tc.Name,
+ "iteration": iteration,
+ })
+
+ var toolResult *ToolResult
+ if config.Tools != nil {
+ toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
+ } else {
+ toolResult = ErrorResult("No tools available")
+ }
+ results[idx].result = toolResult
+ }(i, tc)
+ }
+ wg.Wait()
+
+ // Append results in original order
+ for _, r := range results {
+ contentForLLM := r.result.ForLLM
+ if contentForLLM == "" && r.result.Err != nil {
+ contentForLLM = r.result.Err.Error()
}
- // Determine content for LLM
- contentForLLM := toolResult.ForLLM
- if contentForLLM == "" && toolResult.Err != nil {
- contentForLLM = toolResult.Err.Error()
- }
-
- // Add tool result message
- toolResultMsg := providers.Message{
+ messages = append(messages, providers.Message{
Role: "tool",
Content: contentForLLM,
- ToolCallID: tc.ID,
- }
- messages = append(messages, toolResultMsg)
+ ToolCallID: r.tc.ID,
+ })
}
}
diff --git a/pkg/tools/web.go b/pkg/tools/web.go
index e95185599..eeceabd98 100644
--- a/pkg/tools/web.go
+++ b/pkg/tools/web.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
@@ -15,6 +16,27 @@ import (
const (
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
+
+ // HTTP client timeouts for web tool providers.
+ searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo
+ perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower)
+ fetchTimeout = 60 * time.Second // WebFetchTool
+
+ defaultMaxChars = 50000
+ maxRedirects = 5
+)
+
+// Pre-compiled regexes for HTML text extraction
+var (
+ reScript = regexp.MustCompile(``)
- result := re.ReplaceAllLiteralString(htmlContent, "")
- re = regexp.MustCompile(`