diff --git a/.env.example b/.env.example index 06d43070c..bc68456d6 100644 --- a/.env.example +++ b/.env.example @@ -17,4 +17,4 @@ # BRAVE_SEARCH_API_KEY=BSA... # ── Timezone ────────────────────────────── -TZ=Asia/Tokyo +TZ=Asia/Shanghai diff --git a/.golangci.yaml b/.golangci.yaml index d0ba90716..ea3107ec8 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -7,7 +7,6 @@ linters: - containedctx - cyclop - depguard - - dupl - dupword - err113 - exhaustruct diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 88227f493..ceff723d2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -269,8 +269,8 @@ Once your PR is submitted, you can reach out to the assigned reviewers listed in |Function| Reviewer| |--- |--- | |Provider|@yinwm | -|Channel |@yinwm | -|Agent |@lxowalle| +|Channel |@yinwm/@alexhoshina | +|Agent |@lxowalle/@Zhaoyikaiii| |Tools |@lxowalle| |SKill || |MCP || diff --git a/CONTRIBUTING.zh.md b/CONTRIBUTING.zh.md index 01a1abfd5..196aecc65 100644 --- a/CONTRIBUTING.zh.md +++ b/CONTRIBUTING.zh.md @@ -268,8 +268,8 @@ Release 分支的保护级别高于 `main`,在任何情况下均不允许直 |Function| Reviewer| |--- |--- | |Provider|@yinwm | -|Channel |@yinwm | -|Agent |@lxowalle| +|Channel |@yinwm/@alexhoshina | +|Agent |@lxowalle/@Zhaoyikaiii| |Tools |@lxowalle| |SKill || |MCP || diff --git a/Makefile b/Makefile index c7375a544..afc76a6ad 100644 --- a/Makefile +++ b/Makefile @@ -168,11 +168,11 @@ clean: @echo "Clean complete" ## vet: Run go vet for static analysis -vet: +vet: generate @$(GO) vet ./... ## test: Test Go code -test: +test: generate @$(GO) test ./... ## fmt: Format Go code @@ -204,6 +204,44 @@ check: deps fmt vet test run: build @$(BUILD_DIR)/$(BINARY_NAME) $(ARGS) +## docker-build: Build Docker image (minimal Alpine-based) +docker-build: + @echo "Building minimal Docker image (Alpine-based)..." + docker compose -f docker/docker-compose.yml build picoclaw-agent picoclaw-gateway + +## docker-build-full: Build Docker image with full MCP support (Node.js 24) +docker-build-full: + @echo "Building full-featured Docker image (Node.js 24)..." + docker compose -f docker/docker-compose.full.yml build picoclaw-agent picoclaw-gateway + +## docker-test: Test MCP tools in Docker container +docker-test: + @echo "Testing MCP tools in Docker..." + @chmod +x scripts/test-docker-mcp.sh + @./scripts/test-docker-mcp.sh + +## docker-run: Run picoclaw gateway in Docker (Alpine-based) +docker-run: + docker compose -f docker/docker-compose.yml --profile gateway up + +## docker-run-full: Run picoclaw gateway in Docker (full-featured) +docker-run-full: + docker compose -f docker/docker-compose.full.yml --profile gateway up + +## docker-run-agent: Run picoclaw agent in Docker (interactive, Alpine-based) +docker-run-agent: + docker compose -f docker/docker-compose.yml run --rm picoclaw-agent + +## docker-run-agent-full: Run picoclaw agent in Docker (interactive, full-featured) +docker-run-agent-full: + docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent + +## docker-clean: Clean Docker images and volumes +docker-clean: + docker compose -f docker/docker-compose.yml down -v + docker compose -f docker/docker-compose.full.yml down -v + docker rmi picoclaw:latest picoclaw:full 2>/dev/null || true + ## help: Show this help message help: @echo "picoclaw Makefile" @@ -219,6 +257,8 @@ help: @echo " make install # Install to ~/.local/bin" @echo " make uninstall # Remove from /usr/local/bin" @echo " make install-skills # Install skills to workspace" + @echo " make docker-build # Build minimal Docker image" + @echo " make docker-test # Test MCP tools in Docker" @echo "" @echo "Environment Variables:" @echo " INSTALL_PREFIX # Installation prefix (default: ~/.local)" diff --git a/README.fr.md b/README.fr.md index 2bec768fc..320aa9e22 100644 --- a/README.fr.md +++ b/README.fr.md @@ -288,7 +288,7 @@ Discutez avec votre PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom | **QQ** | Facile (AppID + AppSecret) | | **DingTalk** | Moyen (identifiants de l'application) | | **LINE** | Moyen (identifiants + URL de webhook) | -| **WeCom** | Moyen (CorpID + configuration webhook) | +| **WeCom AI Bot** | Moyen (Token + clé AES) |
Telegram (Recommandé) @@ -491,12 +491,13 @@ picoclaw gateway
WeCom (WeChat Work) -PicoClaw prend en charge deux types d'intégration WeCom : +PicoClaw prend en charge trois types d'intégration WeCom : -**Option 1 : WeCom Bot (Robot Intelligent)** - Configuration plus facile, prend en charge les discussions de groupe -**Option 2 : WeCom App (Application Personnalisée)** - Plus de fonctionnalités, messagerie proactive +**Option 1 : WeCom Bot (Robot)** - Configuration plus facile, prend en charge les discussions de groupe +**Option 2 : WeCom App (Application Personnalisée)** - Plus de fonctionnalités, messagerie proactive, chat privé uniquement +**Option 3 : WeCom AI Bot (Bot Intelligent)** - Bot IA officiel, réponses en streaming, prend en charge groupe et privé -Voir le [Guide de Configuration WeCom App](docs/wecom-app-configuration.md) pour des instructions détaillées. +Voir le [Guide de Configuration WeCom AI Bot](docs/channels/wecom/wecom_aibot/README.zh.md) pour des instructions détaillées. **Configuration Rapide - WeCom Bot :** @@ -563,6 +564,39 @@ picoclaw gateway > **Note** : Les callbacks webhook WeCom App sont servis par le serveur Gateway partagé (par défaut `127.0.0.1:18790`). Assurez-vous que le port `18790` est accessible ou utilisez un proxy inverse HTTPS en production. +**Configuration Rapide - WeCom AI Bot :** + +**1. Créer un AI Bot** + +* Accédez à la Console d'Administration WeCom → Gestion des Applications → AI Bot +* Configurez l'URL de callback : `http://your-server:18791/webhook/wecom-aibot` +* Copiez le **Token** et générez l'**EncodingAESKey** + +**2. Configurer** + +```json +{ + "channels": { + "wecom_aibot": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_path": "/webhook/wecom-aibot", + "allow_from": [], + "welcome_message": "Bonjour ! Comment puis-je vous aider ?" + } + } +} +``` + +**3. Lancer** + +```bash +picoclaw gateway +``` + +> **Note** : WeCom AI Bot utilise le protocole pull en streaming — pas de problème de timeout. Les tâches longues (>5,5 min) basculent automatiquement vers la livraison via `response_url`. +
## ClawdChat Rejoignez le Réseau Social d'Agents @@ -793,7 +827,7 @@ Le sous-agent a accès aux outils (message, web_search, etc.) et peut communique ### Fournisseurs > [!NOTE] -> Groq fournit la transcription vocale gratuite via Whisper. Si configuré, les messages vocaux Telegram seront automatiquement transcrits. +> Groq fournit la transcription vocale gratuite via Whisper. Si configuré, les messages audio de n'importe quel canal seront automatiquement transcrits au niveau de l'agent. | Fournisseur | Utilisation | Obtenir une Clé API | | ------------------------ | ---------------------------------------- | ------------------------------------------------------ | diff --git a/README.ja.md b/README.ja.md index 15ed1f649..ea6bc7e72 100644 --- a/README.ja.md +++ b/README.ja.md @@ -257,7 +257,7 @@ Telegram、Discord、QQ、DingTalk、LINE、WeCom で PicoClaw と会話でき | **QQ** | 簡単(AppID + AppSecret) | | **DingTalk** | 普通(アプリ認証情報) | | **LINE** | 普通(認証情報 + Webhook URL) | -| **WeCom** | 普通(CorpID + Webhook設定) | +| **WeCom AI Bot** | 普通(Token + AES キー) |
Telegram(推奨) @@ -456,12 +456,13 @@ picoclaw gateway
WeCom (企業微信) -PicoClaw は2種類の WeCom 統合をサポートしています: +PicoClaw は3種類の WeCom 統合をサポートしています: -**オプション1: WeCom Bot (智能ロボット)** - 簡単な設定、グループチャット対応 -**オプション2: WeCom App (自作アプリ)** - より多機能、アクティブメッセージング対応 +**オプション1: WeCom Bot (ロボット)** - 簡単な設定、グループチャット対応 +**オプション2: WeCom App (カスタムアプリ)** - より多機能、アクティブメッセージング対応、プライベートチャットのみ +**オプション3: WeCom AI Bot (スマートボット)** - 公式 AI Bot、ストリーミング返信、グループ・プライベート両対応 -詳細な設定手順は [WeCom App Configuration Guide](docs/wecom-app-configuration.md) を参照してください。 +詳細な設定手順は [WeCom AI Bot Configuration Guide](docs/channels/wecom/wecom_aibot/README.zh.md) を参照してください。 **クイックセットアップ - WeCom Bot:** @@ -530,6 +531,39 @@ picoclaw gateway > **注意**: WeCom App の Webhook コールバックは共有の Gateway HTTP サーバー(デフォルト: `127.0.0.1:18790`)で提供されます。ホストからアクセスする場合は HTTPS 用のリバースプロキシを設定してください。 +**クイックセットアップ - WeCom AI Bot:** + +**1. AI Bot を作成** + +* WeCom 管理コンソール → アプリ管理 → AI Bot +* コールバック URL を設定: `http://your-server:18791/webhook/wecom-aibot` +* **Token** をコピーし、**EncodingAESKey** を生成 + +**2. 設定** + +```json +{ + "channels": { + "wecom_aibot": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_path": "/webhook/wecom-aibot", + "allow_from": [], + "welcome_message": "こんにちは!何かお手伝いできますか?" + } + } +} +``` + +**3. 起動** + +```bash +picoclaw gateway +``` + +> **注意**: WeCom AI Bot はストリーミングプルプロトコルを使用 — 返信タイムアウトの心配なし。長時間タスク(>30秒)は自動的に `response_url` によるプッシュ配信に切り替わります。 +
## ⚙️ 設定 @@ -751,7 +785,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る ### プロバイダー > [!NOTE] -> Groq は Whisper による無料の音声文字起こしを提供しています。設定すると、Telegram の音声メッセージが自動的に文字起こしされます。 +> Groq は Whisper による無料の音声文字起こしを提供しています。設定すると、あらゆるチャンネルからの音声メッセージがエージェントレベルで自動的に文字起こしされます。 | プロバイダー | 用途 | API キー取得先 | | --- | --- | --- | diff --git a/README.md b/README.md index 97eb47773..204a1af81 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ ## 📢 News -2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/ROADMAP.md) —we can’t wait to have you on board! +2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](ROADMAP.md) —we can’t wait to have you on board! 2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs & issues coming in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development. 🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting. @@ -305,7 +305,7 @@ Talk to your picoclaw through Telegram, Discord, WhatsApp, DingTalk, LINE, or We | **QQ** | Easy (AppID + AppSecret) | | **DingTalk** | Medium (app credentials) | | **LINE** | Medium (credentials + webhook URL) | -| **WeCom** | Medium (CorpID + webhook setup) | +| **WeCom AI Bot** | Medium (Token + AES key) |
Telegram (Recommended) @@ -557,12 +557,13 @@ picoclaw gateway
WeCom (企业微信) -PicoClaw supports two types of WeCom integration: +PicoClaw supports three types of WeCom integration: -**Option 1: WeCom Bot (智能机器人)** - Easier setup, supports group chats -**Option 2: WeCom App (自建应用)** - More features, proactive messaging +**Option 1: WeCom Bot (Bot)** - Easier setup, supports group chats +**Option 2: WeCom App (Custom App)** - More features, proactive messaging, private chat only +**Option 3: WeCom AI Bot (AI Bot)** - Official AI Bot, streaming replies, supports group & private chat -See [WeCom App Configuration Guide](docs/wecom-app-configuration.md) for detailed setup instructions. +See [WeCom AI Bot Configuration Guide](docs/channels/wecom/wecom_aibot/README.zh.md) for detailed setup instructions. **Quick Setup - WeCom Bot:** @@ -631,6 +632,39 @@ picoclaw gateway > **Note**: WeCom webhook callbacks are served on the Gateway port (default 18790). Use a reverse proxy for HTTPS. +**Quick Setup - WeCom AI Bot:** + +**1. Create an AI Bot** + +* Go to WeCom Admin Console → App Management → AI Bot +* In the AI Bot settings, configure callback URL: `http://your-server:18791/webhook/wecom-aibot` +* Copy **Token** and click "Random Generate" for **EncodingAESKey** + +**2. Configure** + +```json +{ + "channels": { + "wecom_aibot": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_path": "/webhook/wecom-aibot", + "allow_from": [], + "welcome_message": "Hello! How can I help you?" + } + } +} +``` + +**3. Run** + +```bash +picoclaw gateway +``` + +> **Note**: WeCom AI Bot uses streaming pull protocol — no reply timeout concerns. Long tasks (>30 seconds) automatically switch to `response_url` push delivery. +
## ClawdChat Join the Agent Social Network @@ -687,6 +721,20 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa └── USER.md # User preferences ``` +### Skill Sources + +By default, skills are loaded from: + +1. `~/.picoclaw/workspace/skills` (workspace) +2. `~/.picoclaw/skills` (global) +3. `/skills` (builtin) + +For advanced/test setups, you can override the builtin skills root with: + +```bash +export PICOCLAW_BUILTIN_SKILLS=/path/to/skills +``` + ### 🔒 Security Sandbox PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. @@ -863,7 +911,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate ### Providers > [!NOTE] -> Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. +> Groq provides free voice transcription via Whisper. If configured, audio messages from any channel will be automatically transcribed at the agent level. | Provider | Purpose | Get API Key | | -------------------------- | --------------------------------------- | -------------------------------------------------------------------- | @@ -892,7 +940,7 @@ This design also enables **multi-agent support** with flexible provider selectio #### 📋 All Supported Vendors | Vendor | `model` Prefix | Default API Base | Protocol | API Key | -| ------------------- | ----------------- | --------------------------------------------------- | --------- | ---------------------------------------------------------------- | +| ------------------- | ----------------- |-----------------------------------------------------| --------- | ---------------------------------------------------------------- | | **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) | | **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) | | **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | @@ -904,6 +952,7 @@ This design also enables **multi-agent support** with flexible provider selectio | **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) | | **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) | | **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) | +| **LiteLLM Proxy** | `litellm/` | `http://localhost:4000/v1 | OpenAI | Your LiteLLM proxy key | | **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local | | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) | | **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) | @@ -1006,6 +1055,19 @@ This design also enables **multi-agent support** with flexible provider selectio } ``` +**LiteLLM Proxy** + +```json +{ + "model_name": "lite-gpt4", + "model": "litellm/lite-gpt4", + "api_base": "http://localhost:4000/v1", + "api_key": "sk-..." +} +``` + +PicoClaw strips only the outer `litellm/` prefix before sending the request, so proxy aliases like `litellm/lite-gpt4` send `lite-gpt4`, while `litellm/openai/gpt-4o` sends `openai/gpt-4o`. + #### Load Balancing Configure multiple endpoints for the same model name—PicoClaw will automatically round-robin between them: diff --git a/README.pt-br.md b/README.pt-br.md index 611a61281..67ce9e0d3 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -282,7 +282,7 @@ Converse com seu PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom. | **QQ** | Fácil (AppID + AppSecret) | | **DingTalk** | Médio (credenciais do app) | | **LINE** | Médio (credenciais + webhook URL) | -| **WeCom** | Médio (CorpID + configuração webhook) | +| **WeCom AI Bot** | Médio (Token + chave AES) |
Telegram (Recomendado) @@ -485,12 +485,13 @@ picoclaw gateway
WeCom (WeChat Work) -O PicoClaw suporta dois tipos de integração WeCom: +O PicoClaw suporta três tipos de integração WeCom: -**Opção 1: WeCom Bot (Robô Inteligente)** - Configuração mais fácil, suporta chats em grupo -**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas +**Opção 1: WeCom Bot (Robô)** - Configuração mais fácil, suporta chats em grupo +**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas, somente chat privado +**Opção 3: WeCom AI Bot (Robô Inteligente)** - Bot IA oficial, respostas em streaming, suporta grupo e privado -Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para instruções detalhadas. +Veja o [Guia de Configuração WeCom AI Bot](docs/channels/wecom/wecom_aibot/README.zh.md) para instruções detalhadas. **Configuração Rápida - WeCom Bot:** @@ -559,6 +560,39 @@ picoclaw gateway > **Nota**: O WeCom App (callbacks de webhook) é servido pelo Gateway compartilhado (padrão 127.0.0.1:18790). Em produção use um proxy reverso HTTPS para expor a porta do Gateway, ou atualize `PICOCLAW_GATEWAY_HOST` para `0.0.0.0` se necessário. +**Configuração Rápida - WeCom AI Bot:** + +**1. Criar um AI Bot** + +* Acesse o Console de Administração WeCom → Gerenciamento de Aplicativos → AI Bot +* Configure a URL de callback: `http://your-server:18791/webhook/wecom-aibot` +* Copie o **Token** e gere o **EncodingAESKey** + +**2. Configurar** + +```json +{ + "channels": { + "wecom_aibot": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_path": "/webhook/wecom-aibot", + "allow_from": [], + "welcome_message": "Olá! Como posso ajudá-lo?" + } + } +} +``` + +**3. Executar** + +```bash +picoclaw gateway +``` + +> **Nota**: O WeCom AI Bot usa protocolo de pull em streaming — sem preocupações com timeout de resposta. Tarefas longas (>5,5 min) alternam automaticamente para entrega via `response_url`. +
## ClawdChat Junte-se a Rede Social de Agentes @@ -789,7 +823,7 @@ O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se com ### Provedores > [!NOTE] -> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de voz do Telegram serão automaticamente transcritas. +> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de áudio de qualquer canal serão automaticamente transcritas no nível do agente. | Provedor | Finalidade | Obter API Key | | --- | --- | --- | diff --git a/README.vi.md b/README.vi.md index e836e30f0..5755896ed 100644 --- a/README.vi.md +++ b/README.vi.md @@ -256,7 +256,7 @@ Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk, LINE hoặc WeCom. | **QQ** | Dễ (AppID + AppSecret) | | **DingTalk** | Trung bình (app credentials) | | **LINE** | Trung bình (credentials + webhook URL) | -| **WeCom** | Trung bình (CorpID + cấu hình webhook) | +| **WeCom AI Bot** | Trung bình (Token + khóa AES) |
Telegram (Khuyên dùng) @@ -457,12 +457,13 @@ picoclaw gateway
WeCom (WeChat Work) -PicoClaw hỗ trợ hai loại tích hợp WeCom: +PicoClaw hỗ trợ ba loại tích hợp WeCom: -**Tùy chọn 1: WeCom Bot (Robot Thông minh)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm -**Tùy chọn 2: WeCom App (Ứng dụng Tự xây dựng)** - Nhiều tính năng hơn, nhắn tin chủ động +**Tùy chọn 1: WeCom Bot (Robot)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm +**Tùy chọn 2: WeCom App (Ứng dụng Tùy chỉnh)** - Nhiều tính năng hơn, nhắn tin chủ động, chỉ chat riêng tư +**Tùy chọn 3: WeCom AI Bot (Bot Thông Minh)** - Bot AI chính thức, phản hồi streaming, hỗ trợ nhóm và riêng tư -Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) để biết hướng dẫn chi tiết. +Xem [Hướng dẫn Cấu hình WeCom AI Bot](docs/channels/wecom/wecom_aibot/README.zh.md) để biết hướng dẫn chi tiết. **Thiết lập Nhanh - WeCom Bot:** @@ -531,6 +532,39 @@ picoclaw gateway > **Lưu ý**: WeCom App callback webhook được phục vụ bởi Gateway HTTP chung (mặc định 127.0.0.1:18790). Sử dụng proxy ngược để cung cấp HTTPS trong môi trường production nếu cần. +**Thiết lập Nhanh - WeCom AI Bot:** + +**1. Tạo AI Bot** + +* Truy cập Bảng điều khiển Quản trị WeCom → Quản lý Ứng dụng → AI Bot +* Cấu hình URL callback: `http://your-server:18791/webhook/wecom-aibot` +* Sao chép **Token** và tạo **EncodingAESKey** + +**2. Cấu hình** + +```json +{ + "channels": { + "wecom_aibot": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_path": "/webhook/wecom-aibot", + "allow_from": [], + "welcome_message": "Xin chào! Tôi có thể giúp gì cho bạn?" + } + } +} +``` + +**3. Chạy** + +```bash +picoclaw gateway +``` + +> **Lưu ý**: WeCom AI Bot sử dụng giao thức pull streaming — không lo timeout phản hồi. Tác vụ dài (>5,5 phút) tự động chuyển sang gửi qua `response_url`. +
## ClawdChat Tham gia Mạng xã hội Agent @@ -761,7 +795,7 @@ Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và ### Nhà cung cấp (Providers) > [!NOTE] -> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn thoại trên Telegram sẽ được tự động chuyển thành văn bản. +> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn âm thanh từ bất kỳ kênh nào sẽ được tự động chuyển thành văn bản ở cấp độ agent. | Nhà cung cấp | Mục đích | Lấy API Key | | --- | --- | --- | diff --git a/README.zh.md b/README.zh.md index 95984bbdf..bd90173f9 100644 --- a/README.zh.md +++ b/README.zh.md @@ -301,7 +301,7 @@ PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方 | **Slack** | ⭐ 简单 | **Socket Mode** (无需公网 IP),企业级支持 | [查看文档](docs/channels/slack/README.zh.md) | | **QQ** | ⭐⭐ 中等 | 官方机器人 API,适合国内社群 | [查看文档](docs/channels/qq/README.zh.md) | | **钉钉 (DingTalk)** | ⭐⭐ 中等 | Stream 模式无需公网,企业办公首选 | [查看文档](docs/channels/dingtalk/README.zh.md) | -| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)和自建应用(API) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) | +| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)、自建应用(API)和智能机器人(AI Bot) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) / [AI Bot 文档](docs/channels/wecom/wecom_aibot/README.zh.md) | | **飞书 (Feishu)** | ⭐⭐⭐ 较难 | 企业级协作,功能丰富 | [查看文档](docs/channels/feishu/README.zh.md) | | **Line** | ⭐⭐⭐ 较难 | 需要 HTTPS Webhook | [查看文档](docs/channels/line/README.zh.md) | | **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) | @@ -362,6 +362,20 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work ``` +### 技能来源 (Skill Sources) + +默认情况下,技能会按以下顺序加载: + +1. `~/.picoclaw/workspace/skills`(工作区) +2. `~/.picoclaw/skills`(全局) +3. `/skills`(内置) + +在高级/测试场景下,可通过以下环境变量覆盖内置技能目录: + +```bash +export PICOCLAW_BUILTIN_SKILLS=/path/to/skills +``` + ### 心跳 / 周期性任务 (Heartbeat) PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: @@ -445,7 +459,7 @@ Agent 读取 HEARTBEAT.md ### 提供商 (Providers) > [!NOTE] -> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。 +> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,任意渠道的音频消息都将在 Agent 层面自动转录为文字。 | 提供商 | 用途 | 获取 API Key | | -------------------- | ---------------------------- | -------------------------------------------------------------------- | diff --git a/assets/wechat.png b/assets/wechat.png index 1c0b88295..32998c122 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw-launcher-tui/internal/ui/channel.go b/cmd/picoclaw-launcher-tui/internal/ui/channel.go index ad9171424..49a6ccc5d 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/channel.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/channel.go @@ -10,8 +10,8 @@ import ( picoclawconfig "github.com/sipeed/picoclaw/pkg/config" ) -func (s *appState) channelMenu() tview.Primitive { - items := []MenuItem{ +func (s *appState) buildChannelMenuItems() []MenuItem { + return []MenuItem{ {Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }}, channelItem( "Telegram", @@ -86,8 +86,10 @@ func (s *appState) channelMenu() tview.Primitive { func() { s.push("channel-wecomapp", s.wecomAppForm()) }, ), } +} - menu := NewMenu("Channels", items) +func (s *appState) channelMenu() tview.Primitive { + menu := NewMenu("Channels", s.buildChannelMenuItems()) menu.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { if event.Key() == tcell.KeyEsc { s.pop() @@ -103,199 +105,72 @@ func (s *appState) channelMenu() tview.Primitive { } func refreshChannelMenuFromState(menu *Menu, s *appState) { - items := []MenuItem{ - {Label: "Back", Description: "Return to main menu", Action: func() { s.pop() }}, - channelItem( - "Telegram", - "Telegram bot settings", - s.config.Channels.Telegram.Enabled, - func() { s.push("channel-telegram", s.telegramForm()) }, - ), - channelItem( - "Discord", - "Discord bot settings", - s.config.Channels.Discord.Enabled, - func() { s.push("channel-discord", s.discordForm()) }, - ), - channelItem( - "QQ", - "QQ bot settings", - s.config.Channels.QQ.Enabled, - func() { s.push("channel-qq", s.qqForm()) }, - ), - channelItem( - "MaixCam", - "MaixCam gateway", - s.config.Channels.MaixCam.Enabled, - func() { s.push("channel-maixcam", s.maixcamForm()) }, - ), - channelItem( - "WhatsApp", - "WhatsApp bridge", - s.config.Channels.WhatsApp.Enabled, - func() { s.push("channel-whatsapp", s.whatsappForm()) }, - ), - channelItem( - "Feishu", - "Feishu bot settings", - s.config.Channels.Feishu.Enabled, - func() { s.push("channel-feishu", s.feishuForm()) }, - ), - channelItem( - "DingTalk", - "DingTalk bot settings", - s.config.Channels.DingTalk.Enabled, - func() { s.push("channel-dingtalk", s.dingtalkForm()) }, - ), - channelItem( - "Slack", - "Slack bot settings", - s.config.Channels.Slack.Enabled, - func() { s.push("channel-slack", s.slackForm()) }, - ), - channelItem( - "LINE", - "LINE bot settings", - s.config.Channels.LINE.Enabled, - func() { s.push("channel-line", s.lineForm()) }, - ), - channelItem( - "OneBot", - "OneBot settings", - s.config.Channels.OneBot.Enabled, - func() { s.push("channel-onebot", s.onebotForm()) }, - ), - channelItem( - "WeCom", - "WeCom bot settings", - s.config.Channels.WeCom.Enabled, - func() { s.push("channel-wecom", s.wecomForm()) }, - ), - channelItem( - "WeCom App", - "WeCom App settings", - s.config.Channels.WeComApp.Enabled, - func() { s.push("channel-wecomapp", s.wecomAppForm()) }, - ), - } - menu.applyItems(items) + menu.applyItems(s.buildChannelMenuItems()) } func (s *appState) telegramForm() tview.Primitive { cfg := &s.config.Channels.Telegram - form := baseChannelForm("Telegram", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Telegram", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Token", cfg.Token, 128, nil, func(text string) { cfg.Token = strings.TrimSpace(text) }) form.AddInputField("Proxy", cfg.Proxy, 128, nil, func(text string) { cfg.Proxy = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) discordForm() tview.Primitive { cfg := &s.config.Channels.Discord - form := baseChannelForm("Discord", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Discord", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Token", cfg.Token, 128, nil, func(text string) { cfg.Token = strings.TrimSpace(text) }) form.AddCheckbox("Mention Only", cfg.MentionOnly, func(checked bool) { cfg.MentionOnly = checked }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) qqForm() tview.Primitive { cfg := &s.config.Channels.QQ - form := baseChannelForm("QQ", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("QQ", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) { cfg.AppID = strings.TrimSpace(text) }) form.AddInputField("App Secret", cfg.AppSecret, 128, nil, func(text string) { cfg.AppSecret = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) maixcamForm() tview.Primitive { cfg := &s.config.Channels.MaixCam - form := baseChannelForm("MaixCam", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("MaixCam", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Host", cfg.Host, 64, nil, func(text string) { cfg.Host = strings.TrimSpace(text) }) addIntField(form, "Port", cfg.Port, func(value int) { cfg.Port = value }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) whatsappForm() tview.Primitive { cfg := &s.config.Channels.WhatsApp - form := baseChannelForm("WhatsApp", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("WhatsApp", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Bridge URL", cfg.BridgeURL, 128, nil, func(text string) { cfg.BridgeURL = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) feishuForm() tview.Primitive { cfg := &s.config.Channels.Feishu - form := baseChannelForm("Feishu", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Feishu", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("App ID", cfg.AppID, 64, nil, func(text string) { cfg.AppID = strings.TrimSpace(text) }) @@ -308,66 +183,39 @@ func (s *appState) feishuForm() tview.Primitive { form.AddInputField("Verification Token", cfg.VerificationToken, 128, nil, func(text string) { cfg.VerificationToken = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) dingtalkForm() tview.Primitive { cfg := &s.config.Channels.DingTalk - form := baseChannelForm("DingTalk", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("DingTalk", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Client ID", cfg.ClientID, 64, nil, func(text string) { cfg.ClientID = strings.TrimSpace(text) }) form.AddInputField("Client Secret", cfg.ClientSecret, 128, nil, func(text string) { cfg.ClientSecret = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) slackForm() tview.Primitive { cfg := &s.config.Channels.Slack - form := baseChannelForm("Slack", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("Slack", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Bot Token", cfg.BotToken, 128, nil, func(text string) { cfg.BotToken = strings.TrimSpace(text) }) form.AddInputField("App Token", cfg.AppToken, 128, nil, func(text string) { cfg.AppToken = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) lineForm() tview.Primitive { cfg := &s.config.Channels.LINE - form := baseChannelForm("LINE", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("LINE", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Channel Secret", cfg.ChannelSecret, 128, nil, func(text string) { cfg.ChannelSecret = strings.TrimSpace(text) }) @@ -381,22 +229,13 @@ func (s *appState) lineForm() tview.Primitive { form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) { cfg.WebhookPath = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) onebotForm() tview.Primitive { cfg := &s.config.Channels.OneBot - form := baseChannelForm("OneBot", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("OneBot", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("WS URL", cfg.WSUrl, 128, nil, func(text string) { cfg.WSUrl = strings.TrimSpace(text) }) @@ -418,22 +257,13 @@ func (s *appState) onebotForm() tview.Primitive { cfg.GroupTriggerPrefix = splitCSV(text) }, ) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) return wrapWithBack(form, s) } func (s *appState) wecomForm() tview.Primitive { cfg := &s.config.Channels.WeCom - form := baseChannelForm("WeCom", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("WeCom", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Token", cfg.Token, 128, nil, func(text string) { cfg.Token = strings.TrimSpace(text) }) @@ -450,9 +280,7 @@ func (s *appState) wecomForm() tview.Primitive { form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) { cfg.WebhookPath = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) addIntField( form, "Reply Timeout", @@ -464,14 +292,7 @@ func (s *appState) wecomForm() tview.Primitive { func (s *appState) wecomAppForm() tview.Primitive { cfg := &s.config.Channels.WeComApp - form := baseChannelForm("WeCom App", cfg.Enabled, func(v bool) { - cfg.Enabled = v - s.dirty = true - refreshMainMenuIfPresent(s) - if menu, ok := s.menus["channel"]; ok { - refreshChannelMenuFromState(menu, s) - } - }) + form := baseChannelForm("WeCom App", cfg.Enabled, s.makeChannelOnEnabled(&cfg.Enabled)) form.AddInputField("Corp ID", cfg.CorpID, 64, nil, func(text string) { cfg.CorpID = strings.TrimSpace(text) }) @@ -492,9 +313,7 @@ func (s *appState) wecomAppForm() tview.Primitive { form.AddInputField("Webhook Path", cfg.WebhookPath, 64, nil, func(text string) { cfg.WebhookPath = strings.TrimSpace(text) }) - form.AddInputField("Allow From", strings.Join(cfg.AllowFrom, ","), 128, nil, func(text string) { - cfg.AllowFrom = splitCSV(text) - }) + addAllowFromField(form, &cfg.AllowFrom) addIntField( form, "Reply Timeout", @@ -504,6 +323,23 @@ func (s *appState) wecomAppForm() tview.Primitive { return wrapWithBack(form, s) } +func (s *appState) makeChannelOnEnabled(enabledPtr *bool) func(bool) { + return func(v bool) { + *enabledPtr = v + s.dirty = true + refreshMainMenuIfPresent(s) + if menu, ok := s.menus["channel"]; ok { + refreshChannelMenuFromState(menu, s) + } + } +} + +func addAllowFromField(form *tview.Form, allowFrom *picoclawconfig.FlexibleStringSlice) { + form.AddInputField("Allow From", strings.Join(*allowFrom, ","), 128, nil, func(text string) { + *allowFrom = splitCSV(text) + }) +} + func baseChannelForm(title string, enabled bool, onEnabled func(bool)) *tview.Form { form := tview.NewForm() form.SetBorder(true).SetTitle(fmt.Sprintf("Channel: %s", title)) diff --git a/cmd/picoclaw-launcher-tui/internal/ui/style.go b/cmd/picoclaw-launcher-tui/internal/ui/style.go index ff4f8b1a8..68cdd60b9 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/style.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/style.go @@ -5,6 +5,19 @@ import ( "github.com/rivo/tview" ) +const ( + colorBlue = "[#3e5db9]" + colorRed = "[#d54646]" + banner = "\r\n[::b]" + + colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" + + colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" + + colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" + + colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" + + colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" + + colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " + + "[:]" +) + func applyStyles() { tview.Styles.PrimitiveBackgroundColor = tcell.NewRGBColor(12, 13, 22) tview.Styles.ContrastBackgroundColor = tcell.NewRGBColor(34, 19, 53) @@ -24,14 +37,7 @@ func bannerView() *tview.TextView { text.SetDynamicColors(true) text.SetTextAlign(tview.AlignCenter) text.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor) - text.SetText( - "[::b][#84aaff]██████╗ ██╗ ██████╗ ██████╗ ██████╗██╗ █████╗ ██╗ ██╗\n" + - "[#84aaff]██╔══██╗██║██╔════╝██╔═══██╗██╔════╝██║ ██╔══██╗██║ ██║\n" + - "[#84aaff]██████╔╝██║██║ ██║ ██║██║ ██║ ███████║██║ █╗ ██║\n" + - "[#84aaff]██╔═══╝ ██║██║ ██║ ██║██║ ██║ ██╔══██║██║███╗██║\n" + - "[#84aaff]██║ ██║╚██████╗╚██████╔╝╚██████╗███████╗██║ ██║╚███╔███╔╝\n" + - "[#84aaff]╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝", - ) + text.SetText(banner) text.SetBorder(false) return text } diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 747f7d44e..5225340c7 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -36,6 +36,7 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/voice" ) func gatewayCmd(debug bool) error { @@ -134,6 +135,12 @@ func gatewayCmd(debug bool) error { agentLoop.SetChannelManager(channelManager) agentLoop.SetMediaStore(mediaStore) + // Wire up voice transcription if a supported provider is configured. + if transcriber := voice.DetectTranscriber(cfg); transcriber != nil { + agentLoop.SetTranscriber(transcriber) + logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) + } + enabledChannels := channelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf("✓ Channels enabled: %s\n", enabledChannels) diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 6db69c990..d9263462e 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -48,7 +48,21 @@ func NewPicoclawCommand() *cobra.Command { return cmd } +const ( + colorBlue = "\033[1;38;2;62;93;185m" + colorRed = "\033[1;38;2;213;70;70m" + banner = "\r\n" + + colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" + + colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" + + colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" + + colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" + + colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" + + colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " + + "\033[0m\r\n" +) + func main() { + fmt.Printf("%s", banner) cmd := NewPicoclawCommand() if err := cmd.Execute(); err != nil { os.Exit(1) diff --git a/config/config.example.json b/config/config.example.json index d885ef94b..f46f6a670 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -6,7 +6,9 @@ "model_name": "gpt4", "max_tokens": 8192, "temperature": 0.7, - "max_tool_iterations": 20 + "max_tool_iterations": 20, + "summarize_message_threshold": 20, + "summarize_token_percent": 75 } }, "model_list": [ @@ -49,6 +51,7 @@ "telegram": { "enabled": false, "token": "YOUR_TELEGRAM_BOT_TOKEN", + "base_url": "", "proxy": "", "allow_from": [ "YOUR_USER_ID" @@ -58,6 +61,7 @@ "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", + "proxy": "", "allow_from": [], "group_trigger": { "mention_only": false @@ -127,7 +131,7 @@ "reasoning_channel_id": "" }, "wecom": { - "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats", + "_comment": "WeCom Bot - Easier setup, supports group chats", "enabled": false, "token": "YOUR_TOKEN", "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", @@ -138,7 +142,7 @@ "reasoning_channel_id": "" }, "wecom_app": { - "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md", + "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only.", "enabled": false, "corp_id": "YOUR_CORP_ID", "corp_secret": "YOUR_CORP_SECRET", @@ -149,6 +153,16 @@ "allow_from": [], "reply_timeout": 5, "reasoning_channel_id": "" + }, + "wecom_aibot": { + "_comment": "WeCom AI Bot (智能机器人) - Official WeCom AI Bot integration, supports proactive messaging and private chats.", + "enabled": false, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_path": "/webhook/wecom-aibot", + "max_steps": 10, + "welcome_message": "Hello! I'm your AI assistant. How can I help you today?", + "reasoning_channel_id": "" } }, "providers": { @@ -233,6 +247,71 @@ "cron": { "exec_timeout_minutes": 5 }, + "mcp": { + "enabled": false, + "servers": { + "context7": { + "enabled": false, + "type": "http", + "url": "https://mcp.context7.com/mcp", + "headers": { + "CONTEXT7_API_KEY": "ctx7sk-xx" + } + }, + "filesystem": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp" + ] + }, + "github": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN" + } + }, + "brave-search": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-brave-search" + ], + "env": { + "BRAVE_API_KEY": "YOUR_BRAVE_API_KEY" + } + }, + "postgres": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-postgres", + "postgresql://user:password@localhost/dbname" + ] + }, + "slack": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-slack" + ], + "env": { + "SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN", + "SLACK_TEAM_ID": "YOUR_SLACK_TEAM_ID" + } + } + } + }, "exec": { "enable_deny_patterns": false, "custom_deny_patterns": [] diff --git a/docker/Dockerfile.full b/docker/Dockerfile.full new file mode 100644 index 000000000..30e1680d5 --- /dev/null +++ b/docker/Dockerfile.full @@ -0,0 +1,44 @@ +# ============================================================ +# Stage 1: Build the picoclaw binary +# ============================================================ +FROM golang:1.26.0-alpine AS builder + +RUN apk add --no-cache git make + +WORKDIR /src + +# Cache dependencies +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source and build +COPY . . +RUN make build + +# ============================================================ +# Stage 2: Node.js-based runtime with full MCP support +# ============================================================ +FROM node:24-alpine3.23 + +# Install runtime dependencies +RUN apk add --no-cache \ + ca-certificates \ + curl \ + git \ + python3 \ + py3-pip + +# Install uv and symlink to system path +RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ + ln -s /root/.local/bin/uv /usr/local/bin/uv && \ + ln -s /root/.local/bin/uvx /usr/local/bin/uvx && \ + uv --version + +# Copy binary +COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw + +# Create picoclaw home directory +RUN /usr/local/bin/picoclaw onboard + +ENTRYPOINT ["picoclaw"] +CMD ["gateway"] diff --git a/docker/docker-compose.full.yml b/docker/docker-compose.full.yml new file mode 100644 index 000000000..6f34448c4 --- /dev/null +++ b/docker/docker-compose.full.yml @@ -0,0 +1,44 @@ +services: + # ───────────────────────────────────────────── + # PicoClaw Agent (one-shot query) - Full MCP Support + # docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent -m "Hello" + # ───────────────────────────────────────────── + picoclaw-agent: + build: + context: .. + dockerfile: docker/Dockerfile.full + container_name: picoclaw-agent-full + profiles: + - agent + volumes: + - ../config/config.json:/root/.picoclaw/config.json:ro + - picoclaw-workspace:/root/.picoclaw/workspace + - picoclaw-npm-cache:/root/.npm # npm cache for faster MCP server installs + entrypoint: ["picoclaw", "agent"] + stdin_open: true + tty: true + + # ───────────────────────────────────────────── + # PicoClaw Gateway (Long-running Bot) - Full MCP Support + # docker compose -f docker/docker-compose.full.yml --profile gateway up + # ───────────────────────────────────────────── + picoclaw-gateway: + build: + context: .. + dockerfile: docker/Dockerfile.full + container_name: picoclaw-gateway-full + restart: unless-stopped + profiles: + - gateway + volumes: + # Configuration file + - ../config/config.json:/root/.picoclaw/config.json:ro + # Persistent workspace (sessions, memory, logs) + - picoclaw-workspace:/root/.picoclaw/workspace + # NPM cache for faster MCP server installs + - picoclaw-npm-cache:/root/.npm + command: ["gateway"] + +volumes: + picoclaw-workspace: + picoclaw-npm-cache: # Cache npm packages to speed up MCP server installations diff --git a/docs/channels/wecom/wecom_aibot/README.zh.md b/docs/channels/wecom/wecom_aibot/README.zh.md new file mode 100644 index 000000000..d210528af --- /dev/null +++ b/docs/channels/wecom/wecom_aibot/README.zh.md @@ -0,0 +1,116 @@ +# 企业微信智能机器人 (AI Bot) + +企业微信智能机器人(AI Bot)是企业微信官方提供的 AI 对话接入方式,支持私聊与群聊,内置流式响应协议,并支持超时后通过 `response_url` 主动推送最终回复。 + +## 与其他 WeCom 通道的对比 + +| 特性 | WeCom Bot | WeCom App | **WeCom AI Bot** | +|------|-----------|-----------|-----------------| +| 私聊 | ✅ | ✅ | ✅ | +| 群聊 | ✅ | ❌ | ✅ | +| 流式输出 | ❌ | ❌ | ✅ | +| 超时主动推送 | ❌ | ✅ | ✅ | +| 配置复杂度 | 低 | 高 | 中 | + +## 配置 + +```json +{ + "channels": { + "wecom_aibot": { + "enabled": true, + "token": "YOUR_TOKEN", + "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY", + "webhook_path": "/webhook/wecom-aibot", + "allow_from": [], + "welcome_message": "你好!有什么可以帮助你的吗?", + "max_steps": 10 + } + } +} +``` + +| 字段 | 类型 | 必填 | 描述 | +| ---------------- | ------ | ---- | -------------------------------------------------- | +| token | string | 是 | 回调验证令牌,在 AI Bot 管理页面配置 | +| encoding_aes_key | string | 是 | 43 字符 AES 密钥,在 AI Bot 管理页面随机生成 | +| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-aibot) | +| allow_from | array | 否 | 用户 ID 白名单,空数组表示允许所有用户 | +| welcome_message | string | 否 | 用户进入聊天时发送的欢迎语,留空则不发送 | +| reply_timeout | int | 否 | 回复超时时间(秒,默认:5) | +| max_steps | int | 否 | Agent 最大执行步骤数(默认:10) | + +## 设置流程 + +1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin) +2. 进入"应用管理" → "智能机器人",创建或选择一个 AI Bot +3. 在 AI Bot 配置页面,填写"消息接收"信息: + - **URL**:`http://:18791/webhook/wecom-aibot` + - **Token**:随机生成或自定义 + - **EncodingAESKey**:点击"随机生成",得到 43 字符密钥 +4. 将 Token 和 EncodingAESKey 填入 PicoClaw 配置文件,启动服务后回到管理后台保存(企业微信会发送验证请求) + +> [!TIP] +> 服务器需要能被企业微信服务器访问。如在内网/本地开发,可使用 [ngrok](https://ngrok.com) 或 frp 做内网穿透。 + +## 流式响应协议 + +WeCom AI Bot 使用"流式拉取"协议,区别于普通 Webhook 的一次性回复: + +``` +用户发消息 + │ + ▼ +PicoClaw 立即返回 {finish: false}(Agent 开始处理) + │ + ▼ +企业微信每隔约 1 秒拉取一次 {msgtype: "stream", stream: {id: "..."}} + │ + ├─ Agent 未完成 → 返回 {finish: false}(继续等待) + │ + └─ Agent 完成 → 返回 {finish: true, content: "回答内容"} +``` + +**超时处理**(任务超过 30 秒): + +若 Agent 处理时间超过约 30 秒(企业微信最大轮询窗口为 6 分钟),PicoClaw 会: + +1. 立即关闭流,向用户显示「⏳ 正在处理中,请稍候,结果将稍后发送。」 +2. Agent 继续在后台运行 +3. Agent 完成后,通过消息中携带的 `response_url` 将最终回复主动推送给用户 + +> `response_url` 由企业微信颁发,有效期 1 小时,只可使用一次,无需加密,直接 POST markdown 消息体即可。 + +## 欢迎语 + +配置 `welcome_message` 后,当用户打开与 AI Bot 的聊天窗口时(`enter_chat` 事件),PicoClaw 会自动回复该欢迎语。留空则静默忽略。 + +```json +"welcome_message": "你好!我是 PicoClaw AI 助手,有什么可以帮你?" +``` + +## 常见问题 + +### 回调 URL 验证失败 + +- 确认服务器防火墙已开放对应端口(默认 18791) +- 确认 `token` 与 `encoding_aes_key` 填写正确 +- 检查 PicoClaw 日志是否收到了来自企业微信的 GET 请求 + +### 消息没有回复 + +- 检查 `allow_from` 是否意外限制了发送者 +- 查看日志中是否出现 `context canceled` 或 Agent 错误 +- 确认 Agent 配置(`model_name` 等)正确 + +### 超长任务没有收到最终推送 + +- 确认消息回调中携带了 `response_url`(仅企业微信新版 AI Bot 支持) +- 确认服务器能主动访问外网(需向 `response_url` POST 请求) +- 查看日志关键词 `response_url mode` 和 `Sending reply via response_url` + +## 参考文档 + +- [企业微信 AI Bot 接入文档](https://developer.work.weixin.qq.com/document/path/100719) +- [流式响应协议说明](https://developer.work.weixin.qq.com/document/path/100719) +- [response_url 主动回复](https://developer.work.weixin.qq.com/document/path/101138) diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md index 8aba1aa91..6204fb0c8 100644 --- a/docs/tools_configuration.md +++ b/docs/tools_configuration.md @@ -8,6 +8,7 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`. { "tools": { "web": { ... }, + "mcp": { ... }, "exec": { ... }, "cron": { ... }, "skills": { ... } @@ -21,35 +22,35 @@ Web tools are used for web search and fetching. ### Brave -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enabled` | bool | false | Enable Brave search | -| `api_key` | string | - | Brave Search API key | -| `max_results` | int | 5 | Maximum number of results | +| Config | Type | Default | Description | +| ------------- | ------ | ------- | ------------------------- | +| `enabled` | bool | false | Enable Brave search | +| `api_key` | string | - | Brave Search API key | +| `max_results` | int | 5 | Maximum number of results | ### DuckDuckGo -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enabled` | bool | true | Enable DuckDuckGo search | -| `max_results` | int | 5 | Maximum number of results | +| Config | Type | Default | Description | +| ------------- | ---- | ------- | ------------------------- | +| `enabled` | bool | true | Enable DuckDuckGo search | +| `max_results` | int | 5 | Maximum number of results | ### Perplexity -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enabled` | bool | false | Enable Perplexity search | -| `api_key` | string | - | Perplexity API key | -| `max_results` | int | 5 | Maximum number of results | +| Config | Type | Default | Description | +| ------------- | ------ | ------- | ------------------------- | +| `enabled` | bool | false | Enable Perplexity search | +| `api_key` | string | - | Perplexity API key | +| `max_results` | int | 5 | Maximum number of results | ## Exec Tool The exec tool is used to execute shell commands. -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking | -| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) | +| Config | Type | Default | Description | +| ---------------------- | ----- | ------- | ------------------------------------------ | +| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking | +| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) | ### Functionality @@ -80,10 +81,7 @@ By default, PicoClaw blocks the following dangerous commands: "tools": { "exec": { "enable_deny_patterns": true, - "custom_deny_patterns": [ - "\\brm\\s+-r\\b", - "\\bkillall\\s+python" - ] + "custom_deny_patterns": ["\\brm\\s+-r\\b", "\\bkillall\\s+python"] } } } @@ -93,9 +91,84 @@ By default, PicoClaw blocks the following dangerous commands: The cron tool is used for scheduling periodic tasks. -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit | +| Config | Type | Default | Description | +| ---------------------- | ---- | ------- | ---------------------------------------------- | +| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit | + +## MCP Tool + +The MCP tool enables integration with external Model Context Protocol servers. + +### Global Config + +| Config | Type | Default | Description | +| --------- | ------ | ------- | ----------------------------------- | +| `enabled` | bool | false | Enable MCP integration globally | +| `servers` | object | `{}` | Map of server name to server config | + +### Per-Server Config + +| Config | Type | Required | Description | +| ---------- | ------ | -------- | ------------------------------------------ | +| `enabled` | bool | yes | Enable this MCP server | +| `type` | string | no | Transport type: `stdio`, `sse`, `http` | +| `command` | string | stdio | Executable command for stdio transport | +| `args` | array | no | Command arguments for stdio transport | +| `env` | object | no | Environment variables for stdio process | +| `env_file` | string | no | Path to environment file for stdio process | +| `url` | string | sse/http | Endpoint URL for `sse`/`http` transport | +| `headers` | object | no | HTTP headers for `sse`/`http` transport | + +### Transport Behavior + +- If `type` is omitted, transport is auto-detected: + - `url` is set → `sse` + - `command` is set → `stdio` +- `http` and `sse` both use `url` + optional `headers`. +- `env` and `env_file` are only applied to `stdio` servers. + +### Configuration Examples + +#### 1) Stdio MCP server + +```json +{ + "tools": { + "mcp": { + "enabled": true, + "servers": { + "filesystem": { + "enabled": true, + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } + } + } +} +``` + +#### 2) Remote SSE/HTTP MCP server + +```json +{ + "tools": { + "mcp": { + "enabled": true, + "servers": { + "remote-mcp": { + "enabled": true, + "type": "sse", + "url": "https://example.com/mcp", + "headers": { + "Authorization": "Bearer YOUR_TOKEN" + } + } + } + } + } +} +``` ## Skills Tool @@ -103,13 +176,13 @@ The skills tool configures skill discovery and installation via registries like ### Registries -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | -| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | -| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | -| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | -| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | +| Config | Type | Default | Description | +| ---------------------------------- | ------ | -------------------- | ----------------------- | +| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | +| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | +| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | +| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | +| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | ### Configuration Example @@ -136,8 +209,10 @@ The skills tool configures skill discovery and installation via registries like All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS_
_`: For example: + - `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true` - `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false` - `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10` +- `PICOCLAW_TOOLS_MCP_ENABLED=true` -Note: Array-type environment variables are not currently supported and must be set via the config file. +Note: Nested map-style config (for example `tools.mcp.servers..*`) is configured in `config.json` rather than environment variables. diff --git a/docs/wecom-app-configuration.md b/docs/wecom-app-configuration.md deleted file mode 100644 index 3c720ecd1..000000000 --- a/docs/wecom-app-configuration.md +++ /dev/null @@ -1,115 +0,0 @@ -# 企业微信自建应用 (WeCom App) 配置指南 - -本文档介绍如何在 PicoClaw 中配置企业微信自建应用 (wecom-app) 通道。 - -## 功能特性 - -| 功能 | 支持状态 | -|------|---------| -| 被动接收消息 | ✅ | -| 主动发送消息 | ✅ | -| 私聊 | ✅ | -| 群聊 | ❌ | - -## 配置步骤 - -### 1. 企业微信后台配置 - -1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin) -2. 进入"应用管理" → 选择自建应用 -3. 记录以下信息: - - **AgentId**: 应用详情页显示 - - **Secret**: 点击"查看"获取 -4. 进入"我的企业"页面,记录 **企业ID** (CorpID) - -### 2. 接收消息配置 - -1. 在应用详情页,点击"接收消息"的"设置API接收" -2. 填写以下信息: - - **URL**: `http://your-server:18790/webhook/wecom-app` - - **Token**: 随机生成或自定义(用于签名验证) - - **EncodingAESKey**: 点击"随机生成"生成43字符的密钥 -3. 点击"保存"时,企业微信会发送验证请求 - -### 3. PicoClaw 配置 - -在 `config.json` 中添加以下配置: - -```json -{ - "channels": { - "wecom_app": { - "enabled": true, - "corp_id": "wwxxxxxxxxxxxxxxxx", // 企业ID - "corp_secret": "xxxxxxxxxxxxxxxxxxxxxxxx", // 应用Secret - "agent_id": 1000002, // 应用AgentId - "token": "your_token", // 接收消息配置的Token - "encoding_aes_key": "your_encoding_aes_key", // 接收消息配置的EncodingAESKey - "webhook_path": "/webhook/wecom-app", - "allow_from": [], - "reply_timeout": 5 - } - } -} -``` - -## 常见问题 - -### 1. 回调URL验证失败 - -**症状**: 企业微信保存API接收消息时提示验证失败 - -**检查项**: -- 确认服务器防火墙已开放 Gateway 端口(默认 18790) -- 确认 `corp_id`、`token`、`encoding_aes_key` 配置正确 -- 查看 PicoClaw 日志是否有请求到达 - -### 2. 中文消息解密失败 - -**症状**: 发送中文消息时出现 `invalid padding size` 错误 - -**原因**: 企业微信使用非标准的 PKCS7 填充(32字节块大小) - -**解决**: 确保使用最新版本的 PicoClaw,已修复此问题。 - -### 3. 端口冲突 - -**症状**: 启动时提示端口已被占用 - -**解决**: 修改 `gateway.port` 为其他端口(所有 Webhook 渠道共享同一个 Gateway HTTP 服务器) - -## 技术细节 - -### 加密算法 - -- **算法**: AES-256-CBC -- **密钥**: EncodingAESKey Base64解码后的32字节 -- **IV**: AESKey的前16字节 -- **填充**: PKCS7(块大小为32字节,非标准16字节) -- **消息格式**: XML - -### 消息结构 - -解密后的消息格式: -``` -random(16B) + msg_len(4B) + msg + receiveid -``` - -其中 `receiveid` 对于自建应用是 `corp_id`。 - -## 调试 - -启用调试模式查看详细日志: - -```bash -picoclaw gateway --debug -``` - -关键日志标识: -- `wecom_app`: WeCom App 通道相关日志 -- `wecom_common`: 加密解密相关日志 - -## 参考文档 - -- [企业微信官方文档 - 接收消息](https://developer.work.weixin.qq.com/document/path/96211) -- [企业微信官方加解密库](https://github.com/sbzhu/weworkapi_golang) diff --git a/go.mod b/go.mod index 7892cade6..c1172937c 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,16 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mdp/qrterminal/v3 v3.2.1 + github.com/modelcontextprotocol/go-sdk v1.3.0 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 + github.com/rivo/tview v0.42.0 github.com/slack-go/slack v0.17.3 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 @@ -35,6 +38,7 @@ require ( github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect github.com/gdamore/tcell/v2 v2.13.8 // indirect + github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -43,7 +47,6 @@ require ( github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/tview v0.42.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/zerolog v1.34.0 // indirect github.com/spf13/pflag v1.0.10 // indirect @@ -81,6 +84,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fastjson v1.6.7 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/arch v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect diff --git a/go.sum b/go.sum index d1ee1d629..060594d06 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncV github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -96,6 +98,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= +github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= +github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -130,6 +134,8 @@ github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4= github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= +github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs= +github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE= github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= @@ -212,6 +218,8 @@ github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTd github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 6fccbaf53..3aa903b3f 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -34,6 +34,11 @@ type ContextBuilder struct { // created (didn't exist at cache time, now exist) or deleted (existed at // cache time, now gone) — both of which should trigger a cache rebuild. existedAtCache map[string]bool + + // skillFilesAtCache snapshots the skill tree file set and mtimes at cache + // build time. This catches nested file creations/deletions/mtime changes + // that may not update the top-level skill root directory mtime. + skillFilesAtCache map[string]time.Time } func getGlobalConfigDir() string { @@ -47,8 +52,11 @@ func getGlobalConfigDir() string { func NewContextBuilder(workspace string) *ContextBuilder { // builtin skills: skills directory in current project // Use the skills/ directory under the current working directory - wd, _ := os.Getwd() - builtinSkillsDir := filepath.Join(wd, "skills") + builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS")) + if builtinSkillsDir == "" { + wd, _ := os.Getwd() + builtinSkillsDir = filepath.Join(wd, "skills") + } globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills") return &ContextBuilder{ @@ -148,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string { cb.cachedSystemPrompt = prompt cb.cachedAt = baseline.maxMtime cb.existedAtCache = baseline.existed + cb.skillFilesAtCache = baseline.skillFiles logger.DebugCF("agent", "System prompt cached", map[string]any{ @@ -167,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() { cb.cachedSystemPrompt = "" cb.cachedAt = time.Time{} cb.existedAtCache = nil + cb.skillFilesAtCache = nil logger.DebugCF("agent", "System prompt cache invalidated", nil) } -// sourcePaths returns the workspace source file paths tracked for cache -// invalidation (bootstrap files + memory). The skills directory is handled -// separately in sourceFilesChangedLocked because it requires both directory- -// level and recursive file-level mtime checks. +// sourcePaths returns non-skill workspace source files tracked for cache +// invalidation (bootstrap files + memory). Skill roots are handled separately +// because they require both directory-level and recursive file-level checks. func (cb *ContextBuilder) sourcePaths() []string { return []string{ filepath.Join(cb.workspace, "AGENTS.md"), @@ -185,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string { } } +// skillRoots returns all skill root directories that can affect +// BuildSkillsSummary output (workspace/global/builtin). +func (cb *ContextBuilder) skillRoots() []string { + if cb.skillsLoader == nil { + return []string{filepath.Join(cb.workspace, "skills")} + } + + roots := cb.skillsLoader.SkillRoots() + if len(roots) == 0 { + return []string{filepath.Join(cb.workspace, "skills")} + } + return roots +} + // cacheBaseline holds the file existence snapshot and the latest observed // mtime across all tracked paths. Used as the cache reference point. type cacheBaseline struct { - existed map[string]bool - maxMtime time.Time + existed map[string]bool + skillFiles map[string]time.Time + maxMtime time.Time } // buildCacheBaseline records which tracked paths currently exist and computes // the latest mtime across all tracked files + skills directory contents. // Called under write lock when the cache is built. func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { - skillsDir := filepath.Join(cb.workspace, "skills") + skillRoots := cb.skillRoots() - // All paths whose existence we track: source files + skills dir. - allPaths := append(cb.sourcePaths(), skillsDir) + // All paths whose existence we track: source files + all skill roots. + allPaths := append(cb.sourcePaths(), skillRoots...) existed := make(map[string]bool, len(allPaths)) + skillFiles := make(map[string]time.Time) var maxMtime time.Time for _, p := range allPaths { @@ -212,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { } } - // Walk skills files to capture their mtimes too. - // Use os.Stat (not d.Info) to match the stat method used in - // fileChangedSince / skillFilesModifiedSince for consistency. - _ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { - if walkErr == nil && !d.IsDir() { - if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) { - maxMtime = info.ModTime() + // Walk all skill roots recursively to snapshot skill files and mtimes. + // Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks. + for _, root := range skillRoots { + _ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr == nil && !d.IsDir() { + if info, err := os.Stat(path); err == nil { + skillFiles[path] = info.ModTime() + if info.ModTime().After(maxMtime) { + maxMtime = info.ModTime() + } + } } - } - return nil - }) + return nil + }) + } // If no tracked files exist yet (empty workspace), maxMtime is zero. // Use a very old non-zero time so that: @@ -234,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { maxMtime = time.Unix(1, 0) } - return cacheBaseline{existed: existed, maxMtime: maxMtime} + return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime} } // sourceFilesChangedLocked checks whether any workspace source file has been @@ -254,21 +283,17 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool { return true } - // --- Skills directory (handled separately from sourcePaths) --- + // --- Skill roots (workspace/global/builtin) --- // - // 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files. - skillsDir := filepath.Join(cb.workspace, "skills") - if cb.fileChangedSince(skillsDir) { - return true + // For each root: + // 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince. + // 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot. + for _, root := range cb.skillRoots() { + if cb.fileChangedSince(root) { + return true + } } - - // 2. Structural changes (add/remove entries inside the dir) are reflected - // in the directory's own mtime, which fileChangedSince already checks. - // - // 3. Content-only edits to files inside skills/ do NOT update the parent - // directory mtime on most filesystems, so we recursively walk to check - // individual file mtimes at any nesting depth. - if skillFilesModifiedSince(skillsDir, cb.cachedAt) { + if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) { return true } @@ -309,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool { // if the callback returned nil when its err parameter is non-nil. var errWalkStop = errors.New("walk stop") -// skillFilesModifiedSince recursively walks the skills directory and checks -// whether any file was modified after t. This catches content-only edits at -// any nesting depth (e.g. skills/name/docs/extra.md) that don't update -// parent directory mtimes. -func skillFilesModifiedSince(skillsDir string, t time.Time) bool { - changed := false - err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { - if walkErr == nil && !d.IsDir() { - if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) { - changed = true - return errWalkStop // stop walking - } - } - return nil - }) - // errWalkStop is expected (early exit on first changed file). - // os.IsNotExist means the skills dir doesn't exist yet — not an error. - // Any other error is unexpected and worth logging. - if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) { - logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()}) +// skillFilesChangedSince compares the current recursive skill file tree +// against the cache-time snapshot. Any create/delete/mtime drift invalidates +// the cache. +func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool { + // Defensive: if the snapshot was never initialized, force rebuild. + if filesAtCache == nil { + return true } - return changed + + // Check cached files still exist and keep the same mtime. + for path, cachedMtime := range filesAtCache { + info, err := os.Stat(path) + if err != nil { + // A previously tracked file disappeared (or became inaccessible): + // either way, cached skill summary may now be stale. + return true + } + if !info.ModTime().Equal(cachedMtime) { + return true + } + } + + // Check no new files appeared under any skill root. + changed := false + for _, root := range skillRoots { + if strings.TrimSpace(root) == "" { + continue + } + + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + // Treat unexpected walk errors as changed to avoid stale cache. + if !os.IsNotExist(walkErr) { + changed = true + return errWalkStop + } + return nil + } + if d.IsDir() { + return nil + } + if _, ok := filesAtCache[path]; !ok { + changed = true + return errWalkStop + } + return nil + }) + + if changed { + return true + } + if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) { + logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()}) + return true + } + } + + return false } func (cb *ContextBuilder) LoadBootstrapFiles() string { @@ -466,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages( // Add current user message if strings.TrimSpace(currentMessage) != "" { - messages = append(messages, providers.Message{ + msg := providers.Message{ Role: "user", Content: currentMessage, - }) + } + if len(media) > 0 { + msg.Media = media + } + messages = append(messages, msg) } return messages diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go index 0905e8a46..707510820 100644 --- a/pkg/agent/context_cache_test.go +++ b/pkg/agent/context_cache_test.go @@ -383,6 +383,162 @@ Updated content.` } } +// TestGlobalSkillFileContentChange verifies that modifying a global skill +// (~/.picoclaw/skills) invalidates the cached system prompt. +func TestGlobalSkillFileContentChange(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md") + if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil { + t.Fatal(err) + } + v1 := `--- +name: global-skill +description: global-v1 +--- +# Global Skill v1` + if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil { + t.Fatal(err) + } + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "global-v1") { + t.Fatal("expected initial prompt to contain global skill description") + } + + v2 := `--- +name: global-skill +description: global-v2 +--- +# Global Skill v2` + if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil { + t.Fatal(err) + } + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(globalSkillPath, future, future); err != nil { + t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect global skill file content change") + } + + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, "global-v2") { + t.Error("rebuilt prompt should contain updated global skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when global skill file content changes") + } +} + +// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill +// invalidates the cached system prompt. +func TestBuiltinSkillFileContentChange(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + builtinRoot := t.TempDir() + t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot) + + builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md") + if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil { + t.Fatal(err) + } + v1 := `--- +name: builtin-skill +description: builtin-v1 +--- +# Builtin Skill v1` + if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil { + t.Fatal(err) + } + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "builtin-v1") { + t.Fatal("expected initial prompt to contain builtin skill description") + } + + v2 := `--- +name: builtin-skill +description: builtin-v2 +--- +# Builtin Skill v2` + if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil { + t.Fatal(err) + } + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(builtinSkillPath, future, future); err != nil { + t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change") + } + + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, "builtin-v2") { + t.Error("rebuilt prompt should contain updated builtin skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when builtin skill file content changes") + } +} + +// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill +// file invalidates the cached system prompt. +func TestSkillFileDeletionInvalidatesCache(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "skills/delete-me/SKILL.md": `--- +name: delete-me +description: delete-me-v1 +--- +# Delete Me`, + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "delete-me-v1") { + t.Fatal("expected initial prompt to contain skill description") + } + + skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md") + if err := os.Remove(skillPath); err != nil { + t.Fatal(err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect deleted skill file") + } + + sp2 := cb.BuildSystemPromptWithCache() + if strings.Contains(sp2, "delete-me-v1") { + t.Error("rebuilt prompt should not contain deleted skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when skill file is deleted") + } +} + // TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines // can safely call BuildSystemPromptWithCache concurrently without producing // empty results, panics, or data races. diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index ed438059f..ed25f537f 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -18,22 +18,24 @@ import ( // AgentInstance represents a fully configured agent with its own workspace, // session manager, context builder, and tool registry. type AgentInstance struct { - ID string - Name string - Model string - Fallbacks []string - Workspace string - MaxIterations int - MaxTokens int - Temperature float64 - ContextWindow int - Provider providers.LLMProvider - Sessions *session.SessionManager - ContextBuilder *ContextBuilder - Tools *tools.ToolRegistry - Subagents *config.SubagentsConfig - SkillsFilter []string - Candidates []providers.FallbackCandidate + ID string + Name string + Model string + Fallbacks []string + Workspace string + MaxIterations int + MaxTokens int + Temperature float64 + ContextWindow int + SummarizeMessageThreshold int + SummarizeTokenPercent int + Provider providers.LLMProvider + Sessions *session.SessionManager + ContextBuilder *ContextBuilder + Tools *tools.ToolRegistry + Subagents *config.SubagentsConfig + SkillsFilter []string + Candidates []providers.FallbackCandidate } // NewAgentInstance creates an agent instance from config. @@ -101,6 +103,16 @@ func NewAgentInstance( temperature = *defaults.Temperature } + summarizeMessageThreshold := defaults.SummarizeMessageThreshold + if summarizeMessageThreshold == 0 { + summarizeMessageThreshold = 20 + } + + summarizeTokenPercent := defaults.SummarizeTokenPercent + if summarizeTokenPercent == 0 { + summarizeTokenPercent = 75 + } + // Resolve fallback candidates modelCfg := providers.ModelConfig{ Primary: model, @@ -149,22 +161,24 @@ func NewAgentInstance( candidates := providers.ResolveCandidatesWithLookup(modelCfg, defaults.Provider, resolveFromModelList) return &AgentInstance{ - ID: agentID, - Name: agentName, - Model: model, - Fallbacks: fallbacks, - Workspace: workspace, - MaxIterations: maxIter, - MaxTokens: maxTokens, - Temperature: temperature, - ContextWindow: maxTokens, - Provider: provider, - Sessions: sessionsManager, - ContextBuilder: contextBuilder, - Tools: toolsRegistry, - Subagents: subagents, - SkillsFilter: skillsFilter, - Candidates: candidates, + ID: agentID, + Name: agentName, + Model: model, + Fallbacks: fallbacks, + Workspace: workspace, + MaxIterations: maxIter, + MaxTokens: maxTokens, + Temperature: temperature, + ContextWindow: maxTokens, + SummarizeMessageThreshold: summarizeMessageThreshold, + SummarizeTokenPercent: summarizeTokenPercent, + Provider: provider, + Sessions: sessionsManager, + ContextBuilder: contextBuilder, + Tools: toolsRegistry, + Subagents: subagents, + SkillsFilter: skillsFilter, + Candidates: candidates, } } diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index af1bf2ead..4f41ecd1c 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -95,75 +95,68 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) { } func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "step-3.5-flash", - }, + tests := []struct { + name string + aliasName string + modelName string + apiBase string + wantProvider string + wantModel string + }{ + { + name: "alias with provider prefix", + aliasName: "step-3.5-flash", + modelName: "openrouter/stepfun/step-3.5-flash:free", + apiBase: "https://openrouter.ai/api/v1", + wantProvider: "openrouter", + wantModel: "stepfun/step-3.5-flash:free", }, - ModelList: []config.ModelConfig{ - { - ModelName: "step-3.5-flash", - Model: "openrouter/stepfun/step-3.5-flash:free", - APIBase: "https://openrouter.ai/api/v1", - }, + { + name: "alias without provider prefix", + aliasName: "glm-5", + modelName: "glm-5", + apiBase: "https://api.z.ai/api/coding/paas/v4", + wantProvider: "openai", + wantModel: "glm-5", }, } - provider := &mockProvider{} - agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) - if len(agent.Candidates) != 1 { - t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) - } - if agent.Candidates[0].Provider != "openrouter" { - t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openrouter") - } - if agent.Candidates[0].Model != "stepfun/step-3.5-flash:free" { - t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "stepfun/step-3.5-flash:free") - } -} - -func TestNewAgentInstance_ResolveCandidatesFromModelListAliasWithoutProtocol(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agent-instance-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "glm-5", - }, - }, - ModelList: []config.ModelConfig{ - { - ModelName: "glm-5", - Model: "glm-5", - APIBase: "https://api.z.ai/api/coding/paas/v4", - }, - }, - } - - provider := &mockProvider{} - agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) - - if len(agent.Candidates) != 1 { - t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) - } - if agent.Candidates[0].Provider != "openai" { - t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, "openai") - } - if agent.Candidates[0].Model != "glm-5" { - t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, "glm-5") + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: tt.aliasName, + }, + }, + ModelList: []config.ModelConfig{ + { + ModelName: tt.aliasName, + Model: tt.modelName, + APIBase: tt.apiBase, + }, + }, + } + + provider := &mockProvider{} + agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, provider) + + if len(agent.Candidates) != 1 { + t.Fatalf("len(Candidates) = %d, want 1", len(agent.Candidates)) + } + if agent.Candidates[0].Provider != tt.wantProvider { + t.Fatalf("candidate provider = %q, want %q", agent.Candidates[0].Provider, tt.wantProvider) + } + if agent.Candidates[0].Model != tt.wantModel { + t.Fatalf("candidate model = %q, want %q", agent.Candidates[0].Model, tt.wantModel) + } + }) } } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 00b0f096a..7ce2a37a6 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "path/filepath" + "regexp" "strings" "sync" "sync/atomic" @@ -23,6 +24,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/mcp" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" @@ -30,6 +32,7 @@ import ( "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" ) type AgentLoop struct { @@ -42,23 +45,29 @@ type AgentLoop struct { fallback *providers.FallbackChain channelManager *channels.Manager mediaStore media.MediaStore + transcriber voice.Transcriber } // processOptions configures how a message is processed type processOptions struct { - SessionKey string // Session identifier for history/context - Channel string // Target channel for tool execution - ChatID string // Target chat ID for tool execution - UserMessage string // User message content (may include prefix) - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - NoHistory bool // If true, don't load session history (for heartbeat) + SessionKey string // Session identifier for history/context + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + UserMessage string // User message content (may include prefix) + Media []string // media:// refs from inbound message + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) } const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." -func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { +func NewAgentLoop( + cfg *config.Config, + msgBus *bus.MessageBus, + provider providers.LLMProvider, +) *AgentLoop { registry := NewAgentRegistry(cfg, provider) // Register shared tools to all agents @@ -112,6 +121,11 @@ func registerSharedTools( PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey, + GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL, + GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine, + GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults, + GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled, Proxy: cfg.Tools.Web.Proxy, }) if err != nil { @@ -170,6 +184,72 @@ func registerSharedTools( func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) + // Initialize MCP servers for all agents + if al.cfg.Tools.MCP.Enabled { + mcpManager := mcp.NewManager() + // Ensure MCP connections are cleaned up on exit, regardless of initialization success + // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails + defer func() { + if err := mcpManager.Close(); err != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]any{ + "error": err.Error(), + }) + } + }() + + defaultAgent := al.registry.GetDefaultAgent() + var workspacePath string + if defaultAgent != nil && defaultAgent.Workspace != "" { + workspacePath = defaultAgent.Workspace + } else { + workspacePath = al.cfg.WorkspacePath() + } + + if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil { + logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available", + map[string]any{ + "error": err.Error(), + }) + } else { + // Register MCP tools for all agents + servers := mcpManager.GetServers() + uniqueTools := 0 + totalRegistrations := 0 + agentIDs := al.registry.ListAgentIDs() + agentCount := len(agentIDs) + + for serverName, conn := range servers { + uniqueTools += len(conn.Tools) + for _, tool := range conn.Tools { + for _, agentID := range agentIDs { + agent, ok := al.registry.GetAgent(agentID) + if !ok { + continue + } + mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) + agent.Tools.Register(mcpTool) + totalRegistrations++ + logger.DebugCF("agent", "Registered MCP tool", + map[string]any{ + "agent_id": agentID, + "server": serverName, + "tool": tool.Name, + "name": mcpTool.Name(), + }) + } + } + } + logger.InfoCF("agent", "MCP tools registered successfully", + map[string]any{ + "server_count": len(servers), + "unique_tools": uniqueTools, + "total_registrations": totalRegistrations, + "agent_count": agentCount, + }) + } + } + for al.running.Load() { select { case <-ctx.Done(): @@ -262,6 +342,64 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s } +// SetTranscriber injects a voice transcriber for agent-level audio transcription. +func (al *AgentLoop) SetTranscriber(t voice.Transcriber) { + al.transcriber = t +} + +var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`) + +// transcribeAudioInMessage resolves audio media refs, transcribes them, and +// replaces audio annotations in msg.Content with the transcribed text. +func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) bus.InboundMessage { + if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 { + return msg + } + + // Transcribe each audio media ref in order. + var transcriptions []string + for _, ref := range msg.Media { + path, meta, err := al.mediaStore.ResolveWithMeta(ref) + if err != nil { + logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err}) + continue + } + if !utils.IsAudioFile(meta.Filename, meta.ContentType) { + continue + } + result, err := al.transcriber.Transcribe(ctx, path) + if err != nil { + logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err}) + transcriptions = append(transcriptions, "") + continue + } + transcriptions = append(transcriptions, result.Text) + } + + if len(transcriptions) == 0 { + return msg + } + + // Replace audio annotations sequentially with transcriptions. + idx := 0 + newContent := audioAnnotationRe.ReplaceAllStringFunc(msg.Content, func(match string) string { + if idx >= len(transcriptions) { + return match + } + text := transcriptions[idx] + idx++ + return "[voice: " + text + "]" + }) + + // Append any remaining transcriptions not matched by an annotation. + for ; idx < len(transcriptions); idx++ { + newContent += "\n[voice: " + transcriptions[idx] + "]" + } + + msg.Content = newContent + return msg +} + // inferMediaType determines the media type ("image", "audio", "video", "file") // from a filename and MIME content type. func inferMediaType(filename, contentType string) string { @@ -310,7 +448,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error { return al.state.SetLastChatID(chatID) } -func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { +func (al *AgentLoop) ProcessDirect( + ctx context.Context, + content, sessionKey string, +) (string, error) { return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") } @@ -331,7 +472,10 @@ func (al *AgentLoop) ProcessDirectWithChannel( // ProcessHeartbeat processes a heartbeat request without session history. // Each heartbeat is independent and doesn't accumulate context. -func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { +func (al *AgentLoop) ProcessHeartbeat( + ctx context.Context, + content, channel, chatID string, +) (string, error) { agent := al.registry.GetDefaultAgent() if agent == nil { return "", fmt.Errorf("no default agent for heartbeat") @@ -356,13 +500,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } else { logContent = utils.Truncate(msg.Content, 80) } - logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), + logger.InfoCF( + "agent", + fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), map[string]any{ "channel": msg.Channel, "chat_id": msg.ChatID, "sender_id": msg.SenderID, "session_key": msg.SessionKey, - }) + }, + ) + + msg = al.transcribeAudioInMessage(ctx, msg) // Route system messages to processSystemMessage if msg.Channel == "system" { @@ -417,15 +566,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, + Media: msg.Media, DefaultResponse: defaultResponse, EnableSummary: true, SendResponse: false, }) } -func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { +func (al *AgentLoop) processSystemMessage( + ctx context.Context, + msg bus.InboundMessage, +) (string, error) { if msg.Channel != "system" { - return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel) + return "", fmt.Errorf( + "processSystemMessage called with non-system message channel: %s", + msg.Channel, + ) } logger.InfoCF("agent", "Processing system message", @@ -483,14 +639,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe } // runAgentLoop is the core message processing logic. -func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) { +func (al *AgentLoop) runAgentLoop( + ctx context.Context, + agent *AgentInstance, + opts processOptions, +) (string, error) { // 0. Record last channel for heartbeat notifications (skip internal channels) if opts.Channel != "" && opts.ChatID != "" { // Don't record internal channels (cli, system, subagent) if !constants.IsInternalChannel(opts.Channel) { channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) if err := al.RecordLastChannel(channelKey); err != nil { - logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()}) + logger.WarnCF( + "agent", + "Failed to record last channel", + map[string]any{"error": err.Error()}, + ) } } } @@ -509,11 +673,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt history, summary, opts.UserMessage, - nil, + opts.Media, opts.Channel, opts.ChatID, ) + // Resolve media:// refs to base64 data URLs (streaming) + maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize() + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) + // 3. Save user message to session agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) @@ -572,7 +740,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string return "" } -func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) { +func (al *AgentLoop) handleReasoning( + ctx context.Context, + reasoningContent, channelName, channelID string, +) { if reasoningContent == "" || channelName == "" || channelID == "" { return } @@ -665,22 +836,33 @@ func (al *AgentLoop) runLLMIteration( callLLM := func() (*providers.LLMResponse, error) { if len(agent.Candidates) > 1 && al.fallback != nil { - fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, + fbResult, fbErr := al.fallback.Execute( + ctx, + agent.Candidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, - }) + return agent.Provider.Chat( + ctx, + messages, + providerToolDefs, + model, + map[string]any{ + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + "prompt_cache_key": agent.ID, + }, + ) }, ) if fbErr != nil { return nil, fbErr } if fbResult.Provider != "" && len(fbResult.Attempts) > 0 { - logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", - fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), - map[string]any{"agent_id": agent.ID, "iteration": iteration}) + logger.InfoCF( + "agent", + fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", + fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), + map[string]any{"agent_id": agent.ID, "iteration": iteration}, + ) } return fbResult.Response, nil } @@ -731,10 +913,14 @@ func (al *AgentLoop) runLLMIteration( } if isContextError && retry < maxRetries { - logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{ - "error": err.Error(), - "retry": retry, - }) + logger.WarnCF( + "agent", + "Context window error detected, attempting compression", + map[string]any{ + "error": err.Error(), + "retry": retry, + }, + ) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ @@ -766,7 +952,12 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } - go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel)) + go al.handleReasoning( + ctx, + response.Reasoning, + opts.Channel, + al.targetReasoningChannelID(opts.Channel), + ) logger.DebugCF("agent", "LLM response", map[string]any{ @@ -841,62 +1032,76 @@ func (al *AgentLoop) runLLMIteration( // Save assistant message with tool calls to session agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) - // Execute tool calls - for _, tc := range normalizedToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "agent_id": agent.ID, - "tool": tc.Name, - "iteration": iteration, - }) + // Execute tool calls in parallel + type indexedAgentResult struct { + result *tools.ToolResult + tc providers.ToolCall + } - // Create async callback for tools that implement AsyncTool - // NOTE: Following openclaw's design, async tools do NOT send results directly to users. - // Instead, they notify the agent via PublishInbound, and the agent decides - // whether to forward the result to the user (in processSystemMessage). - asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { - // Log the async completion but don't send directly to user - // The agent will handle user notification via processSystemMessage - if !result.Silent && result.ForUser != "" { - logger.InfoCF("agent", "Async tool completed, agent will handle notification", - map[string]any{ - "tool": tc.Name, - "content_len": len(result.ForUser), - }) + agentResults := make([]indexedAgentResult, len(normalizedToolCalls)) + var wg sync.WaitGroup + + for i, tc := range normalizedToolCalls { + agentResults[i].tc = tc + + wg.Add(1) + go func(idx int, tc providers.ToolCall) { + defer wg.Done() + + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "agent_id": agent.ID, + "tool": tc.Name, + "iteration": iteration, + }) + + // Create async callback for tools that implement AsyncTool + asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { + if !result.Silent && result.ForUser != "" { + logger.InfoCF("agent", "Async tool completed, agent will handle notification", + map[string]any{ + "tool": tc.Name, + "content_len": len(result.ForUser), + }) + } } - } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) + toolResult := agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) + agentResults[idx].result = toolResult + }(i, tc) + } + wg.Wait() + // Process results in original order (send to user, save to session) + for _, r := range agentResults { // Send ForUser content to user immediately if not Silent - if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: toolResult.ForUser, + Content: r.result.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": tc.Name, - "content_len": len(toolResult.ForUser), + "tool": r.tc.Name, + "content_len": len(r.result.ForUser), }) } // If tool returned media refs, publish them as outbound media - if len(toolResult.Media) > 0 && opts.SendResponse { - parts := make([]bus.MediaPart, 0, len(toolResult.Media)) - for _, ref := range toolResult.Media { + if len(r.result.Media) > 0 && opts.SendResponse { + parts := make([]bus.MediaPart, 0, len(r.result.Media)) + for _, ref := range r.result.Media { part := bus.MediaPart{Ref: ref} - // Populate metadata from MediaStore when available if al.mediaStore != nil { if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { part.Filename = meta.Filename @@ -914,15 +1119,15 @@ func (al *AgentLoop) runLLMIteration( } // Determine content for LLM based on tool result - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() + contentForLLM := r.result.ForLLM + if contentForLLM == "" && r.result.Err != nil { + contentForLLM = r.result.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, + ToolCallID: r.tc.ID, } messages = append(messages, toolResultMsg) @@ -958,9 +1163,9 @@ func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID st func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) - threshold := agent.ContextWindow * 75 / 100 + threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100 - if len(newHistory) > 20 || tokenEstimate > threshold { + if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold { summarizeKey := agent.ID + ":" + sessionKey if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { @@ -1068,7 +1273,11 @@ func formatMessagesForLog(messages []providers.Message) string { for _, tc := range msg.ToolCalls { fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name) if tc.Function != nil { - fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200)) + fmt.Fprintf( + &sb, + " Arguments: %s\n", + utils.Truncate(tc.Function.Arguments, 200), + ) } } } @@ -1097,7 +1306,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string { fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name) fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description) if len(tool.Function.Parameters) > 0 { - fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200)) + fmt.Fprintf( + &sb, + " Parameters: %s\n", + utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200), + ) } } sb.WriteString("]") @@ -1194,7 +1407,9 @@ func (al *AgentLoop) summarizeBatch( existingSummary string, ) (string, error) { var sb strings.Builder - sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n") + sb.WriteString( + "Provide a concise summary of this conversation segment, preserving core context and key points.\n", + ) if existingSummary != "" { sb.WriteString("Existing context: ") sb.WriteString(existingSummary) diff --git a/pkg/agent/loop_media.go b/pkg/agent/loop_media.go new file mode 100644 index 000000000..82547a008 --- /dev/null +++ b/pkg/agent/loop_media.go @@ -0,0 +1,122 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package agent + +import ( + "bytes" + "encoding/base64" + "io" + "os" + "strings" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs. +// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding +// both raw bytes and encoded string in memory simultaneously. +// Returns a new slice; original messages are not mutated. +func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message { + if store == nil { + return messages + } + + result := make([]providers.Message, len(messages)) + copy(result, messages) + + for i, m := range result { + if len(m.Media) == 0 { + continue + } + + resolved := make([]string, 0, len(m.Media)) + for _, ref := range m.Media { + if !strings.HasPrefix(ref, "media://") { + resolved = append(resolved, ref) + continue + } + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{ + "ref": ref, + "error": err.Error(), + }) + continue + } + + info, err := os.Stat(localPath) + if err != nil { + logger.WarnCF("agent", "Failed to stat media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + if info.Size() > int64(maxSize) { + logger.WarnCF("agent", "Media file too large, skipping", map[string]any{ + "path": localPath, + "size": info.Size(), + "max_size": maxSize, + }) + continue + } + + // Determine MIME type: prefer metadata, fallback to magic-bytes detection + mime := meta.ContentType + if mime == "" { + kind, ftErr := filetype.MatchFile(localPath) + if ftErr != nil || kind == filetype.Unknown { + logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{ + "path": localPath, + }) + continue + } + mime = kind.MIME.Value + } + + // Streaming base64: open file → base64 encoder → buffer + // Peak memory: ~1.33x file size (buffer only, no raw bytes copy) + f, err := os.Open(localPath) + if err != nil { + logger.WarnCF("agent", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + + prefix := "data:" + mime + ";base64," + encodedLen := base64.StdEncoding.EncodedLen(int(info.Size())) + var buf bytes.Buffer + buf.Grow(len(prefix) + encodedLen) + buf.WriteString(prefix) + + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + if _, err := io.Copy(encoder, f); err != nil { + f.Close() + logger.WarnCF("agent", "Failed to encode media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + encoder.Close() + f.Close() + + resolved = append(resolved, buf.String()) + } + + result[i].Media = resolved + } + + return result +} diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 1034b06e8..023286f02 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -6,12 +6,14 @@ import ( "os" "path/filepath" "slices" + "strings" "testing" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -27,16 +29,15 @@ func (f *fakeChannel) IsAllowed(string) bool { func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } func (f *fakeChannel) ReasoningChannelID() string { return f.id } -func TestRecordLastChannel(t *testing.T) { - // Create temp workspace +func newTestAgentLoop( + t *testing.T, +) (al *AgentLoop, cfg *config.Config, msgBus *bus.MessageBus, provider *mockProvider, cleanup func()) { + t.Helper() tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } - defer os.RemoveAll(tmpDir) - - // Create test config - cfg := &config.Config{ + cfg = &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, @@ -46,74 +47,43 @@ func TestRecordLastChannel(t *testing.T) { }, }, } + msgBus = bus.NewMessageBus() + provider = &mockProvider{} + al = NewAgentLoop(cfg, msgBus, provider) + return al, cfg, msgBus, provider, func() { os.RemoveAll(tmpDir) } +} - // Create agent loop - msgBus := bus.NewMessageBus() - provider := &mockProvider{} - al := NewAgentLoop(cfg, msgBus, provider) +func TestRecordLastChannel(t *testing.T) { + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) + defer cleanup() - // Test RecordLastChannel testChannel := "test-channel" - err = al.RecordLastChannel(testChannel) - if err != nil { + if err := al.RecordLastChannel(testChannel); err != nil { t.Fatalf("RecordLastChannel failed: %v", err) } - - // Verify channel was saved - lastChannel := al.state.GetLastChannel() - if lastChannel != testChannel { - t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel) + if got := al.state.GetLastChannel(); got != testChannel { + t.Errorf("Expected channel '%s', got '%s'", testChannel, got) } - - // Verify persistence by creating a new agent loop al2 := NewAgentLoop(cfg, msgBus, provider) - if al2.state.GetLastChannel() != testChannel { - t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel()) + if got := al2.state.GetLastChannel(); got != testChannel { + t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, got) } } func TestRecordLastChatID(t *testing.T) { - // Create temp workspace - tmpDir, err := os.MkdirTemp("", "agent-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) + al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t) + defer cleanup() - // Create test config - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, - } - - // Create agent loop - msgBus := bus.NewMessageBus() - provider := &mockProvider{} - al := NewAgentLoop(cfg, msgBus, provider) - - // Test RecordLastChatID testChatID := "test-chat-id-123" - err = al.RecordLastChatID(testChatID) - if err != nil { + if err := al.RecordLastChatID(testChatID); err != nil { t.Fatalf("RecordLastChatID failed: %v", err) } - - // Verify chat ID was saved - lastChatID := al.state.GetLastChatID() - if lastChatID != testChatID { - t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID) + if got := al.state.GetLastChatID(); got != testChatID { + t.Errorf("Expected chat ID '%s', got '%s'", testChatID, got) } - - // Verify persistence by creating a new agent loop al2 := NewAgentLoop(cfg, msgBus, provider) - if al2.state.GetLastChatID() != testChatID { - t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID()) + if got := al2.state.GetLastChatID(); got != testChatID { + t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, got) } } @@ -840,3 +810,142 @@ func TestHandleReasoning(t *testing.T) { } }) } + +func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + // Create a minimal valid PNG (8-byte header is enough for filetype detection) + pngPath := filepath.Join(dir, "test.png") + // PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR + pngHeader := []byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // "IHDR" + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB + 0x00, 0x00, 0x00, // no interlace + 0x90, 0x77, 0x53, 0xDE, // CRC + } + if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil { + t.Fatal(err) + } + ref, err := store.Store(pngPath, media.MediaMeta{}, "test") + if err != nil { + t.Fatal(err) + } + + messages := []providers.Message{ + {Role: "user", Content: "describe this", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 { + t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media)) + } + if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") { + t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40]) + } +} + +func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + bigPath := filepath.Join(dir, "big.png") + // Write PNG header + padding to exceed limit + data := make([]byte, 1024+1) // 1KB + 1 byte + copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}) + if err := os.WriteFile(bigPath, data, 0o644); err != nil { + t.Fatal(err) + } + ref, _ := store.Store(bigPath, media.MediaMeta{}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + // Use a tiny limit (1KB) so the file is oversized + result := resolveMediaRefs(messages, store, 1024) + + if len(result[0].Media) != 0 { + t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media)) + } +} + +func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + txtPath := filepath.Join(dir, "readme.txt") + if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil { + t.Fatal(err) + } + ref, _ := store.Store(txtPath, media.MediaMeta{}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 0 { + t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media)) + } +} + +func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) { + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}}, + } + result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" { + t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media) + } +} + +func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + pngPath := filepath.Join(dir, "test.png") + pngHeader := []byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, + 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, + 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, + } + os.WriteFile(pngPath, pngHeader, 0o644) + ref, _ := store.Store(pngPath, media.MediaMeta{}, "test") + + original := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + originalRef := original[0].Media[0] + + resolveMediaRefs(original, store, config.DefaultMaxMediaSize) + + if original[0].Media[0] != originalRef { + t.Fatal("resolveMediaRefs mutated original message slice") + } +} + +func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + // File with JPEG content but stored with explicit content type + jpegPath := filepath.Join(dir, "photo") + jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes + os.WriteFile(jpegPath, jpegHeader, 0o644) + ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 { + t.Fatalf("expected 1 media, got %d", len(result[0].Media)) + } + if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") { + t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30]) + } +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index cd6a2560f..1de910c83 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,12 +3,15 @@ package discord import ( "context" "fmt" + "net/http" + "net/url" "os" "strings" "sync" "time" "github.com/bwmarrin/discordgo" + "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } + if err := applyDiscordProxy(session, cfg.Proxy); err != nil { + return nil, err + } base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(2000), channels.WithGroupTrigger(cfg.GroupTrigger), @@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func() func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", + ProxyURL: c.config.Proxy, }) } +func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error { + var proxyFunc func(*http.Request) (*url.URL, error) + if proxyAddr != "" { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err) + } + proxyFunc = http.ProxyURL(proxyURL) + } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { + proxyFunc = http.ProxyFromEnvironment + } + + if proxyFunc == nil { + return nil + } + + transport := &http.Transport{Proxy: proxyFunc} + session.Client = &http.Client{ + Timeout: sendTimeout, + Transport: transport, + } + + if session.Dialer != nil { + dialerCopy := *session.Dialer + dialerCopy.Proxy = proxyFunc + session.Dialer = &dialerCopy + } else { + session.Dialer = &websocket.Dialer{Proxy: proxyFunc} + } + + return nil +} + // stripBotMention removes the bot mention from the message content. // Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). func (c *DiscordChannel) stripBotMention(text string) string { diff --git a/pkg/channels/discord/discord_test.go b/pkg/channels/discord/discord_test.go new file mode 100644 index 000000000..0cd5328f4 --- /dev/null +++ b/pkg/channels/discord/discord_test.go @@ -0,0 +1,91 @@ +package discord + +import ( + "net/http" + "net/url" + "testing" + + "github.com/bwmarrin/discordgo" +) + +func TestApplyDiscordProxy_CustomProxy(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + restProxy := session.Client.Transport.(*http.Transport).Proxy + restProxyURL, err := restProxy(req) + if err != nil { + t.Fatalf("rest proxy func error: %v", err) + } + if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("REST proxy = %q, want %q", got, want) + } + + wsProxyURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("WS proxy = %q, want %q", got, want) + } +} + +func TestApplyDiscordProxy_FromEnvironment(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") + t.Setenv("http_proxy", "http://127.0.0.1:8888") + t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + t.Setenv("https_proxy", "http://127.0.0.1:8888") + t.Setenv("ALL_PROXY", "") + t.Setenv("all_proxy", "") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, ""); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + gotURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + + wantURL, err := url.Parse("http://127.0.0.1:8888") + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + if gotURL.String() != wantURL.String() { + t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String()) + } +} + +func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "://bad-proxy"); err == nil { + t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil") + } +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go index e8a057741..fbe085b73 100644 --- a/pkg/channels/feishu/common.go +++ b/pkg/channels/feishu/common.go @@ -1,5 +1,16 @@ package feishu +import ( + "encoding/json" + "regexp" + "strings" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions. +var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`) + // stringValue safely dereferences a *string pointer. func stringValue(v *string) string { if v == nil { @@ -7,3 +18,69 @@ func stringValue(v *string) string { } return *v } + +// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content. +// JSON 2.0 cards support full CommonMark standard markdown syntax. +func buildMarkdownCard(content string) (string, error) { + card := map[string]any{ + "schema": "2.0", + "body": map[string]any{ + "elements": []map[string]any{ + { + "tag": "markdown", + "content": content, + }, + }, + }, + } + data, err := json.Marshal(card) + if err != nil { + return "", err + } + return string(data), nil +} + +// extractJSONStringField unmarshals content as JSON and returns the value of the given string field. +// Returns "" if the content is invalid JSON or the field is missing/empty. +func extractJSONStringField(content, field string) string { + var m map[string]json.RawMessage + if err := json.Unmarshal([]byte(content), &m); err != nil { + return "" + } + raw, ok := m[field] + if !ok { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "" + } + return s +} + +// extractImageKey extracts the image_key from a Feishu image message content JSON. +// Format: {"image_key": "img_xxx"} +func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") } + +// extractFileKey extracts the file_key from a Feishu file/audio message content JSON. +// Format: {"file_key": "file_xxx", "file_name": "...", ...} +func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") } + +// extractFileName extracts the file_name from a Feishu file message content JSON. +func extractFileName(content string) string { return extractJSONStringField(content, "file_name") } + +// stripMentionPlaceholders removes @_user_N placeholders from the text content. +// These are inserted by Feishu when users @mention someone in a message. +func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string { + if len(mentions) == 0 { + return content + } + for _, m := range mentions { + if m.Key != nil && *m.Key != "" { + content = strings.ReplaceAll(content, *m.Key, "") + } + } + // Also clean up any remaining @_user_N patterns + content = mentionPlaceholderRegex.ReplaceAllString(content, "") + return strings.TrimSpace(content) +} diff --git a/pkg/channels/feishu/common_test.go b/pkg/channels/feishu/common_test.go new file mode 100644 index 000000000..fefc9f7c1 --- /dev/null +++ b/pkg/channels/feishu/common_test.go @@ -0,0 +1,292 @@ +package feishu + +import ( + "encoding/json" + "testing" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func TestExtractJSONStringField(t *testing.T) { + tests := []struct { + name string + content string + field string + want string + }{ + { + name: "valid field", + content: `{"image_key": "img_v2_xxx"}`, + field: "image_key", + want: "img_v2_xxx", + }, + { + name: "missing field", + content: `{"image_key": "img_v2_xxx"}`, + field: "file_key", + want: "", + }, + { + name: "invalid JSON", + content: `not json at all`, + field: "image_key", + want: "", + }, + { + name: "empty content", + content: "", + field: "image_key", + want: "", + }, + { + name: "non-string field value", + content: `{"count": 42}`, + field: "count", + want: "", + }, + { + name: "empty string value", + content: `{"image_key": ""}`, + field: "image_key", + want: "", + }, + { + name: "multiple fields", + content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`, + field: "file_name", + want: "test.pdf", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractJSONStringField(tt.content, tt.field) + if got != tt.want { + t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want) + } + }) + } +} + +func TestExtractImageKey(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"image_key": "img_v2_abc123"}`, + want: "img_v2_abc123", + }, + { + name: "missing key", + content: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `{broken`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractImageKey(tt.content) + if got != tt.want { + t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestExtractFileKey(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`, + want: "file_v2_abc123", + }, + { + name: "missing key", + content: `{"image_key": "img_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `not json`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFileKey(tt.content) + if got != tt.want { + t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestExtractFileName(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`, + want: "report.pdf", + }, + { + name: "missing name", + content: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `{bad`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFileName(tt.content) + if got != tt.want { + t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestBuildMarkdownCard(t *testing.T) { + tests := []struct { + name string + content string + }{ + { + name: "normal content", + content: "Hello **world**", + }, + { + name: "empty content", + content: "", + }, + { + name: "special characters", + content: `Code: "foo" & 'baz'`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := buildMarkdownCard(tt.content) + if err != nil { + t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err) + } + + // Verify valid JSON + var parsed map[string]any + if err := json.Unmarshal([]byte(result), &parsed); err != nil { + t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err) + } + + // Verify schema + if parsed["schema"] != "2.0" { + t.Errorf("schema = %v, want %q", parsed["schema"], "2.0") + } + + // Verify body.elements[0].content == input + body, ok := parsed["body"].(map[string]any) + if !ok { + t.Fatal("missing body in card JSON") + } + elements, ok := body["elements"].([]any) + if !ok || len(elements) == 0 { + t.Fatal("missing or empty elements in card JSON") + } + elem, ok := elements[0].(map[string]any) + if !ok { + t.Fatal("first element is not an object") + } + if elem["tag"] != "markdown" { + t.Errorf("tag = %v, want %q", elem["tag"], "markdown") + } + if elem["content"] != tt.content { + t.Errorf("content = %v, want %q", elem["content"], tt.content) + } + }) + } +} + +func TestStripMentionPlaceholders(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + content string + mentions []*larkim.MentionEvent + want string + }{ + { + name: "no mentions", + content: "Hello world", + mentions: nil, + want: "Hello world", + }, + { + name: "single mention", + content: "@_user_1 hello", + mentions: []*larkim.MentionEvent{ + {Key: strPtr("@_user_1")}, + }, + want: "hello", + }, + { + name: "multiple mentions", + content: "@_user_1 @_user_2 hey", + mentions: []*larkim.MentionEvent{ + {Key: strPtr("@_user_1")}, + {Key: strPtr("@_user_2")}, + }, + want: "hey", + }, + { + name: "empty content", + content: "", + mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}}, + want: "", + }, + { + name: "empty mentions slice", + content: "@_user_1 test", + mentions: []*larkim.MentionEvent{}, + want: "@_user_1 test", + }, + { + name: "mention with nil key", + content: "@_user_1 test", + mentions: []*larkim.MentionEvent{ + {Key: nil}, + }, + want: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripMentionPlaceholders(tt.content, tt.mentions) + if got != tt.want { + t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/feishu/feishu_32.go b/pkg/channels/feishu/feishu_32.go index d0ec758c6..f5e3aa224 100644 --- a/pkg/channels/feishu/feishu_32.go +++ b/pkg/channels/feishu/feishu_32.go @@ -16,6 +16,8 @@ type FeishuChannel struct { *channels.BaseChannel } +var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures") + // NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { return nil, errors.New( @@ -25,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan // Start is a stub method to satisfy the Channel interface func (c *FeishuChannel) Start(ctx context.Context) error { - return nil + return errUnsupported } // Stop is a stub method to satisfy the Channel interface func (c *FeishuChannel) Stop(ctx context.Context) error { - return nil + return errUnsupported } // Send is a stub method to satisfy the Channel interface func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - return errors.New("feishu channel is not supported on 32-bit architectures") + return errUnsupported +} + +// EditMessage is a stub method to satisfy MessageEditor +func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return errUnsupported +} + +// SendPlaceholder is a stub method to satisfy PlaceholderCapable +func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + return "", errUnsupported +} + +// ReactToMessage is a stub method to satisfy ReactionCapable +func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + return func() {}, errUnsupported +} + +// SendMedia is a stub method to satisfy MediaSender +func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + return errUnsupported } diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index 1db1bf669..00f73064d 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -6,10 +6,15 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "os" + "path/filepath" "sync" - "time" + "sync/atomic" lark "github.com/larksuite/oapi-sdk-go/v3" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" larkws "github.com/larksuite/oapi-sdk-go/v3/ws" @@ -19,6 +24,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -28,6 +34,8 @@ type FeishuChannel struct { client *lark.Client wsClient *larkws.Client + botOpenID atomic.Value // stores string; populated lazily for @mention detection + mu sync.Mutex cancel context.CancelFunc } @@ -38,11 +46,13 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) - return &FeishuChannel{ + ch := &FeishuChannel{ BaseChannel: base, config: cfg, client: lark.NewClient(cfg.AppID, cfg.AppSecret), - }, nil + } + ch.SetOwner(ch) + return ch, nil } func (c *FeishuChannel) Start(ctx context.Context) error { @@ -50,6 +60,13 @@ func (c *FeishuChannel) Start(ctx context.Context) error { return fmt.Errorf("feishu app_id or app_secret is empty") } + // Fetch bot open_id via API for reliable @mention detection. + if err := c.fetchBotOpenID(ctx); err != nil { + logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{ + "error": err.Error(), + }) + } + dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey). OnP2MessageReceiveV1(c.handleMessageReceive) @@ -93,46 +110,213 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { return nil } +// Send sends a message using Interactive Card format for markdown rendering. func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning } if msg.ChatID == "" { - return fmt.Errorf("chat ID is empty") + return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) } - payload, err := json.Marshal(map[string]string{"text": msg.Content}) + // Build interactive card with markdown content + cardContent, err := buildMarkdownCard(msg.Content) if err != nil { - return fmt.Errorf("failed to marshal feishu content: %w", err) + return fmt.Errorf("feishu send: card build failed: %w", err) + } + return c.sendCard(ctx, msg.ChatID, cardContent) +} + +// EditMessage implements channels.MessageEditor. +// Uses Message.Patch to update an interactive card message. +func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + cardContent, err := buildMarkdownCard(content) + if err != nil { + return fmt.Errorf("feishu edit: card build failed: %w", err) + } + + req := larkim.NewPatchMessageReqBuilder(). + MessageId(messageID). + Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()). + Build() + + resp, err := c.client.Im.V1.Message.Patch(ctx, req) + if err != nil { + return fmt.Errorf("feishu edit: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + return nil +} + +// SendPlaceholder implements channels.PlaceholderCapable. +// Sends an interactive card with placeholder text and returns its message ID. +func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + if !c.config.Placeholder.Enabled { + logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{ + "chat_id": chatID, + }) + return "", nil + } + + text := c.config.Placeholder.Text + if text == "" { + text = "Thinking..." + } + + cardContent, err := buildMarkdownCard(text) + if err != nil { + return "", fmt.Errorf("feishu placeholder: card build failed: %w", err) } req := larkim.NewCreateMessageReqBuilder(). ReceiveIdType(larkim.ReceiveIdTypeChatId). Body(larkim.NewCreateMessageReqBodyBuilder(). - ReceiveId(msg.ChatID). - MsgType(larkim.MsgTypeText). - Content(string(payload)). - Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())). + ReceiveId(chatID). + MsgType(larkim.MsgTypeInteractive). + Content(cardContent). Build()). Build() resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("feishu send: %w", channels.ErrTemporary) + return "", fmt.Errorf("feishu placeholder send: %w", err) } - if !resp.Success() { - return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg) } - logger.DebugCF("feishu", "Feishu message sent", map[string]any{ - "chat_id": msg.ChatID, - }) + if resp.Data != nil && resp.Data.MessageId != nil { + return *resp.Data.MessageId, nil + } + return "", nil +} + +// ReactToMessage implements channels.ReactionCapable. +// Adds an "Pin" reaction and returns an undo function to remove it. +func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + req := larkim.NewCreateMessageReactionReqBuilder(). + MessageId(messageID). + Body(larkim.NewCreateMessageReactionReqBodyBuilder(). + ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()). + Build()). + Build() + + resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req) + if err != nil { + logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{ + "message_id": messageID, + "error": err.Error(), + }) + return func() {}, fmt.Errorf("feishu react: %w", err) + } + if !resp.Success() { + logger.ErrorCF("feishu", "Reaction API error", map[string]any{ + "message_id": messageID, + "code": resp.Code, + "msg": resp.Msg, + }) + return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + + var reactionID string + if resp.Data != nil && resp.Data.ReactionId != nil { + reactionID = *resp.Data.ReactionId + } + if reactionID == "" { + return func() {}, nil + } + + var undone atomic.Bool + undo := func() { + if !undone.CompareAndSwap(false, true) { + return + } + delReq := larkim.NewDeleteMessageReactionReqBuilder(). + MessageId(messageID). + ReactionId(reactionID). + Build() + _, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq) + } + return undo, nil +} + +// SendMedia implements channels.MediaSender. +// Uploads images/files via Feishu API then sends as messages. +func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + if msg.ChatID == "" { + return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil { + return err + } + } return nil } +// sendMediaPart resolves and sends a single media part. +func (c *FeishuChannel) sendMediaPart( + ctx context.Context, + chatID string, + part bus.MediaPart, + store media.MediaStore, +) error { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + return nil // skip this part + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("feishu", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + return nil // skip this part + } + defer file.Close() + + switch part.Type { + case "image": + err = c.sendImage(ctx, chatID, file) + default: + filename := part.Filename + if filename == "" { + filename = "file" + } + err = c.sendFile(ctx, chatID, file, filename, part.Type) + } + + if err != nil { + logger.ErrorCF("feishu", "Failed to send media", map[string]any{ + "type": part.Type, + "error": err.Error(), + }) + return fmt.Errorf("feishu send media: %w", channels.ErrTemporary) + } + return nil +} + +// --- Inbound message handling --- + func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { if event == nil || event.Event == nil || event.Event.Message == nil { return nil @@ -151,34 +335,68 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. senderID = "unknown" } - content := extractFeishuMessageContent(message) + messageType := stringValue(message.MessageType) + messageID := stringValue(message.MessageId) + rawContent := stringValue(message.Content) + + // Check allowlist early to avoid downloading media for rejected senders. + // BaseChannel.HandleMessage will check again, but this avoids wasted network I/O. + senderInfo := bus.SenderInfo{ + Platform: "feishu", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("feishu", senderID), + } + if !c.IsAllowedSender(senderInfo) { + return nil + } + + // Extract content based on message type + content := extractContent(messageType, rawContent) + + // Handle media messages (download and store) + var mediaRefs []string + if store := c.GetMediaStore(); store != nil && messageID != "" { + mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store) + } + + // Append media tags to content (like Telegram does) + content = appendMediaTags(content, messageType, mediaRefs) + if content == "" { content = "[empty message]" } metadata := map[string]string{} - messageID := "" - if mid := stringValue(message.MessageId); mid != "" { - messageID = mid + if messageID != "" { + metadata["message_id"] = messageID } - if messageType := stringValue(message.MessageType); messageType != "" { + if messageType != "" { metadata["message_type"] = messageType } - if chatType := stringValue(message.ChatType); chatType != "" { + chatType := stringValue(message.ChatType) + if chatType != "" { metadata["chat_type"] = chatType } if sender != nil && sender.TenantKey != nil { metadata["tenant_key"] = *sender.TenantKey } - chatType := stringValue(message.ChatType) var peer bus.Peer if chatType == "p2p" { peer = bus.Peer{Kind: "direct", ID: senderID} } else { peer = bus.Peer{Kind: "group", ID: chatID} + + // Check if bot was mentioned + isMentioned := c.isBotMentioned(message) + + // Strip mention placeholders from content before group trigger check + if len(message.Mentions) > 0 { + content = stripMentionPlaceholders(content, message.Mentions) + } + // In group chats, apply unified group trigger filtering - respond, cleaned := c.ShouldRespondInGroup(false, content) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) if !respond { return nil } @@ -186,22 +404,398 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. } logger.InfoCF("feishu", "Feishu message received", map[string]any{ - "sender_id": senderID, - "chat_id": chatID, - "preview": utils.Truncate(content, 80), + "sender_id": senderID, + "chat_id": chatID, + "message_id": messageID, + "preview": utils.Truncate(content, 80), }) - senderInfo := bus.SenderInfo{ - Platform: "feishu", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("feishu", senderID), + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo) + return nil +} + +// --- Internal helpers --- + +// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id. +func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error { + resp, err := c.client.Do(ctx, &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: "/open-apis/bot/v3/info", + SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant}, + }) + if err != nil { + return fmt.Errorf("bot info request: %w", err) } - if !c.IsAllowedSender(senderInfo) { - return nil + var result struct { + Code int `json:"code"` + Bot struct { + OpenID string `json:"open_id"` + } `json:"bot"` + } + if err := json.Unmarshal(resp.RawBody, &result); err != nil { + return fmt.Errorf("bot info parse: %w", err) + } + if result.Code != 0 { + return fmt.Errorf("bot info api error (code=%d)", result.Code) + } + if result.Bot.OpenID == "" { + return fmt.Errorf("bot info: empty open_id") } - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo) + c.botOpenID.Store(result.Bot.OpenID) + logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{ + "open_id": result.Bot.OpenID, + }) + return nil +} + +// isBotMentioned checks if the bot was @mentioned in the message. +func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool { + if message.Mentions == nil { + return false + } + + knownID, _ := c.botOpenID.Load().(string) + if knownID == "" { + logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil) + return false + } + + for _, m := range message.Mentions { + if m.Id == nil { + continue + } + if m.Id.OpenId != nil && *m.Id.OpenId == knownID { + return true + } + } + return false +} + +// extractContent extracts text content from different message types. +func extractContent(messageType, rawContent string) string { + if rawContent == "" { + return "" + } + + switch messageType { + case larkim.MsgTypeText: + var textPayload struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil { + return textPayload.Text + } + return rawContent + + case larkim.MsgTypePost: + // Pass raw JSON to LLM — structured rich text is more informative than flattened plain text + return rawContent + + case larkim.MsgTypeImage: + // Image messages don't have text content + return "" + + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + // File/audio/video messages may have a filename + name := extractFileName(rawContent) + if name != "" { + return name + } + return "" + + default: + return rawContent + } +} + +// downloadInboundMedia downloads media from inbound messages and stores in MediaStore. +func (c *FeishuChannel) downloadInboundMedia( + ctx context.Context, + chatID, messageID, messageType, rawContent string, + store media.MediaStore, +) []string { + var refs []string + scope := channels.BuildMediaScope("feishu", chatID, messageID) + + switch messageType { + case larkim.MsgTypeImage: + imageKey := extractImageKey(rawContent) + if imageKey == "" { + return nil + } + ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope) + if ref != "" { + refs = append(refs, ref) + } + + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + fileKey := extractFileKey(rawContent) + if fileKey == "" { + return nil + } + // Derive a fallback extension from the message type. + var ext string + switch messageType { + case larkim.MsgTypeAudio: + ext = ".ogg" + case larkim.MsgTypeMedia: + ext = ".mp4" + default: + ext = "" // generic file — rely on resp.FileName + } + ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope) + if ref != "" { + refs = append(refs, ref) + } + } + + return refs +} + +// downloadResource downloads a message resource (image/file) from Feishu, +// writes it to the project media directory, and stores the reference in MediaStore. +// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension. +func (c *FeishuChannel) downloadResource( + ctx context.Context, + messageID, fileKey, resourceType, fallbackExt string, + store media.MediaStore, + scope string, +) string { + req := larkim.NewGetMessageResourceReqBuilder(). + MessageId(messageID). + FileKey(fileKey). + Type(resourceType). + Build() + + resp, err := c.client.Im.V1.MessageResource.Get(ctx, req) + if err != nil { + logger.ErrorCF("feishu", "Failed to download resource", map[string]any{ + "message_id": messageID, + "file_key": fileKey, + "error": err.Error(), + }) + return "" + } + if !resp.Success() { + logger.ErrorCF("feishu", "Resource download api error", map[string]any{ + "code": resp.Code, + "msg": resp.Msg, + }) + return "" + } + + if resp.File == nil { + return "" + } + // Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body). + if closer, ok := resp.File.(io.Closer); ok { + defer closer.Close() + } + + filename := resp.FileName + if filename == "" { + filename = fileKey + } + // If filename still has no extension, append the fallback (like Telegram's ext parameter). + if filepath.Ext(filename) == "" && fallbackExt != "" { + filename += fallbackExt + } + + // Write to the shared picoclaw_media directory using a unique name to avoid collisions. + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil { + logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{ + "error": mkdirErr.Error(), + }) + return "" + } + ext := filepath.Ext(filename) + localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext)) + + out, err := os.Create(localPath) + if err != nil { + logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{ + "error": err.Error(), + }) + return "" + } + + if _, copyErr := io.Copy(out, resp.File); copyErr != nil { + out.Close() + os.Remove(localPath) + logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{ + "error": copyErr.Error(), + }) + return "" + } + out.Close() + + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "feishu", + }, scope) + if err != nil { + logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{ + "file_key": fileKey, + "error": err.Error(), + }) + os.Remove(localPath) + return "" + } + + return ref +} + +// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]"). +func appendMediaTags(content, messageType string, mediaRefs []string) string { + if len(mediaRefs) == 0 { + return content + } + + var tag string + switch messageType { + case larkim.MsgTypeImage: + tag = "[image: photo]" + case larkim.MsgTypeAudio: + tag = "[audio]" + case larkim.MsgTypeMedia: + tag = "[video]" + case larkim.MsgTypeFile: + tag = "[file]" + default: + tag = "[attachment]" + } + + if content == "" { + return tag + } + return content + " " + tag +} + +// sendCard sends an interactive card message to a chat. +func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error { + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeInteractive). + Content(cardContent). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu send card: %w", channels.ErrTemporary) + } + + if !resp.Success() { + return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + } + + logger.DebugCF("feishu", "Feishu card message sent", map[string]any{ + "chat_id": chatID, + }) + + return nil +} + +// sendImage uploads an image and sends it as a message. +func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error { + // Upload image to get image_key + uploadReq := larkim.NewCreateImageReqBuilder(). + Body(larkim.NewCreateImageReqBodyBuilder(). + ImageType("message"). + Image(file). + Build()). + Build() + + uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("feishu image upload: %w", err) + } + if !uploadResp.Success() { + return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg) + } + if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil { + return fmt.Errorf("feishu image upload: no image_key returned") + } + + imageKey := *uploadResp.Data.ImageKey + + // Send image message + content, _ := json.Marshal(map[string]string{"image_key": imageKey}) + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeImage). + Content(string(content)). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu image send: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + return nil +} + +// sendFile uploads a file and sends it as a message. +func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error { + // Map part type to Feishu file type + feishuFileType := "stream" + switch fileType { + case "audio": + feishuFileType = "opus" + case "video": + feishuFileType = "mp4" + } + + // Upload file to get file_key + uploadReq := larkim.NewCreateFileReqBuilder(). + Body(larkim.NewCreateFileReqBodyBuilder(). + FileType(feishuFileType). + FileName(filename). + File(file). + Build()). + Build() + + uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("feishu file upload: %w", err) + } + if !uploadResp.Success() { + return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg) + } + if uploadResp.Data == nil || uploadResp.Data.FileKey == nil { + return fmt.Errorf("feishu file upload: no file_key returned") + } + + fileKey := *uploadResp.Data.FileKey + + // Send file message + content, _ := json.Marshal(map[string]string{"file_key": fileKey}) + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeFile). + Content(string(content)). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu file send: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg) + } return nil } @@ -222,20 +816,3 @@ func extractFeishuSenderID(sender *larkim.EventSender) string { return "" } - -func extractFeishuMessageContent(message *larkim.EventMessage) string { - if message == nil || message.Content == nil || *message.Content == "" { - return "" - } - - if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText { - var textPayload struct { - Text string `json:"text"` - } - if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil { - return textPayload.Text - } - } - - return *message.Content -} diff --git a/pkg/channels/feishu/feishu_64_test.go b/pkg/channels/feishu/feishu_64_test.go new file mode 100644 index 000000000..dc3eab2e7 --- /dev/null +++ b/pkg/channels/feishu/feishu_64_test.go @@ -0,0 +1,256 @@ +//go:build amd64 || arm64 || riscv64 || mips64 || ppc64 + +package feishu + +import ( + "testing" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func TestExtractContent(t *testing.T) { + tests := []struct { + name string + messageType string + rawContent string + want string + }{ + { + name: "text message", + messageType: "text", + rawContent: `{"text": "hello world"}`, + want: "hello world", + }, + { + name: "text message invalid JSON", + messageType: "text", + rawContent: `not json`, + want: "not json", + }, + { + name: "post message returns raw JSON", + messageType: "post", + rawContent: `{"title": "test post"}`, + want: `{"title": "test post"}`, + }, + { + name: "image message returns empty", + messageType: "image", + rawContent: `{"image_key": "img_xxx"}`, + want: "", + }, + { + name: "file message with filename", + messageType: "file", + rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`, + want: "report.pdf", + }, + { + name: "file message without filename", + messageType: "file", + rawContent: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "audio message with filename", + messageType: "audio", + rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`, + want: "recording.ogg", + }, + { + name: "media message with filename", + messageType: "media", + rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`, + want: "video.mp4", + }, + { + name: "unknown message type returns raw", + messageType: "sticker", + rawContent: `{"sticker_id": "sticker_xxx"}`, + want: `{"sticker_id": "sticker_xxx"}`, + }, + { + name: "empty raw content", + messageType: "text", + rawContent: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractContent(tt.messageType, tt.rawContent) + if got != tt.want { + t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want) + } + }) + } +} + +func TestAppendMediaTags(t *testing.T) { + tests := []struct { + name string + content string + messageType string + mediaRefs []string + want string + }{ + { + name: "no refs returns content unchanged", + content: "hello", + messageType: "image", + mediaRefs: nil, + want: "hello", + }, + { + name: "empty refs returns content unchanged", + content: "hello", + messageType: "image", + mediaRefs: []string{}, + want: "hello", + }, + { + name: "image with content", + content: "check this", + messageType: "image", + mediaRefs: []string{"ref1"}, + want: "check this [image: photo]", + }, + { + name: "image empty content", + content: "", + messageType: "image", + mediaRefs: []string{"ref1"}, + want: "[image: photo]", + }, + { + name: "audio", + content: "listen", + messageType: "audio", + mediaRefs: []string{"ref1"}, + want: "listen [audio]", + }, + { + name: "media/video", + content: "watch", + messageType: "media", + mediaRefs: []string{"ref1"}, + want: "watch [video]", + }, + { + name: "file", + content: "report.pdf", + messageType: "file", + mediaRefs: []string{"ref1"}, + want: "report.pdf [file]", + }, + { + name: "unknown type", + content: "something", + messageType: "sticker", + mediaRefs: []string{"ref1"}, + want: "something [attachment]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs) + if got != tt.want { + t.Errorf( + "appendMediaTags(%q, %q, %v) = %q, want %q", + tt.content, + tt.messageType, + tt.mediaRefs, + got, + tt.want, + ) + } + }) + } +} + +func TestExtractFeishuSenderID(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + sender *larkim.EventSender + want string + }{ + { + name: "nil sender", + sender: nil, + want: "", + }, + { + name: "nil sender ID", + sender: &larkim.EventSender{SenderId: nil}, + want: "", + }, + { + name: "userId preferred", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr("u_abc123"), + OpenId: strPtr("ou_def456"), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "u_abc123", + }, + { + name: "openId fallback", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr("ou_def456"), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "ou_def456", + }, + { + name: "unionId fallback", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr(""), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "on_ghi789", + }, + { + name: "all empty strings", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr(""), + UnionId: strPtr(""), + }, + }, + want: "", + }, + { + name: "nil userId pointer falls through", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: nil, + OpenId: strPtr("ou_def456"), + UnionId: nil, + }, + }, + want: "ou_def456", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFeishuSenderID(tt.sender) + if got != tt.want { + t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 155e50b39..fdd6d0c1f 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -255,6 +255,10 @@ func (m *Manager) initChannels() error { m.initChannel("wecom", "WeCom") } + if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" { + m.initChannel("wecom_aibot", "WeCom AI Bot") + } + if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" { m.initChannel("wecom_app", "WeCom App") } @@ -539,86 +543,88 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork }) } -func (m *Manager) dispatchOutbound(ctx context.Context) { - logger.InfoC("channels", "Outbound dispatcher started") +func dispatchLoop[M any]( + ctx context.Context, + m *Manager, + subscribe func(context.Context) (M, bool), + getChannel func(M) string, + enqueue func(context.Context, *channelWorker, M) bool, + startMsg, stopMsg, unknownMsg, noWorkerMsg string, +) { + logger.InfoC("channels", startMsg) for { - msg, ok := m.bus.SubscribeOutbound(ctx) + msg, ok := subscribe(ctx) if !ok { - logger.InfoC("channels", "Outbound dispatcher stopped") + logger.InfoC("channels", stopMsg) return } + channel := getChannel(msg) + // Silently skip internal channels - if constants.IsInternalChannel(msg.Channel) { + if constants.IsInternalChannel(channel) { continue } m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] + _, exists := m.channels[channel] + w, wExists := m.workers[channel] m.mu.RUnlock() if !exists { - logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{ - "channel": msg.Channel, - }) + logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel}) continue } if wExists && w != nil { - select { - case w.queue <- msg: - case <-ctx.Done(): + if !enqueue(ctx, w, msg) { return } } else if exists { - logger.WarnCF("channels", "Channel has no active worker, skipping message", map[string]any{ - "channel": msg.Channel, - }) + logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel}) } } } +func (m *Manager) dispatchOutbound(ctx context.Context) { + dispatchLoop( + ctx, m, + m.bus.SubscribeOutbound, + func(msg bus.OutboundMessage) string { return msg.Channel }, + func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool { + select { + case w.queue <- msg: + return true + case <-ctx.Done(): + return false + } + }, + "Outbound dispatcher started", + "Outbound dispatcher stopped", + "Unknown channel for outbound message", + "Channel has no active worker, skipping message", + ) +} + func (m *Manager) dispatchOutboundMedia(ctx context.Context) { - logger.InfoC("channels", "Outbound media dispatcher started") - - for { - msg, ok := m.bus.SubscribeOutboundMedia(ctx) - if !ok { - logger.InfoC("channels", "Outbound media dispatcher stopped") - return - } - - // Silently skip internal channels - if constants.IsInternalChannel(msg.Channel) { - continue - } - - m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] - m.mu.RUnlock() - - if !exists { - logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{ - "channel": msg.Channel, - }) - continue - } - - if wExists && w != nil { + dispatchLoop( + ctx, m, + m.bus.SubscribeOutboundMedia, + func(msg bus.OutboundMediaMessage) string { return msg.Channel }, + func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool { select { case w.mediaQueue <- msg: + return true case <-ctx.Done(): - return + return false } - } else if exists { - logger.WarnCF("channels", "Channel has no active worker, skipping media message", map[string]any{ - "channel": msg.Channel, - }) - } - } + }, + "Outbound media dispatcher started", + "Outbound media dispatcher stopped", + "Unknown channel for outbound media message", + "Channel has no active worker, skipping media message", + ) } // runMediaWorker processes outbound media messages for a single channel. diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index a11cf53b8..f328f32b8 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,12 +7,12 @@ import ( "net/url" "os" "regexp" + "slices" "strconv" "strings" "time" "github.com/mymmrac/telego" - "github.com/mymmrac/telego/telegohandler" th "github.com/mymmrac/telego/telegohandler" tu "github.com/mymmrac/telego/telegoutil" @@ -41,7 +41,7 @@ var ( type TelegramChannel struct { *channels.BaseChannel bot *telego.Bot - bh *telegohandler.BotHandler + bh *th.BotHandler commands TelegramCommander config *config.Config chatIDs map[string]int64 @@ -72,6 +72,10 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann })) } + if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" { + opts = append(opts, telego.WithAPIServer(baseURL)) + } + bot, err := telego.NewBot(telegramCfg.Token, opts...) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) @@ -101,6 +105,12 @@ func (c *TelegramChannel) Start(ctx context.Context) error { c.ctx, c.cancel = context.WithCancel(ctx) + if err := c.initBotCommands(c.ctx); err != nil { + logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{ + "error": err.Error(), + }) + } + updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) @@ -109,20 +119,19 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to start long polling: %w", err) } - bh, err := telegohandler.NewBotHandler(c.bot, updates) + bh, err := th.NewBotHandler(c.bot, updates) if err != nil { c.cancel() return fmt.Errorf("failed to create bot handler: %w", err) } c.bh = bh - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - c.commands.Help(ctx, message) - return nil - }, th.CommandEqual("help")) bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { return c.commands.Start(ctx, message) }, th.CommandEqual("start")) + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Help(ctx, message) + }, th.CommandEqual("help")) bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { return c.commands.Show(ctx, message) @@ -141,7 +150,13 @@ func (c *TelegramChannel) Start(ctx context.Context) error { "username": c.bot.Username(), }) - go bh.Start() + go func() { + if err = bh.Start(); err != nil { + logger.ErrorCF("telegram", "Bot handler failed", map[string]any{ + "error": err.Error(), + }) + } + }() return nil } @@ -152,7 +167,7 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { // Stop the bot handler if c.bh != nil { - c.bh.Stop() + _ = c.bh.StopWithContext(ctx) } // Cancel our context (stops long polling) @@ -163,6 +178,51 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { return nil } +func (c *TelegramChannel) initBotCommands(ctx context.Context) error { + currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{ + Scope: tu.ScopeDefault(), + }) + if err != nil { + return fmt.Errorf("get commands: %w", err) + } + + commands := []telego.BotCommand{ + { + Command: "start", + Description: "Start the bot", + }, + { + Command: "help", + Description: "Show a help message", + }, + { + Command: "show", + Description: "Show current configuration", + }, + { + Command: "list", + Description: "List available options", + }, + } + + // Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed + if !slices.Equal(currentCommands, commands) { + logger.InfoC("telegram", "Updating bot commands") + + err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ + Commands: commands, + Scope: tu.ScopeDefault(), + }) + if err != nil { + return fmt.Errorf("set commands: %w", err) + } + } else { + logger.DebugC("telegram", "Bot commands are up to date") + } + + return nil +} + func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go new file mode 100644 index 000000000..6c5aca40b --- /dev/null +++ b/pkg/channels/wecom/aibot.go @@ -0,0 +1,1014 @@ +package wecom + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// WeComAIBotChannel implements the Channel interface for WeCom AI Bot (企业微信智能机器人) +type WeComAIBotChannel struct { + *channels.BaseChannel + config config.WeComAIBotConfig + ctx context.Context + cancel context.CancelFunc + streamTasks map[string]*streamTask // streamID -> task (for poll lookups) + chatTasks map[string][]*streamTask // chatID -> in-flight tasks queue (FIFO) + taskMu sync.RWMutex +} + +// streamTask represents a streaming task for AI Bot. +// +// Mutable fields (Finished, StreamClosed, StreamClosedAt) must be read/written +// while holding WeComAIBotChannel.taskMu. Immutable fields (StreamID, ChatID, +// ResponseURL, Question, CreatedTime, Deadline, answerCh, ctx, cancel) are set +// once at creation and never modified, so they are safe to read without a lock. +type streamTask struct { + // immutable after creation + StreamID string + ChatID string // used by Send() to find this task + ResponseURL string // temporary URL for proactive reply (valid 1 hour, use once) + Question string + CreatedTime time.Time + Deadline time.Time // ~30s, we close the stream here and switch to response_url + answerCh chan string // receives agent reply from Send() + ctx context.Context // canceled when task is removed; used to interrupt the agent goroutine + cancel context.CancelFunc // call on task removal to cancel ctx + + // mutable — guarded by WeComAIBotChannel.taskMu + StreamClosed bool // stream returned finish:true; waiting for agent to reply via response_url + StreamClosedAt time.Time // set when StreamClosed becomes true; used for accelerated cleanup + Finished bool // fully done +} + +// WeComAIBotMessage represents the decrypted JSON message from WeCom AI Bot +// Ref: https://developer.work.weixin.qq.com/document/path/100719 +type WeComAIBotMessage struct { + MsgID string `json:"msgid"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid"` // only for group chat + ChatType string `json:"chattype"` // "single" or "group" + From struct { + UserID string `json:"userid"` + } `json:"from"` + ResponseURL string `json:"response_url"` // temporary URL for proactive reply + MsgType string `json:"msgtype"` + // text message + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + // stream polling refresh + Stream *struct { + ID string `json:"id"` + } `json:"stream,omitempty"` + // image message + Image *struct { + URL string `json:"url"` + } `json:"image,omitempty"` + // mixed message (text + image) + Mixed *struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text *struct { + Content string `json:"content"` + } `json:"text,omitempty"` + Image *struct { + URL string `json:"url"` + } `json:"image,omitempty"` + } `json:"msg_item"` + } `json:"mixed,omitempty"` + // event field + Event *struct { + EventType string `json:"eventtype"` + } `json:"event,omitempty"` +} + +// WeComAIBotMsgItemImage holds the image payload inside a stream message item. +type WeComAIBotMsgItemImage struct { + Base64 string `json:"base64"` + MD5 string `json:"md5"` +} + +// WeComAIBotMsgItem is a single item inside a stream's msg_item list. +type WeComAIBotMsgItem struct { + MsgType string `json:"msgtype"` + Image *WeComAIBotMsgItemImage `json:"image,omitempty"` +} + +// WeComAIBotStreamInfo represents the detailed stream content in streaming responses. +type WeComAIBotStreamInfo struct { + ID string `json:"id"` + Finish bool `json:"finish"` + Content string `json:"content,omitempty"` + MsgItem []WeComAIBotMsgItem `json:"msg_item,omitempty"` +} + +// WeComAIBotStreamResponse represents the streaming response format +type WeComAIBotStreamResponse struct { + MsgType string `json:"msgtype"` + Stream WeComAIBotStreamInfo `json:"stream"` +} + +// WeComAIBotEncryptedResponse represents the encrypted response wrapper +// Fields match WXBizJsonMsgCrypt.generate() in Python SDK +type WeComAIBotEncryptedResponse struct { + Encrypt string `json:"encrypt"` + MsgSignature string `json:"msgsignature"` + Timestamp string `json:"timestamp"` + Nonce string `json:"nonce"` +} + +// NewWeComAIBotChannel creates a new WeCom AI Bot channel instance +func NewWeComAIBotChannel( + cfg config.WeComAIBotConfig, + messageBus *bus.MessageBus, +) (*WeComAIBotChannel, error) { + if cfg.Token == "" || cfg.EncodingAESKey == "" { + return nil, fmt.Errorf("token and encoding_aes_key are required for WeCom AI Bot") + } + + base := channels.NewBaseChannel("wecom_aibot", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(2048), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) + + return &WeComAIBotChannel{ + BaseChannel: base, + config: cfg, + streamTasks: make(map[string]*streamTask), + chatTasks: make(map[string][]*streamTask), + }, nil +} + +// Name returns the channel name +func (c *WeComAIBotChannel) Name() string { + return "wecom_aibot" +} + +// Start initializes the WeCom AI Bot channel +func (c *WeComAIBotChannel) Start(ctx context.Context) error { + logger.InfoC("wecom_aibot", "Starting WeCom AI Bot channel...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Start cleanup goroutine for old tasks + go c.cleanupLoop() + + c.SetRunning(true) + logger.InfoC("wecom_aibot", "WeCom AI Bot channel started") + + return nil +} + +// Stop gracefully stops the WeCom AI Bot channel +func (c *WeComAIBotChannel) Stop(ctx context.Context) error { + logger.InfoC("wecom_aibot", "Stopping WeCom AI Bot channel...") + + if c.cancel != nil { + c.cancel() + } + + c.SetRunning(false) + logger.InfoC("wecom_aibot", "WeCom AI Bot channel stopped") + return nil +} + +// Send delivers the agent reply into the active streamTask for msg.ChatID. +// It writes into the earliest unfinished task in the queue (FIFO per chatID). +// If the stream has already closed (deadline passed), it posts directly to response_url. +func (c *WeComAIBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + c.taskMu.Lock() + queue := c.chatTasks[msg.ChatID] + // Only compact Finished tasks at the head of the queue. + // Tasks that are Finished in the middle are NOT removed here: doing a full + // scan on every Send() call would be O(n) and is unnecessary given that + // removeTask() always splices the task out of the queue immediately. + // Any Finished task left stranded in the middle (e.g. due to an unexpected + // code path) will be collected by cleanupOldTasks. + for len(queue) > 0 && queue[0].Finished { + queue = queue[1:] + } + c.chatTasks[msg.ChatID] = queue + var task *streamTask + var streamClosed bool + var responseURL string + if len(queue) > 0 { + task = queue[0] + // Read mutable fields while holding c.taskMu to avoid data races. + streamClosed = task.StreamClosed + responseURL = task.ResponseURL + } + c.taskMu.Unlock() + + if task == nil { + logger.DebugCF( + "wecom_aibot", + "Send: no active task for chat (may have timed out)", + map[string]any{ + "chat_id": msg.ChatID, + }, + ) + return nil + } + + if streamClosed { + // Stream already ended with a "please wait" notice; send the real reply via response_url. + // Note: task.StreamID and task.ChatID are immutable, safe to read without a lock. + logger.InfoCF("wecom_aibot", "Sending reply via response_url", map[string]any{ + "stream_id": task.StreamID, + "chat_id": msg.ChatID, + }) + if responseURL != "" { + if err := c.sendViaResponseURL(responseURL, msg.Content); err != nil { + logger.ErrorCF("wecom_aibot", "Failed to send via response_url", map[string]any{ + "error": err, + "stream_id": task.StreamID, + }) + c.removeTask(task) + return fmt.Errorf("response_url delivery failed: %w", channels.ErrSendFailed) + } + } else { + logger.WarnCF("wecom_aibot", "Stream closed but no response_url available", map[string]any{ + "stream_id": task.StreamID, + }) + } + c.removeTask(task) + return nil + } + + // Stream still open: deliver via answerCh for the next poll response. + select { + case task.answerCh <- msg.Content: + case <-task.ctx.Done(): + // Task was canceled (cleanup removed it); silently drop the reply. + return nil + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +// WebhookPath returns the path for registering on the shared HTTP server +func (c *WeComAIBotChannel) WebhookPath() string { + if c.config.WebhookPath == "" { + return "/webhook/wecom-aibot" + } + return c.config.WebhookPath +} + +// ServeHTTP implements http.Handler for the shared HTTP server +func (c *WeComAIBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.handleWebhook(w, r) +} + +// HealthPath returns the health check endpoint path +func (c *WeComAIBotChannel) HealthPath() string { + return c.WebhookPath() + "/health" +} + +// HealthHandler handles health check requests +func (c *WeComAIBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + c.handleHealth(w, r) +} + +// handleWebhook handles incoming webhook requests from WeCom AI Bot +func (c *WeComAIBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Log all incoming requests for debugging + logger.DebugCF("wecom_aibot", "Received webhook request", map[string]any{ + "method": r.Method, + "path": r.URL.Path, + "query": r.URL.RawQuery, + }) + + switch r.Method { + case http.MethodGet: + // URL verification + c.handleVerification(ctx, w, r) + case http.MethodPost: + // Message callback + c.handleMessageCallback(ctx, w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleVerification handles the URL verification request from WeCom +func (c *WeComAIBotChannel) handleVerification( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, +) { + msgSignature := r.URL.Query().Get("msg_signature") + timestamp := r.URL.Query().Get("timestamp") + nonce := r.URL.Query().Get("nonce") + echostr := r.URL.Query().Get("echostr") + + logger.DebugCF("wecom_aibot", "URL verification request", map[string]any{ + "msg_signature": msgSignature, + "timestamp": timestamp, + "nonce": nonce, + }) + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + logger.ErrorC("wecom_aibot", "Signature verification failed") + http.Error(w, "Signature verification failed", http.StatusUnauthorized) + return + } + + // Decrypt echostr + // For WeCom AI Bot (智能机器人), receiveid should be empty string + decrypted, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom_aibot", "Failed to decrypt echostr", map[string]any{ + "error": err, + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Remove BOM and whitespace as per WeCom documentation + decrypted = strings.TrimPrefix(decrypted, "\ufeff") + decrypted = strings.TrimSpace(decrypted) + + logger.InfoC("wecom_aibot", "URL verification successful") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(decrypted)) +} + +// handleMessageCallback handles incoming messages from WeCom AI Bot +func (c *WeComAIBotChannel) handleMessageCallback( + ctx context.Context, + w http.ResponseWriter, + r *http.Request, +) { + msgSignature := r.URL.Query().Get("msg_signature") + timestamp := r.URL.Query().Get("timestamp") + nonce := r.URL.Query().Get("nonce") + + // Read request body (limit to 4 MB to prevent memory exhaustion) + const maxBodySize = 4 << 20 // 4 MB + body, err := io.ReadAll(io.LimitReader(r.Body, maxBodySize+1)) + if err != nil { + logger.ErrorCF("wecom_aibot", "Failed to read request body", map[string]any{ + "error": err, + }) + http.Error(w, "Failed to read body", http.StatusBadRequest) + return + } + if len(body) > maxBodySize { + http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) + return + } + + // Parse JSON body to get encrypted message + // Format: {"encrypt": "base64_encrypted_string"} + var encryptedMsg struct { + Encrypt string `json:"encrypt"` + } + if unmarshalErr := json.Unmarshal(body, &encryptedMsg); unmarshalErr != nil { + logger.ErrorCF("wecom_aibot", "Failed to parse JSON body", map[string]any{ + "error": unmarshalErr, + "body": string(body), + }) + http.Error(w, "Failed to parse JSON", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + logger.ErrorC("wecom_aibot", "Signature verification failed") + http.Error(w, "Signature verification failed", http.StatusUnauthorized) + return + } + + // Decrypt message + // For WeCom AI Bot (智能机器人), receiveid is empty string + decrypted, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom_aibot", "Failed to decrypt message", map[string]any{ + "error": err, + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Parse decrypted JSON message + var msg WeComAIBotMessage + if unmarshalErr := json.Unmarshal([]byte(decrypted), &msg); unmarshalErr != nil { + logger.ErrorCF("wecom_aibot", "Failed to parse decrypted JSON", map[string]any{ + "error": unmarshalErr, + "decrypted": decrypted, + }) + http.Error(w, "Failed to parse message", http.StatusInternalServerError) + return + } + + logger.DebugCF("wecom_aibot", "Decrypted message", map[string]any{ + "msgtype": msg.MsgType, + }) + + // Process the message and get streaming response + response := c.processMessage(ctx, msg, timestamp, nonce) + + // Check if response is empty (e.g. due to unsupported message type) + if response == "" { + response = c.encryptEmptyResponse(timestamp, nonce) + } + + // Return encrypted JSON response + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(response)) +} + +// processMessage processes the received message and returns encrypted response +func (c *WeComAIBotChannel) processMessage( + ctx context.Context, + msg WeComAIBotMessage, + timestamp, nonce string, +) string { + logger.DebugCF("wecom_aibot", "Processing message", map[string]any{ + "msgtype": msg.MsgType, + }) + + switch msg.MsgType { + case "text": + return c.handleTextMessage(ctx, msg, timestamp, nonce) + case "stream": + return c.handleStreamMessage(ctx, msg, timestamp, nonce) + case "image": + return c.handleImageMessage(ctx, msg, timestamp, nonce) + case "mixed": + return c.handleMixedMessage(ctx, msg, timestamp, nonce) + case "event": + return c.handleEventMessage(ctx, msg, timestamp, nonce) + default: + logger.WarnCF("wecom_aibot", "Unsupported message type", map[string]any{ + "msgtype": msg.MsgType, + }) + return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ + MsgType: "stream", + Stream: WeComAIBotStreamInfo{ + ID: c.generateStreamID(), + Finish: true, + Content: "Unsupported message type: " + msg.MsgType, + }, + }) + } +} + +// handleTextMessage handles text messages by starting a new streaming task +func (c *WeComAIBotChannel) handleTextMessage( + ctx context.Context, + msg WeComAIBotMessage, + timestamp, nonce string, +) string { + if msg.Text == nil { + logger.ErrorC("wecom_aibot", "text message missing text field") + return c.encryptEmptyResponse(timestamp, nonce) + } + + content := msg.Text.Content + userID := msg.From.UserID + if userID == "" { + userID = "unknown" + } + + // chatID: group chat uses chatid, single chat uses userid + chatID := msg.ChatID + if chatID == "" { + chatID = userID + } + + streamID := c.generateStreamID() + + // WeCom stops sending stream-refresh callbacks after 6 minutes. + // Set a slightly shorter deadline so we can send a timeout notice before it gives up. + deadline := time.Now().Add(30 * time.Second) + + // Each task gets its own context derived from the channel lifetime context. + // Canceling taskCancel interrupts the agent goroutine when the task is removed. + taskCtx, taskCancel := context.WithCancel(c.ctx) + + task := &streamTask{ + StreamID: streamID, + ChatID: chatID, + ResponseURL: msg.ResponseURL, + Question: content, + CreatedTime: time.Now(), + Deadline: deadline, + Finished: false, + answerCh: make(chan string, 1), + ctx: taskCtx, + cancel: taskCancel, + } + + c.taskMu.Lock() + c.streamTasks[streamID] = task + c.chatTasks[chatID] = append(c.chatTasks[chatID], task) + c.taskMu.Unlock() + + // Publish to agent asynchronously; agent will call Send() with reply. + // Use task.ctx (not c.ctx) so the agent goroutine is canceled when the task is removed. + go func() { + sender := bus.SenderInfo{ + Platform: "wecom_aibot", + PlatformID: userID, + CanonicalID: identity.BuildCanonicalID("wecom_aibot", userID), + DisplayName: userID, + } + peerKind := "direct" + if msg.ChatType == "group" { + peerKind = "group" + } + peer := bus.Peer{Kind: peerKind, ID: chatID} + metadata := map[string]string{ + "channel": "wecom_aibot", + "chat_type": msg.ChatType, + "msg_type": "text", + "msgid": msg.MsgID, + "aibotid": msg.AIBotID, + "stream_id": streamID, + "response_url": msg.ResponseURL, + } + c.HandleMessage(task.ctx, peer, msg.MsgID, userID, chatID, + content, nil, metadata, sender) + }() + + // Return first streaming response immediately (finish=false, content empty) + return c.getStreamResponse(task, timestamp, nonce) +} + +// handleStreamMessage handles stream polling requests +func (c *WeComAIBotChannel) handleStreamMessage( + ctx context.Context, + msg WeComAIBotMessage, + timestamp, nonce string, +) string { + if msg.Stream == nil { + logger.ErrorC("wecom_aibot", "Stream message missing stream field") + return c.encryptEmptyResponse(timestamp, nonce) + } + + streamID := msg.Stream.ID + + c.taskMu.RLock() + task, exists := c.streamTasks[streamID] + c.taskMu.RUnlock() + + if !exists { + logger.DebugCF( + "wecom_aibot", + "Stream task not found (may be from previous session)", + map[string]any{ + "stream_id": streamID, + }, + ) + return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{ + MsgType: "stream", + Stream: WeComAIBotStreamInfo{ + ID: streamID, + Finish: true, + Content: "Task not found or already finished. Please resend your message to start a new session.", + }, + }) + } + + // Get next response + return c.getStreamResponse(task, timestamp, nonce) +} + +// handleImageMessage handles image messages +func (c *WeComAIBotChannel) handleImageMessage( + ctx context.Context, + msg WeComAIBotMessage, + timestamp, nonce string, +) string { + logger.WarnC("wecom_aibot", "Image message type not yet fully implemented") + if msg.Image == nil { + logger.ErrorC("wecom_aibot", "Image message missing image field") + return c.encryptEmptyResponse(timestamp, nonce) + } + + imageURL := msg.Image.URL + + // For now, just acknowledge receipt without echoing the image + return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ + MsgType: "stream", + Stream: WeComAIBotStreamInfo{ + ID: c.generateStreamID(), + Finish: true, + Content: fmt.Sprintf( + "Image received (URL: %s), but image messages are not yet supported", + imageURL, + ), + }, + }) +} + +// handleMixedMessage handles mixed (text + image) messages +func (c *WeComAIBotChannel) handleMixedMessage( + ctx context.Context, + msg WeComAIBotMessage, + timestamp, nonce string, +) string { + logger.WarnC("wecom_aibot", "Mixed message type not yet fully implemented") + return c.encryptResponse("", timestamp, nonce, WeComAIBotStreamResponse{ + MsgType: "stream", + Stream: WeComAIBotStreamInfo{ + ID: c.generateStreamID(), + Finish: true, + Content: "Mixed message type is not yet supported", + }, + }) +} + +// handleEventMessage handles event messages +func (c *WeComAIBotChannel) handleEventMessage( + ctx context.Context, + msg WeComAIBotMessage, + timestamp, nonce string, +) string { + eventType := "" + if msg.Event != nil { + eventType = msg.Event.EventType + } + logger.DebugCF("wecom_aibot", "Received event", map[string]any{ + "event_type": eventType, + }) + + // Send welcome message when user opens the chat window + if eventType == "enter_chat" && c.config.WelcomeMessage != "" { + streamID := c.generateStreamID() + return c.encryptResponse(streamID, timestamp, nonce, WeComAIBotStreamResponse{ + MsgType: "stream", + Stream: WeComAIBotStreamInfo{ + ID: streamID, + Finish: true, + Content: c.config.WelcomeMessage, + }, + }) + } + + return c.encryptEmptyResponse(timestamp, nonce) +} + +// getStreamResponse gets the next streaming response for a task. +// - If agent replied: return finish=true with the real answer. +// - If deadline passed: return finish=true with a "please wait" notice, keep task alive for response_url. +// - Otherwise: return finish=false (empty), client will poll again. +func (c *WeComAIBotChannel) getStreamResponse(task *streamTask, timestamp, nonce string) string { + var content string + var finish bool + var closeStreamOnly bool // close stream but do NOT remove task (response_url still pending) + + select { + case answer := <-task.answerCh: + // Agent replied before deadline — normal finish. + content = answer + finish = true + default: + if time.Now().After(task.Deadline) { + // Deadline reached: close the stream with a notice, then wait for agent via response_url. + content = "⏳ Processing, please wait. The results will be sent shortly." + finish = true + closeStreamOnly = true + logger.InfoCF( + "wecom_aibot", + "Stream deadline reached, switching to response_url mode", + map[string]any{ + "stream_id": task.StreamID, + "chat_id": task.ChatID, + "response_url": task.ResponseURL != "", + }, + ) + } + // else: still waiting, return finish=false + } + + if finish && !closeStreamOnly { + // Normal finish: remove from all maps. + c.removeTask(task) + } else if closeStreamOnly { + // Mark stream as closed and remove from streamTasks under a single lock + // to keep StreamClosed/StreamClosedAt consistent with map membership. + c.taskMu.Lock() + task.StreamClosed = true + task.StreamClosedAt = time.Now() + delete(c.streamTasks, task.StreamID) + c.taskMu.Unlock() + } + + response := WeComAIBotStreamResponse{ + MsgType: "stream", + Stream: WeComAIBotStreamInfo{ + ID: task.StreamID, + Finish: finish, + Content: content, + }, + } + + return c.encryptResponse(task.StreamID, timestamp, nonce, response) +} + +// removeTask removes a task from both streamTasks and chatTasks, marks it finished, +// and cancels its context to interrupt the associated agent goroutine. +func (c *WeComAIBotChannel) removeTask(task *streamTask) { + // Cancel first so the agent goroutine stops as soon as possible, + // before we acquire the write lock. + task.cancel() + + c.taskMu.Lock() + task.Finished = true // written under c.taskMu, consistent with all readers + delete(c.streamTasks, task.StreamID) + queue := c.chatTasks[task.ChatID] + for i, t := range queue { + if t == task { + c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...) + break + } + } + if len(c.chatTasks[task.ChatID]) == 0 { + delete(c.chatTasks, task.ChatID) + } + c.taskMu.Unlock() +} + +// sendViaResponseURL posts a markdown reply to the WeCom response_url. +// response_url is valid for 1 hour and can only be used once per callback. +// Returned errors are wrapped with channels.ErrRateLimit, channels.ErrTemporary, +// or channels.ErrSendFailed so the manager can apply the right retry policy. +func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) error { + payload := map[string]any{ + "msgtype": "markdown", + "markdown": map[string]string{ + "content": content, + }, + } + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, responseURL, bytes.NewBuffer(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json; charset=utf-8") + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("post to response_url failed: %w: %w", channels.ErrTemporary, err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return nil + } + + respBody, _ := io.ReadAll(resp.Body) + switch { + case resp.StatusCode == http.StatusTooManyRequests: + return fmt.Errorf("response_url rate limited (%d): %s: %w", + resp.StatusCode, respBody, channels.ErrRateLimit) + case resp.StatusCode >= 500: + return fmt.Errorf("response_url server error (%d): %s: %w", + resp.StatusCode, respBody, channels.ErrTemporary) + default: + return fmt.Errorf("response_url returned %d: %s: %w", + resp.StatusCode, respBody, channels.ErrSendFailed) + } +} + +// encryptResponse encrypts a streaming response +func (c *WeComAIBotChannel) encryptResponse( + streamID, timestamp, nonce string, + response WeComAIBotStreamResponse, +) string { + // Marshal response to JSON + plaintext, err := json.Marshal(response) + if err != nil { + logger.ErrorCF("wecom_aibot", "Failed to marshal response", map[string]any{ + "error": err, + }) + return "" + } + + logger.DebugCF("wecom_aibot", "Encrypting response", map[string]any{ + "stream_id": streamID, + "finish": response.Stream.Finish, + "preview": utils.Truncate(response.Stream.Content, 100), + }) + + // Encrypt message + encrypted, err := c.encryptMessage(string(plaintext), "") + if err != nil { + logger.ErrorCF("wecom_aibot", "Failed to encrypt message", map[string]any{ + "error": err, + }) + return "" + } + + // Generate signature + signature := computeSignature(c.config.Token, timestamp, nonce, encrypted) + + // Build encrypted response + encryptedResp := WeComAIBotEncryptedResponse{ + Encrypt: encrypted, + MsgSignature: signature, + Timestamp: timestamp, + Nonce: nonce, + } + + respJSON, err := json.Marshal(encryptedResp) + if err != nil { + logger.ErrorCF("wecom_aibot", "Failed to marshal encrypted response", map[string]any{ + "error": err, + }) + return "" + } + + logger.DebugCF("wecom_aibot", "Response encrypted", map[string]any{ + "stream_id": streamID, + }) + + return string(respJSON) +} + +// encryptEmptyResponse returns a minimal valid encrypted response +func (c *WeComAIBotChannel) encryptEmptyResponse(timestamp, nonce string) string { + // Construct a zero-value stream response and encrypt it so that + // WeCom always receives a syntactically valid encrypted JSON object. + emptyResp := WeComAIBotStreamResponse{} + return c.encryptResponse("", timestamp, nonce, emptyResp) +} + +// encryptMessage encrypts a plain text message for WeCom AI Bot +func (c *WeComAIBotChannel) encryptMessage(plaintext, receiveid string) (string, error) { + aesKey, err := decodeWeComAESKey(c.config.EncodingAESKey) + if err != nil { + return "", err + } + + frame, err := packWeComFrame(plaintext, receiveid) + if err != nil { + return "", err + } + + // PKCS7 padding then AES-CBC encrypt + paddedFrame := pkcs7Pad(frame, blockSize) + ciphertext, err := encryptAESCBC(aesKey, paddedFrame) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// generateStreamID generates a random stream ID +func (c *WeComAIBotChannel) generateStreamID() string { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, 10) + for i := range b { + n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + b[i] = letters[n.Int64()] + } + return string(b) +} + +// cleanupLoop periodically cleans up old streaming tasks +func (c *WeComAIBotChannel) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.cleanupOldTasks() + case <-c.ctx.Done(): + return + } + } +} + +// cleanupOldTasks removes tasks that have exceeded their expected lifetime: +// - Active tasks (in streamTasks): cleaned up after 1 hour (response_url validity window). +// - StreamClosed tasks (in chatTasks only): cleaned up after streamClosedGracePeriod. +// These tasks are waiting for the agent to call Send() via response_url. If the agent +// crashes or times out without calling Send(), we must not let them accumulate indefinitely. +// The grace period is generous enough to cover typical LLM latency but far shorter than 1 hour, +// preventing chatTasks from filling up when many requests time out in quick succession. +const ( + streamClosedGracePeriod = 10 * time.Minute // max wait for agent after stream closes + taskMaxLifetime = 1 * time.Hour // absolute max (≈ response_url validity) +) + +func (c *WeComAIBotChannel) cleanupOldTasks() { + c.taskMu.Lock() + defer c.taskMu.Unlock() + + now := time.Now() + cutoff := now.Add(-taskMaxLifetime) + for id, task := range c.streamTasks { + if task.CreatedTime.Before(cutoff) { + delete(c.streamTasks, id) + task.cancel() // interrupt agent goroutine still waiting for LLM + queue := c.chatTasks[task.ChatID] + for i, t := range queue { + if t == task { + c.chatTasks[task.ChatID] = append(queue[:i], queue[i+1:]...) + break + } + } + if len(c.chatTasks[task.ChatID]) == 0 { + delete(c.chatTasks, task.ChatID) + } + logger.DebugCF("wecom_aibot", "Cleaned up expired task", map[string]any{ + "stream_id": id, + }) + } + } + // Clean up StreamClosed tasks from chatTasks. + // Two expiry conditions are checked: + // 1. Absolute expiry: task was created more than taskMaxLifetime ago. + // 2. Grace expiry: stream closed more than streamClosedGracePeriod ago + // (agent had enough time to reply; it is not coming back). + for chatID, queue := range c.chatTasks { + filtered := queue[:0] + for i, t := range queue { + absoluteExpired := t.CreatedTime.Before(cutoff) + graceExpired := t.StreamClosed && + !t.StreamClosedAt.IsZero() && + t.StreamClosedAt.Before(now.Add(-streamClosedGracePeriod)) + if t.Finished { + // Finished tasks should have been removed by removeTask(). + // Finding one here (especially not at position 0) means an + // unexpected code path left it stranded, causing the queue to + // grow silently. Log a warning so it is visible, then drop it. + if i > 0 { + logger.WarnCF("wecom_aibot", + "Found stranded Finished task in the middle of chatTasks queue; "+ + "this should not happen — removeTask() should have spliced it out", + map[string]any{ + "chat_id": chatID, + "stream_id": t.StreamID, + "position": i, + }) + } + // The task is already finished; its context was already canceled + // by removeTask(), so no further action is required. + continue + } else if !absoluteExpired && !graceExpired { + filtered = append(filtered, t) + } else { + t.cancel() // cancel any lingering agent goroutine + } + } + if len(filtered) == 0 { + delete(c.chatTasks, chatID) + } else { + c.chatTasks[chatID] = filtered + } + } +} + +// handleHealth handles health check requests +func (c *WeComAIBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { + status := "ok" + if !c.IsRunning() { + status = "not running" + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "status": status, + }) +} diff --git a/pkg/channels/wecom/aibot_test.go b/pkg/channels/wecom/aibot_test.go new file mode 100644 index 000000000..6f0664187 --- /dev/null +++ b/pkg/channels/wecom/aibot_test.go @@ -0,0 +1,210 @@ +package wecom + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestNewWeComAIBotChannel(t *testing.T) { + t.Run("success with valid config", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + EncodingAESKey: "testkey1234567890123456789012345678901234567", + WebhookPath: "/webhook/test", + } + + messageBus := bus.NewMessageBus() + ch, err := NewWeComAIBotChannel(cfg, messageBus) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if ch == nil { + t.Fatal("Expected channel to be created") + } + + if ch.Name() != "wecom_aibot" { + t.Errorf("Expected name 'wecom_aibot', got '%s'", ch.Name()) + } + }) + + t.Run("error with missing token", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + EncodingAESKey: "testkey1234567890123456789012345678901234567", + } + + messageBus := bus.NewMessageBus() + _, err := NewWeComAIBotChannel(cfg, messageBus) + + if err == nil { + t.Fatal("Expected error for missing token, got nil") + } + }) + + t.Run("error with missing encoding key", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + } + + messageBus := bus.NewMessageBus() + _, err := NewWeComAIBotChannel(cfg, messageBus) + + if err == nil { + t.Fatal("Expected error for missing encoding key, got nil") + } + }) +} + +func TestWeComAIBotChannelStartStop(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + EncodingAESKey: "testkey1234567890123456789012345678901234567", + } + + messageBus := bus.NewMessageBus() + ch, err := NewWeComAIBotChannel(cfg, messageBus) + if err != nil { + t.Fatalf("Failed to create channel: %v", err) + } + + ctx := context.Background() + + // Test Start + if err := ch.Start(ctx); err != nil { + t.Fatalf("Failed to start channel: %v", err) + } + + if !ch.IsRunning() { + t.Error("Expected channel to be running") + } + + // Test Stop + if err := ch.Stop(ctx); err != nil { + t.Fatalf("Failed to stop channel: %v", err) + } + + if ch.IsRunning() { + t.Error("Expected channel to be stopped") + } +} + +func TestWeComAIBotChannelWebhookPath(t *testing.T) { + t.Run("default path", func(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + EncodingAESKey: "testkey1234567890123456789012345678901234567", + } + + messageBus := bus.NewMessageBus() + ch, _ := NewWeComAIBotChannel(cfg, messageBus) + + expectedPath := "/webhook/wecom-aibot" + if ch.WebhookPath() != expectedPath { + t.Errorf("Expected webhook path '%s', got '%s'", expectedPath, ch.WebhookPath()) + } + }) + + t.Run("custom path", func(t *testing.T) { + customPath := "/custom/webhook" + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + EncodingAESKey: "testkey1234567890123456789012345678901234567", + WebhookPath: customPath, + } + + messageBus := bus.NewMessageBus() + ch, _ := NewWeComAIBotChannel(cfg, messageBus) + + if ch.WebhookPath() != customPath { + t.Errorf("Expected webhook path '%s', got '%s'", customPath, ch.WebhookPath()) + } + }) +} + +func TestGenerateStreamID(t *testing.T) { + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + EncodingAESKey: "testkey1234567890123456789012345678901234567", + } + + messageBus := bus.NewMessageBus() + ch, _ := NewWeComAIBotChannel(cfg, messageBus) + + // Generate multiple IDs and check they are unique + ids := make(map[string]bool) + for i := 0; i < 100; i++ { + id := ch.generateStreamID() + + if len(id) != 10 { + t.Errorf("Expected stream ID length 10, got %d", len(id)) + } + + if ids[id] { + t.Errorf("Duplicate stream ID generated: %s", id) + } + ids[id] = true + } +} + +func TestEncryptDecrypt(t *testing.T) { + // Use a valid 43-character base64 key (企业微信标准格式) + cfg := config.WeComAIBotConfig{ + Enabled: true, + Token: "test_token", + EncodingAESKey: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", // 43 characters + } + + messageBus := bus.NewMessageBus() + ch, _ := NewWeComAIBotChannel(cfg, messageBus) + + plaintext := "Hello, World!" + receiveid := "" + + // Encrypt + encrypted, err := ch.encryptMessage(plaintext, receiveid) + if err != nil { + t.Fatalf("Failed to encrypt message: %v", err) + } + + if encrypted == "" { + t.Fatal("Encrypted message is empty") + } + + // Decrypt + decrypted, err := decryptMessageWithVerify(encrypted, cfg.EncodingAESKey, receiveid) + if err != nil { + t.Fatalf("Failed to decrypt message: %v", err) + } + + if decrypted != plaintext { + t.Errorf("Expected decrypted message '%s', got '%s'", plaintext, decrypted) + } +} + +func TestGenerateSignature(t *testing.T) { + token := "test_token" + timestamp := "1234567890" + nonce := "test_nonce" + encrypt := "encrypted_msg" + + signature := computeSignature(token, timestamp, nonce, encrypt) + + if signature == "" { + t.Error("Generated signature is empty") + } + + // Verify signature using verifySignature function + if !verifySignature(token, signature, timestamp, nonce, encrypt) { + t.Error("Generated signature does not verify correctly") + } +} diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 292a71fd2..717815b9f 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -38,8 +38,7 @@ type WeComAppChannel struct { tokenMu sync.RWMutex ctx context.Context cancel context.CancelFunc - processedMsgs map[string]bool // Message deduplication: msg_id -> processed - msgMu sync.RWMutex + processedMsgs *MessageDeduplicator } // WeComXMLMessage represents the XML message structure from WeCom @@ -144,7 +143,7 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, - processedMsgs: make(map[string]bool), + processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), }, nil } @@ -342,18 +341,11 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp return result.MediaID, nil } -// sendImageMessage sends an image message using a media_id. -func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { +// sendWeComMessage marshals payload and POSTs it to the WeCom message API. +func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken string, payload any) error { apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - msg := WeComImageMessage{ - ToUser: userID, - MsgType: "image", - AgentID: c.config.AgentID, - } - msg.Image.MediaID = mediaID - - jsonData, err := json.Marshal(msg) + jsonData, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } @@ -400,6 +392,17 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use return nil } +// sendImageMessage sends an image message using a media_id. +func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { + msg := WeComImageMessage{ + ToUser: userID, + MsgType: "image", + AgentID: c.config.AgentID, + } + msg.Image.MediaID = mediaID + return c.sendWeComMessage(ctx, accessToken, msg) +} + // WebhookPath returns the path for registering on the shared HTTP server. func (c *WeComAppChannel) WebhookPath() string { if c.config.WebhookPath != "" { @@ -603,23 +606,12 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag // Message deduplication: Use msg_id to prevent duplicate processing // As per WeCom documentation, use msg_id for deduplication msgID := fmt.Sprintf("%d", msg.MsgId) - c.msgMu.Lock() - if c.processedMsgs[msgID] { - c.msgMu.Unlock() + if !c.processedMsgs.MarkMessageProcessed(msgID) { logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{ "msg_id": msgID, }) return } - c.processedMsgs[msgID] = true - // Clean up old messages while still holding the lock to avoid a data race - // on len(). Reset the map but re-insert the current msgID so it remains - // deduplicated. - if len(c.processedMsgs) > 1000 { - c.processedMsgs = make(map[string]bool) - c.processedMsgs[msgID] = true - } - c.msgMu.Unlock() senderID := msg.FromUserName chatID := senderID // WeCom App uses user ID as chat ID for direct messages @@ -722,63 +714,15 @@ func (c *WeComAppChannel) getAccessToken() string { return c.accessToken } -// sendTextMessage sends a text message to a user +// sendTextMessage sends a text message to a user. func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { - apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - msg := WeComTextMessage{ ToUser: userID, MsgType: "text", AgentID: c.config.AgentID, } msg.Text.Content = content - - jsonData, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) - } - - // Use configurable timeout (default 5 seconds) - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(req) - if err != nil { - return channels.ClassifyNetError(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var sendResp WeComSendMessageResponse - if err := json.Unmarshal(body, &sendResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if sendResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) - } - - return nil + return c.sendWeComMessage(ctx, accessToken, msg) } // handleHealth handles health check requests diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go index 0d15e955b..7f230494f 100644 --- a/pkg/channels/wecom/app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -323,60 +323,6 @@ func TestWeComAppDecryptMessage(t *testing.T) { }) } -func TestWeComAppPKCS7Unpad(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - }{ - { - name: "empty input", - input: []byte{}, - expected: []byte{}, - }, - { - name: "valid padding 3 bytes", - input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), - expected: []byte("hello"), - }, - { - name: "valid padding 16 bytes (full block)", - input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("123456789012345"), - }, - { - name: "invalid padding larger than data", - input: []byte{20}, - expected: nil, // should return error - }, - { - name: "invalid padding zero", - input: append([]byte("test"), byte(0)), - expected: nil, // should return error - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7Unpad(tt.input) - if tt.expected == nil { - // This case should return an error - if err == nil { - t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) - } - return - } - if err != nil { - t.Errorf("pkcs7Unpad() unexpected error: %v", err) - return - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) - } - }) - } -} - func TestWeComAppHandleVerification(t *testing.T) { msgBus := bus.NewMessageBus() aesKey := generateTestAESKeyApp() diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 0d0426c0d..9126a847d 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -9,7 +9,6 @@ import ( "io" "net/http" "strings" - "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" @@ -28,8 +27,7 @@ type WeComBotChannel struct { client *http.Client ctx context.Context cancel context.CancelFunc - processedMsgs map[string]bool // Message deduplication: msg_id -> processed - msgMu sync.RWMutex + processedMsgs *MessageDeduplicator } // WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) @@ -108,7 +106,7 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, - processedMsgs: make(map[string]bool), + processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), }, nil } @@ -330,23 +328,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag // Message deduplication: Use msg_id to prevent duplicate processing msgID := msg.MsgID - c.msgMu.Lock() - if c.processedMsgs[msgID] { - c.msgMu.Unlock() + if !c.processedMsgs.MarkMessageProcessed(msgID) { logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{ "msg_id": msgID, }) return } - c.processedMsgs[msgID] = true - // Clean up old messages while still holding the lock to avoid a data race - // on len(). Reset the map but re-insert the current msgID so it remains - // deduplicated. - if len(c.processedMsgs) > 1000 { - c.processedMsgs = make(map[string]bool) - c.processedMsgs[msgID] = true - } - c.msgMu.Unlock() senderID := msg.From.UserID diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go index 97b503ce8..c053578b1 100644 --- a/pkg/channels/wecom/bot_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -412,22 +412,9 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - t.Run("valid direct message callback", func(t *testing.T) { - // Create JSON message for direct chat (single) - jsonMsg := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chattype": "single", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` - - // Encrypt message + runBotMessageCallback := func(t *testing.T, jsonMsg string) *httptest.ResponseRecorder { + t.Helper() encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - - // Create encrypted XML wrapper encryptedWrapper := struct { XMLName xml.Name `xml:"xml"` Encrypt string `xml:"Encrypt"` @@ -435,20 +422,29 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { Encrypt: encrypted, } wrapperData, _ := xml.Marshal(encryptedWrapper) - timestamp := "1234567890" nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, encrypted) - req := httptest.NewRequest( http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData), ) w := httptest.NewRecorder() - ch.handleMessageCallback(context.Background(), w, req) + return w + } + t.Run("valid direct message callback", func(t *testing.T) { + w := runBotMessageCallback(t, `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chattype": "single", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }`) if w.Code != http.StatusOK { t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) } @@ -458,8 +454,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { }) t.Run("valid group message callback", func(t *testing.T) { - // Create JSON message for group chat - jsonMsg := `{ + w := runBotMessageCallback(t, `{ "msgid": "test_msg_id_456", "aibotid": "test_aibot_id", "chatid": "group_chat_id_123", @@ -468,33 +463,7 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", "msgtype": "text", "text": {"content": "Hello Group"} - }` - - // Encrypt message - encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - - // Create encrypted XML wrapper - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encrypted) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - + }`) if w.Code != http.StatusOK { t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) } diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go index 39a27d04c..6510e6f81 100644 --- a/pkg/channels/wecom/common.go +++ b/pkg/channels/wecom/common.go @@ -1,12 +1,15 @@ package wecom import ( + "bytes" "crypto/aes" "crypto/cipher" + "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/binary" "fmt" + "math/big" "sort" "strings" ) @@ -14,25 +17,23 @@ import ( // blockSize is the PKCS7 block size used by WeCom (32) const blockSize = 32 +// computeSignature computes the WeCom message signature from the given parameters. +// It sorts [token, timestamp, nonce, encrypt], concatenates them and returns the SHA1 hex digest. +func computeSignature(token, timestamp, nonce, encrypt string) string { + params := []string{token, timestamp, nonce, encrypt} + sort.Strings(params) + str := strings.Join(params, "") + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash) +} + // verifySignature verifies the message signature for WeCom // This is a common function used by both WeCom Bot and WeCom App func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { if token == "" { return true // Skip verification if token is not set } - - // Sort parameters - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - - // Concatenate - str := strings.Join(params, "") - - // SHA1 hash - hash := sha1.Sum([]byte(str)) - expectedSignature := fmt.Sprintf("%x", hash) - - return expectedSignature == msgSignature + return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature } // decryptMessage decrypts the encrypted message using AES @@ -53,64 +54,128 @@ func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (s return string(decoded), nil } - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + aesKey, err := decodeWeComAESKey(encodingAESKey) if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) + return "", err } - // Decode encrypted message cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) if err != nil { return "", fmt.Errorf("failed to decode message: %w", err) } - // AES decrypt + plainText, err := decryptAESCBC(aesKey, cipherText) + if err != nil { + return "", err + } + + return unpackWeComFrame(plainText, receiveid) +} + +// decodeWeComAESKey base64-decodes the 43-character EncodingAESKey (trailing "=" is +// appended automatically) and validates that the result is exactly 32 bytes. +// It is the single place that handles this repeated pattern in both encrypt and decrypt paths. +func decodeWeComAESKey(encodingAESKey string) ([]byte, error) { + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return nil, fmt.Errorf("failed to decode AES key: %w", err) + } + if len(aesKey) != 32 { + return nil, fmt.Errorf("invalid AES key length: %d", len(aesKey)) + } + return aesKey, nil +} + +// encryptAESCBC encrypts plaintext using AES-CBC with the given key, mirroring +// decryptAESCBC. IV = aesKey[:aes.BlockSize]. The caller must PKCS7-pad the +// plaintext to a multiple of aes.BlockSize before calling. +func encryptAESCBC(aesKey, plaintext []byte) ([]byte, error) { block, err := aes.NewCipher(aesKey) if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) + return nil, fmt.Errorf("failed to create cipher: %w", err) } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - // IV is the first 16 bytes of AESKey iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCDecrypter(block, iv) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) + ciphertext := make([]byte, len(plaintext)) + cipher.NewCBCEncrypter(block, iv).CryptBlocks(ciphertext, plaintext) + return ciphertext, nil +} - // Remove PKCS7 padding - plainText, err = pkcs7Unpad(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) +// packWeComFrame builds the WeCom wire format: +// +// random(16 ASCII digits) + msg_len(4, big-endian) + msg + receiveid +func packWeComFrame(msg, receiveid string) ([]byte, error) { + randomBytes := make([]byte, 16) + for i := range 16 { + n, err := rand.Int(rand.Reader, big.NewInt(10)) + if err != nil { + return nil, fmt.Errorf("failed to generate random: %w", err) + } + randomBytes[i] = byte('0' + n.Int64()) } + msgBytes := []byte(msg) + msgLenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(msgLenBytes, uint32(len(msgBytes))) + var buf bytes.Buffer + buf.Write(randomBytes) + buf.Write(msgLenBytes) + buf.Write(msgBytes) + buf.WriteString(receiveid) + return buf.Bytes(), nil +} - // Parse message structure - // Format: random(16) + msg_len(4) + msg + receiveid - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") +// unpackWeComFrame parses the WeCom wire format produced by packWeComFrame. +// If receiveid is non-empty it verifies the frame's trailing receiveid field. +func unpackWeComFrame(data []byte, receiveid string) (string, error) { + if len(data) < 20 { + return "", fmt.Errorf("decrypted frame too short: %d bytes", len(data)) } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") + msgLen := binary.BigEndian.Uint32(data[16:20]) + if int(msgLen) > len(data)-20 { + return "", fmt.Errorf("invalid message length: %d", msgLen) } - - msg := plainText[20 : 20+msgLen] - - // Verify receiveid if provided - if receiveid != "" && len(plainText) > 20+int(msgLen) { - actualReceiveID := string(plainText[20+msgLen:]) + msg := data[20 : 20+msgLen] + if receiveid != "" && len(data) > 20+int(msgLen) { + actualReceiveID := string(data[20+msgLen:]) if actualReceiveID != receiveid { return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) } } - return string(msg), nil } +// decryptAESCBC decrypts ciphertext using AES-CBC with the given key. +// IV = aesKey[:aes.BlockSize]. PKCS7 padding is stripped from the returned plaintext. +func decryptAESCBC(aesKey, ciphertext []byte) ([]byte, error) { + if len(ciphertext) == 0 { + return nil, fmt.Errorf("ciphertext is empty") + } + if len(ciphertext)%aes.BlockSize != 0 { + return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size", len(ciphertext)) + } + block, err := aes.NewCipher(aesKey) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + iv := aesKey[:aes.BlockSize] + plaintext := make([]byte, len(ciphertext)) + cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ciphertext) + plaintext, err = pkcs7Unpad(plaintext) + if err != nil { + return nil, fmt.Errorf("failed to unpad: %w", err) + } + return plaintext, nil +} + +// pkcs7Pad adds PKCS7 padding +func pkcs7Pad(data []byte, blockSize int) []byte { + padding := blockSize - (len(data) % blockSize) + if padding == 0 { + padding = blockSize + } + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + // pkcs7Unpad removes PKCS7 padding with validation func pkcs7Unpad(data []byte) ([]byte, error) { if len(data) == 0 { diff --git a/pkg/channels/wecom/dedupe.go b/pkg/channels/wecom/dedupe.go new file mode 100644 index 000000000..865be668e --- /dev/null +++ b/pkg/channels/wecom/dedupe.go @@ -0,0 +1,54 @@ +package wecom + +import "sync" + +const wecomMaxProcessedMessages = 1000 + +// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer) +// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest +// messages without causing "amnesia cliffs" when the limit is reached. +type MessageDeduplicator struct { + mu sync.Mutex + msgs map[string]bool + ring []string + idx int + max int +} + +// NewMessageDeduplicator creates a new deduplicator with the specified capacity. +func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator { + if maxEntries <= 0 { + maxEntries = wecomMaxProcessedMessages + } + return &MessageDeduplicator{ + msgs: make(map[string]bool, maxEntries), + ring: make([]string, maxEntries), + max: maxEntries, + } +} + +// MarkMessageProcessed marks msgID as processed and returns false for duplicates. +func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // 1. Check for duplicate + if d.msgs[msgID] { + return false + } + + // 2. Evict the oldest message at our current ring position (if any) + oldestID := d.ring[d.idx] + if oldestID != "" { + delete(d.msgs, oldestID) + } + + // 3. Store the new message + d.msgs[msgID] = true + d.ring[d.idx] = msgID + + // 4. Advance the circle queue index + d.idx = (d.idx + 1) % d.max + + return true +} diff --git a/pkg/channels/wecom/dedupe_test.go b/pkg/channels/wecom/dedupe_test.go new file mode 100644 index 000000000..10dff4cfe --- /dev/null +++ b/pkg/channels/wecom/dedupe_test.go @@ -0,0 +1,83 @@ +package wecom + +import ( + "sync" + "testing" +) + +func TestMessageDeduplicator_DuplicateDetection(t *testing.T) { + d := NewMessageDeduplicator(wecomMaxProcessedMessages) + + if ok := d.MarkMessageProcessed("msg-1"); !ok { + t.Fatalf("first message should be accepted") + } + + if ok := d.MarkMessageProcessed("msg-1"); ok { + t.Fatalf("duplicate message should be rejected") + } +} + +func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) { + d := NewMessageDeduplicator(wecomMaxProcessedMessages) + + const goroutines = 64 + var wg sync.WaitGroup + wg.Add(goroutines) + + results := make(chan bool, goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + results <- d.MarkMessageProcessed("msg-concurrent") + }() + } + + wg.Wait() + close(results) + + successes := 0 + for ok := range results { + if ok { + successes++ + } + } + + if successes != 1 { + t.Fatalf("expected exactly 1 successful mark, got %d", successes) + } +} + +func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) { + // Create a deduplicator with a very small capacity to test eviction easily. + capacity := 3 + d := NewMessageDeduplicator(capacity) + + // Fill the queue. + d.MarkMessageProcessed("msg-1") + d.MarkMessageProcessed("msg-2") + d.MarkMessageProcessed("msg-3") + + // At this point, the queue is full. msg-1 is the oldest. + if len(d.msgs) != 3 { + t.Fatalf("expected map size to be 3, got %d", len(d.msgs)) + } + + // This should evict msg-1 and add msg-4. + if ok := d.MarkMessageProcessed("msg-4"); !ok { + t.Fatalf("msg-4 should be accepted") + } + + if len(d.msgs) != 3 { + t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs)) + } + + // msg-1 should now be forgotten (evicted). + if ok := d.MarkMessageProcessed("msg-1"); !ok { + t.Fatalf("msg-1 should be accepted again because it was evicted") + } + + // msg-2 should have been evicted when we added msg-1 back. + if ok := d.MarkMessageProcessed("msg-2"); !ok { + t.Fatalf("msg-2 should be accepted again because it was evicted") + } +} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go index 3ef1ecdf3..bc5a70fa3 100644 --- a/pkg/channels/wecom/init.go +++ b/pkg/channels/wecom/init.go @@ -13,4 +13,7 @@ func init() { channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { return NewWeComAppChannel(cfg.Channels.WeComApp, b) }) + channels.RegisterFactory("wecom_aibot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComAIBotChannel(cfg.Channels.WeComAIBot, b) + }) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 779928574..55f4e34fa 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -180,6 +180,18 @@ type AgentDefaults struct { MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` +} + +const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB + +func (d *AgentDefaults) GetMaxMediaSize() int { + if d.MaxMediaSize > 0 { + return d.MaxMediaSize + } + return DefaultMaxMediaSize } // GetModelName returns the effective model name for the agent defaults. @@ -192,19 +204,20 @@ func (d *AgentDefaults) GetModelName() string { } type ChannelsConfig struct { - WhatsApp WhatsAppConfig `json:"whatsapp"` - Telegram TelegramConfig `json:"telegram"` - Feishu FeishuConfig `json:"feishu"` - Discord DiscordConfig `json:"discord"` - MaixCam MaixCamConfig `json:"maixcam"` - QQ QQConfig `json:"qq"` - DingTalk DingTalkConfig `json:"dingtalk"` - Slack SlackConfig `json:"slack"` - LINE LINEConfig `json:"line"` - OneBot OneBotConfig `json:"onebot"` - WeCom WeComConfig `json:"wecom"` - WeComApp WeComAppConfig `json:"wecom_app"` - Pico PicoConfig `json:"pico"` + WhatsApp WhatsAppConfig `json:"whatsapp"` + Telegram TelegramConfig `json:"telegram"` + Feishu FeishuConfig `json:"feishu"` + Discord DiscordConfig `json:"discord"` + MaixCam MaixCamConfig `json:"maixcam"` + QQ QQConfig `json:"qq"` + DingTalk DingTalkConfig `json:"dingtalk"` + Slack SlackConfig `json:"slack"` + LINE LINEConfig `json:"line"` + OneBot OneBotConfig `json:"onebot"` + WeCom WeComConfig `json:"wecom"` + WeComApp WeComAppConfig `json:"wecom_app"` + WeComAIBot WeComAIBotConfig `json:"wecom_aibot"` + Pico PicoConfig `json:"pico"` } // GroupTriggerConfig controls when the bot responds in group chats. @@ -236,6 +249,7 @@ type WhatsAppConfig struct { type TelegramConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"` Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` @@ -252,12 +266,14 @@ type FeishuConfig struct { VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` } type DiscordConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` @@ -360,6 +376,18 @@ type WeComAppConfig struct { ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` } +type WeComAIBotConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ENCODING_AES_KEY"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REPLY_TIMEOUT"` + MaxSteps int `json:"max_steps" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_MAX_STEPS"` // Maximum streaming steps + WelcomeMessage string `json:"welcome_message" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_WELCOME_MESSAGE"` // Sent on enter_chat event; empty = no welcome + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_AIBOT_REASONING_CHANNEL_ID"` +} + type PicoConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` @@ -386,6 +414,7 @@ type DevicesConfig struct { type ProvidersConfig struct { Anthropic ProviderConfig `json:"anthropic"` OpenAI OpenAIProviderConfig `json:"openai"` + LiteLLM ProviderConfig `json:"litellm"` OpenRouter ProviderConfig `json:"openrouter"` Groq ProviderConfig `json:"groq"` Zhipu ProviderConfig `json:"zhipu"` @@ -410,6 +439,7 @@ type ProvidersConfig struct { func (p ProvidersConfig) IsEmpty() bool { return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" && p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" && + p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" && p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" && p.Groq.APIKey == "" && p.Groq.APIBase == "" && p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" && @@ -519,11 +549,22 @@ type PerplexityConfig struct { MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"` } +type GLMSearchConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"` + BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"` + // SearchEngine specifies the search backend: "search_std" (default), + // "search_pro", "search_pro_sogou", or "search_pro_quark". + SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"` +} + type WebToolsConfig struct { Brave BraveConfig `json:"brave"` Tavily TavilyConfig `json:"tavily"` DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` Perplexity PerplexityConfig `json:"perplexity"` + GLMSearch GLMSearchConfig `json:"glm_search"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` @@ -554,6 +595,7 @@ type ToolsConfig struct { Exec ExecConfig `json:"exec"` Skills SkillsToolsConfig `json:"skills"` MediaCleanup MediaCleanupConfig `json:"media_cleanup"` + MCP MCPConfig `json:"mcp"` } type SkillsToolsConfig struct { @@ -583,6 +625,34 @@ type ClawHubRegistryConfig struct { MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"` } +// MCPServerConfig defines configuration for a single MCP server +type MCPServerConfig struct { + // Enabled indicates whether this MCP server is active + Enabled bool `json:"enabled"` + // Command is the executable to run (e.g., "npx", "python", "/path/to/server") + Command string `json:"command"` + // Args are the arguments to pass to the command + Args []string `json:"args,omitempty"` + // Env are environment variables to set for the server process (stdio only) + Env map[string]string `json:"env,omitempty"` + // EnvFile is the path to a file containing environment variables (stdio only) + EnvFile string `json:"env_file,omitempty"` + // Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set) + Type string `json:"type,omitempty"` + // URL is used for SSE/HTTP transport + URL string `json:"url,omitempty"` + // Headers are HTTP headers to send with requests (sse/http only) + Headers map[string]string `json:"headers,omitempty"` +} + +// MCPConfig defines configuration for all MCP servers +type MCPConfig struct { + // Enabled globally enables/disables MCP integration + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"` + // Servers is a map of server name to server configuration + Servers map[string]MCPServerConfig `json:"servers,omitempty"` +} + func LoadConfig(path string) (*Config, error) { cfg := DefaultConfig() @@ -639,7 +709,8 @@ func (c *Config) migrateChannelConfigs() { } // OneBot: group_trigger_prefix -> group_trigger.prefixes - if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { + if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && + len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix } } @@ -749,6 +820,7 @@ func (c *Config) findMatches(modelName string) []ModelConfig { // HasProvidersConfig checks if any provider in the old providers config has configuration. func (c *Config) HasProvidersConfig() bool { +<<<<<<< HEAD v := c.Providers return v.Anthropic.APIKey != "" || v.Anthropic.APIBase != "" || v.OpenAI.APIKey != "" || v.OpenAI.APIBase != "" || @@ -769,6 +841,9 @@ func (c *Config) HasProvidersConfig() bool { v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" || v.Qwen.APIKey != "" || v.Qwen.APIBase != "" || v.Mistral.APIKey != "" || v.Mistral.APIBase != "" +======= + return !c.Providers.IsEmpty() +>>>>>>> origin_picoclaw/main } // ValidateModelList validates all ModelConfig entries in the model_list. diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 6af7c209e..10ebc7c90 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -435,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) { } // TestDefaultConfig_DMScope verifies the default dm_scope value +// TestDefaultConfig_SummarizationThresholds verifies summarization defaults +func TestDefaultConfig_SummarizationThresholds(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 { + t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold) + } + if cfg.Agents.Defaults.SummarizeTokenPercent != 75 { + t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent) + } +} + func TestDefaultConfig_DMScope(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 385c2f653..518d3421d 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -26,13 +26,15 @@ func DefaultConfig() *Config { return &Config{ Agents: AgentsConfig{ Defaults: AgentDefaults{ - Workspace: workspacePath, - RestrictToWorkspace: true, - Provider: "", - Model: "", - MaxTokens: 32768, - Temperature: nil, // nil means use provider default - MaxToolIterations: 50, + Workspace: workspacePath, + RestrictToWorkspace: true, + Provider: "", + Model: "", + MaxTokens: 32768, + Temperature: nil, // nil means use provider default + MaxToolIterations: 50, + SummarizeMessageThreshold: 20, + SummarizeTokenPercent: 75, }, }, Bindings: []AgentBinding{}, @@ -137,6 +139,16 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, ReplyTimeout: 5, }, + WeComAIBot: WeComAIBotConfig{ + Enabled: false, + Token: "", + EncodingAESKey: "", + WebhookPath: "/webhook/wecom-aibot", + AllowFrom: FlexibleStringSlice{}, + ReplyTimeout: 5, + MaxSteps: 10, + WelcomeMessage: "Hello! I'm your AI assistant. How can I help you today?", + }, Pico: PicoConfig{ Enabled: false, Token: "", @@ -339,6 +351,13 @@ func DefaultConfig() *Config { APIKey: "", MaxResults: 5, }, + GLMSearch: GLMSearchConfig{ + Enabled: false, + APIKey: "", + BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search", + SearchEngine: "search_std", + MaxResults: 5, + }, }, Cron: CronToolsConfig{ ExecTimeoutMinutes: 5, @@ -359,6 +378,10 @@ func DefaultConfig() *Config { TTLSeconds: 300, }, }, + MCP: MCPConfig{ + Enabled: false, + Servers: map[string]MCPServerConfig{}, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, diff --git a/pkg/config/migration.go b/pkg/config/migration.go index e1e0fe0d5..111efa722 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -88,6 +88,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, true }, }, + { + providerNames: []string{"litellm"}, + protocol: "litellm", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "litellm", + Model: "litellm/auto", + APIKey: p.LiteLLM.APIKey, + APIBase: p.LiteLLM.APIBase, + Proxy: p.LiteLLM.Proxy, + RequestTimeout: p.LiteLLM.RequestTimeout, + }, true + }, + }, { providerNames: []string{"openrouter"}, protocol: "openrouter", diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index 7070db4be..6106a5b17 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -63,6 +63,33 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) { } } +func TestConvertProvidersToModelList_LiteLLM(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + LiteLLM: ProviderConfig{ + APIKey: "litellm-key", + APIBase: "http://localhost:4000/v1", + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].ModelName != "litellm" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "litellm") + } + if result[0].Model != "litellm/auto" { + t.Errorf("Model = %q, want %q", result[0].Model, "litellm/auto") + } + if result[0].APIBase != "http://localhost:4000/v1" { + t.Errorf("APIBase = %q, want %q", result[0].APIBase, "http://localhost:4000/v1") + } +} + func TestConvertProvidersToModelList_Multiple(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ @@ -115,6 +142,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}}, + LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"}, Anthropic: ProviderConfig{APIKey: "key2"}, OpenRouter: ProviderConfig{APIKey: "key3"}, Groq: ProviderConfig{APIKey: "key4"}, diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go index a7aef8c3a..3b7eeeefb 100644 --- a/pkg/heartbeat/service_test.go +++ b/pkg/heartbeat/service_test.go @@ -47,79 +47,63 @@ func TestExecuteHeartbeat_Async(t *testing.T) { } } -func TestExecuteHeartbeat_Error(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - hs := NewHeartbeatService(tmpDir, 30, true) - hs.stopChan = make(chan struct{}) // Enable for testing - - hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { - return &tools.ToolResult{ - ForLLM: "Heartbeat failed: connection error", - ForUser: "", - Silent: false, - IsError: true, - Async: false, - } - }) - - // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) - - hs.executeHeartbeat() - - // Check log file for error message - logFile := filepath.Join(tmpDir, "heartbeat.log") - data, err := os.ReadFile(logFile) - if err != nil { - t.Fatalf("Failed to read log file: %v", err) +func TestExecuteHeartbeat_ResultLogging(t *testing.T) { + tests := []struct { + name string + result *tools.ToolResult + wantLog string + }{ + { + name: "error result", + result: &tools.ToolResult{ + ForLLM: "Heartbeat failed: connection error", + ForUser: "", + Silent: false, + IsError: true, + Async: false, + }, + wantLog: "error message", + }, + { + name: "silent result", + result: &tools.ToolResult{ + ForLLM: "Heartbeat completed successfully", + ForUser: "", + Silent: true, + IsError: false, + Async: false, + }, + wantLog: "completion message", + }, } - logContent := string(data) - if logContent == "" { - t.Error("Expected log file to contain error message") - } -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) -func TestExecuteHeartbeat_Silent(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing - hs := NewHeartbeatService(tmpDir, 30, true) - hs.stopChan = make(chan struct{}) // Enable for testing + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return tt.result + }) - hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { - return &tools.ToolResult{ - ForLLM: "Heartbeat completed successfully", - ForUser: "", - Silent: true, - IsError: false, - Async: false, - } - }) + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) + hs.executeHeartbeat() - // Create HEARTBEAT.md - os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644) - - hs.executeHeartbeat() - - // Check log file for completion message - logFile := filepath.Join(tmpDir, "heartbeat.log") - data, err := os.ReadFile(logFile) - if err != nil { - t.Fatalf("Failed to read log file: %v", err) - } - - logContent := string(data) - if logContent == "" { - t.Error("Expected log file to contain completion message") + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + if string(data) == "" { + t.Errorf("Expected log file to contain %s", tt.wantLog) + } + }) } } diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go new file mode 100644 index 000000000..7b63cc979 --- /dev/null +++ b/pkg/mcp/manager.go @@ -0,0 +1,532 @@ +package mcp + +import ( + "bufio" + "context" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "sync/atomic" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// headerTransport is an http.RoundTripper that adds custom headers to requests +type headerTransport struct { + base http.RoundTripper + headers map[string]string +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid modifying the original + req = req.Clone(req.Context()) + + // Add custom headers + for key, value := range t.headers { + req.Header.Set(key, value) + } + + // Use the base transport + base := t.base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(req) +} + +// loadEnvFile loads environment variables from a file in .env format +// Each line should be in the format: KEY=value +// Lines starting with # are comments +// Empty lines are ignored +func loadEnvFile(path string) (map[string]string, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open env file: %w", err) + } + defer file.Close() + + envVars := make(map[string]string) + scanner := bufio.NewScanner(file) + lineNum := 0 + + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Parse KEY=value + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key == "" { + return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum) + } + + // Remove surrounding quotes if present + if len(value) >= 2 { + if (value[0] == '"' && value[len(value)-1] == '"') || + (value[0] == '\'' && value[len(value)-1] == '\'') { + value = value[1 : len(value)-1] + } + } + + envVars[key] = value + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading env file: %w", err) + } + + return envVars, nil +} + +// ServerConnection represents a connection to an MCP server +type ServerConnection struct { + Name string + Client *mcp.Client + Session *mcp.ClientSession + Tools []*mcp.Tool +} + +// Manager manages multiple MCP server connections +type Manager struct { + servers map[string]*ServerConnection + mu sync.RWMutex + closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race + wg sync.WaitGroup // tracks in-flight CallTool calls +} + +// NewManager creates a new MCP manager +func NewManager() *Manager { + return &Manager{ + servers: make(map[string]*ServerConnection), + } +} + +// LoadFromConfig loads MCP servers from configuration +func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error { + return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath()) +} + +// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path. +// This is the minimal dependency version that doesn't require the full Config object. +func (m *Manager) LoadFromMCPConfig( + ctx context.Context, + mcpCfg config.MCPConfig, + workspacePath string, +) error { + if !mcpCfg.Enabled { + logger.InfoCF("mcp", "MCP integration is disabled", nil) + return nil + } + + if len(mcpCfg.Servers) == 0 { + logger.InfoCF("mcp", "No MCP servers configured", nil) + return nil + } + + logger.InfoCF("mcp", "Initializing MCP servers", + map[string]any{ + "count": len(mcpCfg.Servers), + }) + + var wg sync.WaitGroup + errs := make(chan error, len(mcpCfg.Servers)) + enabledCount := 0 + + for name, serverCfg := range mcpCfg.Servers { + if !serverCfg.Enabled { + logger.DebugCF("mcp", "Skipping disabled server", + map[string]any{ + "server": name, + }) + continue + } + + enabledCount++ + wg.Add(1) + go func(name string, serverCfg config.MCPServerConfig, workspace string) { + defer wg.Done() + + // Resolve relative envFile paths relative to workspace + if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) { + if workspace == "" { + err := fmt.Errorf( + "workspace path is empty while resolving relative envFile %q for server %s", + serverCfg.EnvFile, + name, + ) + logger.ErrorCF("mcp", "Invalid MCP server configuration", + map[string]any{ + "server": name, + "env_file": serverCfg.EnvFile, + "error": err.Error(), + }) + errs <- err + return + } + serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile) + } + + if err := m.ConnectServer(ctx, name, serverCfg); err != nil { + logger.ErrorCF("mcp", "Failed to connect to MCP server", + map[string]any{ + "server": name, + "error": err.Error(), + }) + errs <- fmt.Errorf("failed to connect to server %s: %w", name, err) + } + }(name, serverCfg, workspacePath) + } + + wg.Wait() + close(errs) + + // Collect errors + var allErrors []error + for err := range errs { + allErrors = append(allErrors, err) + } + + connectedCount := len(m.GetServers()) + + // If all enabled servers failed to connect, return aggregated error + if enabledCount > 0 && connectedCount == 0 { + logger.ErrorCF("mcp", "All MCP servers failed to connect", + map[string]any{ + "failed": len(allErrors), + "total": enabledCount, + }) + return errors.Join(allErrors...) + } + + if len(allErrors) > 0 { + logger.WarnCF("mcp", "Some MCP servers failed to connect", + map[string]any{ + "failed": len(allErrors), + "connected": connectedCount, + "total": enabledCount, + }) + // Don't fail completely if some servers successfully connected + } + + logger.InfoCF("mcp", "MCP server initialization complete", + map[string]any{ + "connected": connectedCount, + "total": enabledCount, + }) + + return nil +} + +// ConnectServer connects to a single MCP server +func (m *Manager) ConnectServer( + ctx context.Context, + name string, + cfg config.MCPServerConfig, +) error { + logger.InfoCF("mcp", "Connecting to MCP server", + map[string]any{ + "server": name, + "command": cfg.Command, + "args_count": len(cfg.Args), + }) + + // Create client + client := mcp.NewClient(&mcp.Implementation{ + Name: "picoclaw", + Version: "1.0.0", + }, nil) + + // Create transport based on configuration + // Auto-detect transport type if not explicitly specified + var transport mcp.Transport + transportType := cfg.Type + + // Auto-detect: if URL is provided, use SSE; if command is provided, use stdio + if transportType == "" { + if cfg.URL != "" { + transportType = "sse" + } else if cfg.Command != "" { + transportType = "stdio" + } else { + return fmt.Errorf("either URL or command must be provided") + } + } + + switch transportType { + case "sse", "http": + if cfg.URL == "" { + return fmt.Errorf("URL is required for SSE/HTTP transport") + } + logger.DebugCF("mcp", "Using SSE/HTTP transport", + map[string]any{ + "server": name, + "url": cfg.URL, + }) + + sseTransport := &mcp.StreamableClientTransport{ + Endpoint: cfg.URL, + } + + // Add custom headers if provided + if len(cfg.Headers) > 0 { + // Create a custom HTTP client with header-injecting transport + sseTransport.HTTPClient = &http.Client{ + Transport: &headerTransport{ + base: http.DefaultTransport, + headers: cfg.Headers, + }, + } + logger.DebugCF("mcp", "Added custom HTTP headers", + map[string]any{ + "server": name, + "header_count": len(cfg.Headers), + }) + } + + transport = sseTransport + case "stdio": + if cfg.Command == "" { + return fmt.Errorf("command is required for stdio transport") + } + logger.DebugCF("mcp", "Using stdio transport", + map[string]any{ + "server": name, + "command": cfg.Command, + }) + // Create command with context + cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...) + + // Build environment variables with proper override semantics + // Use a map to ensure config variables override file variables + envMap := make(map[string]string) + + // Start with parent process environment + for _, e := range cmd.Environ() { + if idx := strings.Index(e, "="); idx > 0 { + envMap[e[:idx]] = e[idx+1:] + } + } + + // Load environment variables from file if specified + if cfg.EnvFile != "" { + envVars, err := loadEnvFile(cfg.EnvFile) + if err != nil { + return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err) + } + for k, v := range envVars { + envMap[k] = v + } + logger.DebugCF("mcp", "Loaded environment variables from file", + map[string]any{ + "server": name, + "envFile": cfg.EnvFile, + "var_count": len(envVars), + }) + } + + // Environment variables from config override those from file + for k, v := range cfg.Env { + envMap[k] = v + } + + // Convert map to slice + env := make([]string, 0, len(envMap)) + for k, v := range envMap { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + cmd.Env = env + + transport = &mcp.CommandTransport{Command: cmd} + default: + return fmt.Errorf( + "unsupported transport type: %s (supported: stdio, sse, http)", + transportType, + ) + } + + // Connect to server + session, err := client.Connect(ctx, transport, nil) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + // Get server info + initResult := session.InitializeResult() + logger.InfoCF("mcp", "Connected to MCP server", + map[string]any{ + "server": name, + "serverName": initResult.ServerInfo.Name, + "serverVersion": initResult.ServerInfo.Version, + "protocol": initResult.ProtocolVersion, + }) + + // List available tools if supported + var tools []*mcp.Tool + if initResult.Capabilities.Tools != nil { + for tool, err := range session.Tools(ctx, nil) { + if err != nil { + logger.WarnCF("mcp", "Error listing tool", + map[string]any{ + "server": name, + "error": err.Error(), + }) + continue + } + tools = append(tools, tool) + } + + logger.InfoCF("mcp", "Listed tools from MCP server", + map[string]any{ + "server": name, + "toolCount": len(tools), + }) + } + + // Store connection + m.mu.Lock() + m.servers[name] = &ServerConnection{ + Name: name, + Client: client, + Session: session, + Tools: tools, + } + m.mu.Unlock() + + return nil +} + +// GetServers returns all connected servers +func (m *Manager) GetServers() map[string]*ServerConnection { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]*ServerConnection, len(m.servers)) + for k, v := range m.servers { + result[k] = v + } + return result +} + +// GetServer returns a specific server connection +func (m *Manager) GetServer(name string) (*ServerConnection, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + conn, ok := m.servers[name] + return conn, ok +} + +// CallTool calls a tool on a specific server +func (m *Manager) CallTool( + ctx context.Context, + serverName, toolName string, + arguments map[string]any, +) (*mcp.CallToolResult, error) { + // Check if closed before acquiring lock (fast path) + if m.closed.Load() { + return nil, fmt.Errorf("manager is closed") + } + + m.mu.RLock() + // Double-check after acquiring lock to prevent TOCTOU race + if m.closed.Load() { + m.mu.RUnlock() + return nil, fmt.Errorf("manager is closed") + } + conn, ok := m.servers[serverName] + if ok { + m.wg.Add(1) // Add to WaitGroup while holding the lock + } + m.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("server %s not found", serverName) + } + defer m.wg.Done() + + params := &mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + } + + result, err := conn.Session.CallTool(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to call tool: %w", err) + } + + return result, nil +} + +// Close closes all server connections +func (m *Manager) Close() error { + // Use Swap to atomically set closed=true and get the previous value + // This prevents TOCTOU race with CallTool's closed check + if m.closed.Swap(true) { + return nil // already closed + } + + // Wait for all in-flight CallTool calls to finish before closing sessions + // After closed=true is set, no new CallTool can start (they check closed first) + m.wg.Wait() + + m.mu.Lock() + defer m.mu.Unlock() + + logger.InfoCF("mcp", "Closing all MCP server connections", + map[string]any{ + "count": len(m.servers), + }) + + var errs []error + for name, conn := range m.servers { + if err := conn.Session.Close(); err != nil { + logger.ErrorCF("mcp", "Failed to close server connection", + map[string]any{ + "server": name, + "error": err.Error(), + }) + errs = append(errs, fmt.Errorf("server %s: %w", name, err)) + } + } + + m.servers = make(map[string]*ServerConnection) + + if len(errs) > 0 { + return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...)) + } + + return nil +} + +// GetAllTools returns all tools from all connected servers +func (m *Manager) GetAllTools() map[string][]*mcp.Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string][]*mcp.Tool) + for name, conn := range m.servers { + if len(conn.Tools) > 0 { + result[name] = conn.Tools + } + } + return result +} diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go new file mode 100644 index 000000000..8ce81d09e --- /dev/null +++ b/pkg/mcp/manager_test.go @@ -0,0 +1,298 @@ +package mcp + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestLoadEnvFile(t *testing.T) { + tests := []struct { + name string + content string + expected map[string]string + expectErr bool + }{ + { + name: "basic env file", + content: `API_KEY=secret123 +DATABASE_URL=postgres://localhost/db +PORT=8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "with comments and empty lines", + content: `# This is a comment +API_KEY=secret123 + +# Another comment +DATABASE_URL=postgres://localhost/db + +PORT=8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "with quoted values", + content: `API_KEY="secret with spaces" +NAME='single quoted' +PLAIN=no-quotes`, + expected: map[string]string{ + "API_KEY": "secret with spaces", + "NAME": "single quoted", + "PLAIN": "no-quotes", + }, + expectErr: false, + }, + { + name: "with spaces around equals", + content: `API_KEY = secret123 +DATABASE_URL= postgres://localhost/db +PORT =8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "invalid format - no equals", + content: `INVALID_LINE`, + expectErr: true, + }, + { + name: "empty file", + content: ``, + expected: map[string]string{}, + expectErr: false, + }, + { + name: "only comments", + content: `# Comment 1 +# Comment 2`, + expected: map[string]string{}, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + + if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + result, err := loadEnvFile(envFile) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if len(result) != len(tt.expected) { + t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result)) + } + + for key, expectedValue := range tt.expected { + if actualValue, ok := result[key]; !ok { + t.Errorf("Expected key %s not found", key) + } else if actualValue != expectedValue { + t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue) + } + } + }) + } +} + +func TestLoadEnvFileNotFound(t *testing.T) { + _, err := loadEnvFile("/nonexistent/file.env") + if err == nil { + t.Error("Expected error for nonexistent file") + } +} + +func TestEnvFilePriority(t *testing.T) { + // Create a temporary .env file + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + + envContent := `API_KEY=from_file +DATABASE_URL=from_file +SHARED_VAR=from_file` + + if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil { + t.Fatalf("Failed to create .env file: %v", err) + } + + // Load envFile + envVars, err := loadEnvFile(envFile) + if err != nil { + t.Fatalf("Failed to load env file: %v", err) + } + + // Verify envFile variables + if envVars["API_KEY"] != "from_file" { + t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"]) + } + + // Simulate config.Env overriding envFile + configEnv := map[string]string{ + "SHARED_VAR": "from_config", + "NEW_VAR": "from_config", + } + + // Merge: envFile first, then config overrides + merged := make(map[string]string) + for k, v := range envVars { + merged[k] = v + } + for k, v := range configEnv { + merged[k] = v + } + + // Verify priority: config.Env should override envFile + if merged["SHARED_VAR"] != "from_config" { + t.Errorf( + "Expected SHARED_VAR=from_config (config should override file), got %s", + merged["SHARED_VAR"], + ) + } + if merged["API_KEY"] != "from_file" { + t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"]) + } + if merged["NEW_VAR"] != "from_config" { + t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"]) + } +} + +func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) { + mgr := NewManager() + + mcpCfg := config.MCPConfig{ + Enabled: true, + Servers: map[string]config.MCPServerConfig{ + "test-server": { + Enabled: true, + Command: "echo", + Args: []string{"ok"}, + EnvFile: ".env", + }, + }, + } + + err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "") + if err == nil { + t.Fatal("expected error for relative env_file with empty workspace path, got nil") + } + + if !strings.Contains(err.Error(), "workspace path is empty") { + t.Fatalf("expected workspace path validation error, got: %v", err) + } +} + +func TestNewManager_InitialState(t *testing.T) { + mgr := NewManager() + if mgr == nil { + t.Fatal("expected manager instance, got nil") + } + if len(mgr.GetServers()) != 0 { + t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers())) + } +} + +func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) { + mgr := NewManager() + + err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp") + if err != nil { + t.Fatalf("expected nil error when MCP disabled, got: %v", err) + } + + err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp") + if err != nil { + t.Fatalf("expected nil error when no servers configured, got: %v", err) + } +} + +func TestGetServers_ReturnsCopy(t *testing.T) { + mgr := NewManager() + mgr.servers["s1"] = &ServerConnection{Name: "s1"} + + servers := mgr.GetServers() + delete(servers, "s1") + + if _, ok := mgr.GetServer("s1"); !ok { + t.Fatal("expected internal manager state to remain unchanged") + } +} + +func TestGetAllTools_FiltersEmptyTools(t *testing.T) { + mgr := NewManager() + mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil} + mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}} + + all := mgr.GetAllTools() + if _, ok := all["empty"]; ok { + t.Fatal("expected server without tools to be excluded") + } + if _, ok := all["with-tools"]; !ok { + t.Fatal("expected server with tools to be included") + } +} + +func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) { + t.Run("manager closed", func(t *testing.T) { + mgr := NewManager() + mgr.closed.Store(true) + + _, err := mgr.CallTool(context.Background(), "s1", "tool", nil) + if err == nil || !strings.Contains(err.Error(), "manager is closed") { + t.Fatalf("expected manager closed error, got: %v", err) + } + }) + + t.Run("server missing", func(t *testing.T) { + mgr := NewManager() + + _, err := mgr.CallTool(context.Background(), "missing", "tool", nil) + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected server not found error, got: %v", err) + } + }) +} + +func TestClose_IdempotentOnEmptyManager(t *testing.T) { + mgr := NewManager() + + if err := mgr.Close(); err != nil { + t.Fatalf("first close should succeed, got: %v", err) + } + if err := mgr.Close(); err != nil { + t.Fatalf("second close should be idempotent, got: %v", err) + } +} diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go new file mode 100644 index 000000000..e12e2c5ab --- /dev/null +++ b/pkg/memory/jsonl.go @@ -0,0 +1,460 @@ +package memory + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "hash/fnv" + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/fileutil" + "github.com/sipeed/picoclaw/pkg/providers" +) + +const ( + // numLockShards is the fixed number of mutexes used to serialize + // per-session access. Using a sharded array instead of a map keeps + // memory bounded regardless of how many sessions are created over + // the lifetime of the process — important for a long-running daemon. + numLockShards = 64 + + // maxLineSize is the maximum size of a single JSON line in a .jsonl + // file. Tool results (read_file, web search, etc.) can be large, so + // we set a generous limit. The scanner starts at 64 KB and grows + // only as needed up to this cap. + maxLineSize = 10 * 1024 * 1024 // 10 MB +) + +// sessionMeta holds per-session metadata stored in a .meta.json file. +type sessionMeta struct { + Key string `json:"key"` + Summary string `json:"summary"` + Skip int `json:"skip"` + Count int `json:"count"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// JSONLStore implements Store using append-only JSONL files. +// +// Each session is stored as two files: +// +// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only +// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset) +// +// Messages are never physically deleted from the JSONL file. Instead, +// TruncateHistory records a "skip" offset in the metadata file and +// GetHistory ignores lines before that offset. This keeps all writes +// append-only, which is both fast and crash-safe. +type JSONLStore struct { + dir string + locks [numLockShards]sync.Mutex +} + +// NewJSONLStore creates a new JSONL-backed store rooted at dir. +func NewJSONLStore(dir string) (*JSONLStore, error) { + err := os.MkdirAll(dir, 0o755) + if err != nil { + return nil, fmt.Errorf("memory: create directory: %w", err) + } + return &JSONLStore{dir: dir}, nil +} + +// sessionLock returns a mutex for the given session key. +// Keys are mapped to a fixed pool of shards via FNV hash, so +// memory usage is O(1) regardless of total session count. +func (s *JSONLStore) sessionLock(key string) *sync.Mutex { + h := fnv.New32a() + h.Write([]byte(key)) + return &s.locks[h.Sum32()%numLockShards] +} + +func (s *JSONLStore) jsonlPath(key string) string { + return filepath.Join(s.dir, sanitizeKey(key)+".jsonl") +} + +func (s *JSONLStore) metaPath(key string) string { + return filepath.Join(s.dir, sanitizeKey(key)+".meta.json") +} + +// sanitizeKey converts a session key to a safe filename component. +// Mirrors pkg/session.sanitizeFilename so that migration paths match. +// +// Note: this is a lossy mapping — "telegram:123" and "telegram_123" +// both produce the same filename. This is an intentional tradeoff: +// keys with colons (e.g. from channels) are by far the common case, +// and a bidirectional encoding (like URL-encoding) would complicate +// file listings and debugging. +func sanitizeKey(key string) string { + return strings.ReplaceAll(key, ":", "_") +} + +// readMeta loads the metadata file for a session. +// Returns a zero-value sessionMeta if the file does not exist. +func (s *JSONLStore) readMeta(key string) (sessionMeta, error) { + data, err := os.ReadFile(s.metaPath(key)) + if os.IsNotExist(err) { + return sessionMeta{Key: key}, nil + } + if err != nil { + return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err) + } + var meta sessionMeta + err = json.Unmarshal(data, &meta) + if err != nil { + return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err) + } + return meta, nil +} + +// writeMeta atomically writes the metadata file using the project's +// standard WriteFileAtomic (temp + fsync + rename). +func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error { + data, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return fmt.Errorf("memory: encode meta: %w", err) + } + return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644) +} + +// readMessages reads valid JSON lines from a .jsonl file, skipping +// the first `skip` lines without unmarshaling them. This avoids the +// cost of json.Unmarshal on logically truncated messages. +// Malformed trailing lines (e.g. from a crash) are silently skipped. +func readMessages(path string, skip int) ([]providers.Message, error) { + f, err := os.Open(path) + if os.IsNotExist(err) { + return []providers.Message{}, nil + } + if err != nil { + return nil, fmt.Errorf("memory: open jsonl: %w", err) + } + defer f.Close() + + var msgs []providers.Message + scanner := bufio.NewScanner(f) + // Allow large lines for tool results (read_file, web search, etc.). + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + lineNum := 0 + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + lineNum++ + if lineNum <= skip { + continue + } + var msg providers.Message + if err := json.Unmarshal(line, &msg); err != nil { + // Corrupt line — likely a partial write from a crash. + // Log so operators know data was skipped, but don't + // fail the entire read; this is the standard JSONL + // recovery pattern. + log.Printf("memory: skipping corrupt line %d in %s: %v", + lineNum, filepath.Base(path), err) + continue + } + msgs = append(msgs, msg) + } + if scanner.Err() != nil { + return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err()) + } + + if msgs == nil { + msgs = []providers.Message{} + } + return msgs, nil +} + +// countLines counts the total number of non-empty lines in a .jsonl file. +// Used by TruncateHistory to reconcile a stale meta.Count without +// the overhead of unmarshaling every message. +func countLines(path string) (int, error) { + f, err := os.Open(path) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("memory: open jsonl: %w", err) + } + defer f.Close() + + n := 0 + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + for scanner.Scan() { + if len(scanner.Bytes()) > 0 { + n++ + } + } + return n, scanner.Err() +} + +func (s *JSONLStore) AddMessage( + _ context.Context, sessionKey, role, content string, +) error { + return s.addMsg(sessionKey, providers.Message{ + Role: role, + Content: content, + }) +} + +func (s *JSONLStore) AddFullMessage( + _ context.Context, sessionKey string, msg providers.Message, +) error { + return s.addMsg(sessionKey, msg) +} + +// addMsg is the shared implementation for AddMessage and AddFullMessage. +func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + // Append the message as a single JSON line. + line, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("memory: marshal message: %w", err) + } + line = append(line, '\n') + + f, err := os.OpenFile( + s.jsonlPath(sessionKey), + os.O_CREATE|os.O_WRONLY|os.O_APPEND, + 0o644, + ) + if err != nil { + return fmt.Errorf("memory: open jsonl for append: %w", err) + } + _, writeErr := f.Write(line) + if writeErr != nil { + f.Close() + return fmt.Errorf("memory: append message: %w", writeErr) + } + // Flush to physical storage before closing. This matches the + // durability guarantee of writeMeta and rewriteJSONL (which use + // WriteFileAtomic with fsync). Without Sync, a power loss could + // leave the append in the kernel page cache only — lost on reboot. + if syncErr := f.Sync(); syncErr != nil { + f.Close() + return fmt.Errorf("memory: sync jsonl: %w", syncErr) + } + if closeErr := f.Close(); closeErr != nil { + return fmt.Errorf("memory: close jsonl: %w", closeErr) + } + + // Update metadata. + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.Count == 0 && meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Count++ + meta.UpdatedAt = now + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) GetHistory( + _ context.Context, sessionKey string, +) ([]providers.Message, error) { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return nil, err + } + + // Pass meta.Skip so readMessages skips those lines without + // unmarshaling them — avoids wasted CPU on truncated messages. + msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip) + if err != nil { + return nil, err + } + + return msgs, nil +} + +func (s *JSONLStore) GetSummary( + _ context.Context, sessionKey string, +) (string, error) { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return "", err + } + return meta.Summary, nil +} + +func (s *JSONLStore) SetSummary( + _ context.Context, sessionKey, summary string, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Summary = summary + meta.UpdatedAt = now + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) TruncateHistory( + _ context.Context, sessionKey string, keepLast int, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + + // Always reconcile meta.Count with the actual line count on disk. + // A crash between the JSONL append and the meta update in addMsg + // leaves meta.Count stale (e.g. file has 101 lines but meta says + // 100). Counting lines is cheap — no unmarshal, just a scan — and + // TruncateHistory is not a hot path, so always re-count. + n, countErr := countLines(s.jsonlPath(sessionKey)) + if countErr != nil { + return countErr + } + meta.Count = n + + if keepLast <= 0 { + meta.Skip = meta.Count + } else { + effective := meta.Count - meta.Skip + if keepLast < effective { + meta.Skip = meta.Count - keepLast + } + } + meta.UpdatedAt = time.Now() + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) SetHistory( + _ context.Context, + sessionKey string, + history []providers.Message, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Skip = 0 + meta.Count = len(history) + meta.UpdatedAt = now + + // Write meta BEFORE rewriting the JSONL file. If we crash between + // the two writes, meta has Skip=0 and the old file is still intact, + // so GetHistory reads from line 1 — returning "too many" messages + // rather than losing data. The next SetHistory call corrects this. + err = s.writeMeta(sessionKey, meta) + if err != nil { + return err + } + + return s.rewriteJSONL(sessionKey, history) +} + +// Compact physically rewrites the JSONL file, dropping all logically +// skipped lines. This reclaims disk space that accumulates after +// repeated TruncateHistory calls. +// +// It is safe to call at any time; if there is nothing to compact +// (skip == 0) the method returns immediately. +func (s *JSONLStore) Compact( + _ context.Context, sessionKey string, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + if meta.Skip == 0 { + return nil + } + + // Read only the active messages, skipping truncated lines + // without unmarshaling them. + active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip) + if err != nil { + return err + } + + // Write meta BEFORE rewriting the JSONL file. If the process + // crashes between the two writes, meta has Skip=0 and the old + // (uncompacted) file is still intact, so GetHistory reads from + // line 1 — returning previously-truncated messages rather than + // losing data. The next Compact or TruncateHistory corrects this. + meta.Skip = 0 + meta.Count = len(active) + meta.UpdatedAt = time.Now() + + err = s.writeMeta(sessionKey, meta) + if err != nil { + return err + } + + return s.rewriteJSONL(sessionKey, active) +} + +// rewriteJSONL atomically replaces the JSONL file with the given messages +// using the project's standard WriteFileAtomic (temp + fsync + rename). +func (s *JSONLStore) rewriteJSONL( + sessionKey string, msgs []providers.Message, +) error { + var buf bytes.Buffer + for i, msg := range msgs { + line, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("memory: marshal message %d: %w", i, err) + } + buf.Write(line) + buf.WriteByte('\n') + } + return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644) +} + +func (s *JSONLStore) Close() error { + return nil +} diff --git a/pkg/memory/jsonl_test.go b/pkg/memory/jsonl_test.go new file mode 100644 index 000000000..356ff14ff --- /dev/null +++ b/pkg/memory/jsonl_test.go @@ -0,0 +1,835 @@ +package memory + +import ( + "context" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func newTestStore(t *testing.T) *JSONLStore { + t.Helper() + store, err := NewJSONLStore(t.TempDir()) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + return store +} + +func TestNewJSONLStore_CreatesDirectory(t *testing.T) { + dir := filepath.Join(t.TempDir(), "nested", "sessions") + store, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + + info, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat: %v", err) + } + if !info.IsDir() { + t.Errorf("expected directory, got file") + } +} + +func TestAddMessage_BasicRoundtrip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "s1", "user", "hello") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store.AddMessage(ctx, "s1", "assistant", "hi there") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hello" { + t.Errorf("msg[0] = %+v", history[0]) + } + if history[1].Role != "assistant" || history[1].Content != "hi there" { + t.Errorf("msg[1] = %+v", history[1]) + } +} + +func TestAddMessage_AutoCreatesSession(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Adding a message to a non-existent session should work. + err := store.AddMessage(ctx, "new-session", "user", "first message") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "new-session") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } +} + +func TestAddFullMessage_WithToolCalls(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "assistant", + Content: "Let me search that.", + ToolCalls: []providers.ToolCall{ + { + ID: "call_abc", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"q":"golang jsonl"}`, + }, + }, + }, + } + + err := store.AddFullMessage(ctx, "tc", msg) + if err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "tc") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if len(history[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls)) + } + tc := history[0].ToolCalls[0] + if tc.ID != "call_abc" { + t.Errorf("tool call ID = %q", tc.ID) + } + if tc.Function == nil || tc.Function.Name != "web_search" { + t.Errorf("tool call function = %+v", tc.Function) + } +} + +func TestAddFullMessage_ToolCallID(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "tool", + Content: "search results here", + ToolCallID: "call_abc", + } + + err := store.AddFullMessage(ctx, "tr", msg) + if err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "tr") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if history[0].ToolCallID != "call_abc" { + t.Errorf("ToolCallID = %q", history[0].ToolCallID) + } +} + +func TestGetHistory_EmptySession(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + history, err := store.GetHistory(ctx, "nonexistent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if history == nil { + t.Fatal("expected non-nil empty slice") + } + if len(history) != 0 { + t.Errorf("expected 0 messages, got %d", len(history)) + } +} + +func TestGetHistory_Ordering(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage( + ctx, "order", + "user", + string(rune('a'+i)), + ) + if err != nil { + t.Fatalf("AddMessage(%d): %v", i, err) + } + } + + history, err := store.GetHistory(ctx, "order") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 5 { + t.Fatalf("expected 5, got %d", len(history)) + } + for i := 0; i < 5; i++ { + expected := string(rune('a' + i)) + if history[i].Content != expected { + t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected) + } + } +} + +func TestSetSummary_GetSummary(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // No summary yet. + summary, err := store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "" { + t.Errorf("expected empty, got %q", summary) + } + + // Set a summary. + err = store.SetSummary(ctx, "s1", "talked about Go") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "talked about Go" { + t.Errorf("summary = %q", summary) + } + + // Update summary. + err = store.SetSummary(ctx, "s1", "updated summary") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "updated summary" { + t.Errorf("summary = %q", summary) + } +} + +func TestTruncateHistory_KeepLast(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 10; i++ { + err := store.AddMessage( + ctx, "trunc", + "user", + string(rune('a'+i)), + ) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "trunc", 4) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "trunc") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 4 { + t.Fatalf("expected 4, got %d", len(history)) + } + // Should be the last 4: g, h, i, j + if history[0].Content != "g" { + t.Errorf("first kept = %q, want 'g'", history[0].Content) + } + if history[3].Content != "j" { + t.Errorf("last kept = %q, want 'j'", history[3].Content) + } +} + +func TestTruncateHistory_KeepZero(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "empty", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "empty", 0) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "empty") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 0 { + t.Errorf("expected 0, got %d", len(history)) + } +} + +func TestTruncateHistory_KeepMoreThanExists(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + err := store.AddMessage(ctx, "few", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Keep 100, but only 3 exist — should keep all. + err := store.TruncateHistory(ctx, "few", 100) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "few") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Errorf("expected 3, got %d", len(history)) + } +} + +func TestSetHistory_ReplacesAll(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Add some initial messages. + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "replace", "user", "old") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Replace with new history. + newHistory := []providers.Message{ + {Role: "user", Content: "new1"}, + {Role: "assistant", Content: "new2"}, + } + err := store.SetHistory(ctx, "replace", newHistory) + if err != nil { + t.Fatalf("SetHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "replace") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2, got %d", len(history)) + } + if history[0].Content != "new1" || history[1].Content != "new2" { + t.Errorf("history = %+v", history) + } +} + +func TestSetHistory_ResetsSkip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Add messages and truncate. + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "skip-reset", "user", "old") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + err := store.TruncateHistory(ctx, "skip-reset", 3) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + // SetHistory should reset skip to 0. + newHistory := []providers.Message{ + {Role: "user", Content: "fresh"}, + } + err = store.SetHistory(ctx, "skip-reset", newHistory) + if err != nil { + t.Fatalf("SetHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "skip-reset") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if history[0].Content != "fresh" { + t.Errorf("content = %q", history[0].Content) + } +} + +func TestColonInKey(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "telegram:123", "user", "hi") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "telegram:123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + + // Verify the file is named with underscore. + jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl") + if _, statErr := os.Stat(jsonlFile); statErr != nil { + t.Errorf("expected file %s to exist: %v", jsonlFile, statErr) + } +} + +func TestCompact_RemovesSkippedMessages(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Write 10 messages, then truncate to keep last 3. + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + err := store.TruncateHistory(ctx, "compact", 3) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + // Before compact: file still has 10 lines. + allOnDisk, err := readMessages(store.jsonlPath("compact"), 0) + if err != nil { + t.Fatalf("readMessages: %v", err) + } + if len(allOnDisk) != 10 { + t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk)) + } + + // Compact. + err = store.Compact(ctx, "compact") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // After compact: file should have only 3 lines. + allOnDisk, err = readMessages(store.jsonlPath("compact"), 0) + if err != nil { + t.Fatalf("readMessages: %v", err) + } + if len(allOnDisk) != 3 { + t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk)) + } + + // GetHistory should still return the same 3 messages. + history, err := store.GetHistory(ctx, "compact") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("expected 3, got %d", len(history)) + } + if history[0].Content != "h" || history[2].Content != "j" { + t.Errorf("wrong content: %+v", history) + } +} + +func TestCompact_NoOpWhenNoSkip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "noop", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Compact without prior truncation — should be a no-op. + err := store.Compact(ctx, "noop") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + history, err := store.GetHistory(ctx, "noop") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 5 { + t.Errorf("expected 5, got %d", len(history)) + } +} + +func TestCompact_ThenAppend(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 8; i++ { + err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "cap", 2) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + err = store.Compact(ctx, "cap") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // Append after compaction should work correctly. + err = store.AddMessage(ctx, "cap", "user", "new") + if err != nil { + t.Fatalf("AddMessage after compact: %v", err) + } + + history, err := store.GetHistory(ctx, "cap") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("expected 3, got %d", len(history)) + } + // g, h (kept from truncation), new (appended after compaction). + if history[0].Content != "g" { + t.Errorf("first = %q, want 'g'", history[0].Content) + } + if history[2].Content != "new" { + t.Errorf("last = %q, want 'new'", history[2].Content) + } +} + +func TestTruncateHistory_StaleMetaCount(t *testing.T) { + // Simulates a crash between JSONL append and meta update in addMsg: + // file has N+1 lines but meta.Count is still N. TruncateHistory must + // reconcile with the real line count so that keepLast is accurate. + store := newTestStore(t) + ctx := context.Background() + + // Write 10 messages normally (meta.Count = 10). + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Simulate crash: append a line to JSONL but do NOT update meta. + // This leaves meta.Count = 10 while the file has 11 lines. + jsonlPath := store.jsonlPath("stale") + f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + t.Fatalf("open for append: %v", err) + } + _, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n") + if err != nil { + t.Fatalf("write orphan: %v", err) + } + f.Close() + + // TruncateHistory(keepLast=4) should keep the last 4 of 11 lines, + // not the last 4 of 10. + err = store.TruncateHistory(ctx, "stale", 4) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "stale") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 4 { + t.Fatalf("expected 4, got %d", len(history)) + } + // Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan] + if history[0].Content != "h" { + t.Errorf("first kept = %q, want 'h'", history[0].Content) + } + if history[3].Content != "orphan" { + t.Errorf("last kept = %q, want 'orphan'", history[3].Content) + } +} + +func TestCrashRecovery_PartialLine(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Write a valid message first. + err := store.AddMessage(ctx, "crash", "user", "valid") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + // Simulate a crash by appending a partial JSON line directly. + jsonlPath := store.jsonlPath("crash") + f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + t.Fatalf("open for append: %v", err) + } + _, err = f.WriteString(`{"role":"user","content":"incomple`) + if err != nil { + t.Fatalf("write partial: %v", err) + } + f.Close() + + // GetHistory should return only the valid message. + history, err := store.GetHistory(ctx, "crash") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 valid message, got %d", len(history)) + } + if history[0].Content != "valid" { + t.Errorf("content = %q", history[0].Content) + } +} + +func TestPersistence_AcrossInstances(t *testing.T) { + dir := t.TempDir() + ctx := context.Background() + + // Write with first instance. + store1, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + err = store1.AddMessage(ctx, "persist", "user", "remember me") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store1.SetSummary(ctx, "persist", "a test session") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + store1.Close() + + // Read with second instance. + store2, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + defer store2.Close() + + history, err := store2.GetHistory(ctx, "persist") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 || history[0].Content != "remember me" { + t.Errorf("history = %+v", history) + } + + summary, err := store2.GetSummary(ctx, "persist") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "a test session" { + t.Errorf("summary = %q", summary) + } +} + +func TestConcurrent_AddAndRead(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + var wg sync.WaitGroup + const goroutines = 10 + const msgsPerGoroutine = 20 + + // Concurrent writes. + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < msgsPerGoroutine; i++ { + _ = store.AddMessage(ctx, "concurrent", "user", "msg") + } + }() + } + wg.Wait() + + history, err := store.GetHistory(ctx, "concurrent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + expected := goroutines * msgsPerGoroutine + if len(history) != expected { + t.Errorf("expected %d messages, got %d", expected, len(history)) + } +} + +func TestConcurrent_SummarizeRace(t *testing.T) { + // Simulates the #704 race: one goroutine adds messages while + // another truncates + sets summary — like summarizeSession(). + store := newTestStore(t) + ctx := context.Background() + + // Seed with some messages. + for i := 0; i < 20; i++ { + err := store.AddMessage(ctx, "race", "user", "seed") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + var wg sync.WaitGroup + + // Writer goroutine (main agent loop). + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 50; i++ { + _ = store.AddMessage(ctx, "race", "user", "new") + } + }() + + // Summarizer goroutine (background task). + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + _ = store.SetSummary(ctx, "race", "summary") + _ = store.TruncateHistory(ctx, "race", 5) + } + }() + + wg.Wait() + + // Verify the store is still in a consistent state. + _, err := store.GetHistory(ctx, "race") + if err != nil { + t.Fatalf("GetHistory after race: %v", err) + } + _, err = store.GetSummary(ctx, "race") + if err != nil { + t.Fatalf("GetSummary after race: %v", err) + } +} + +func TestMultipleSessions_Isolation(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "s1", "user", "msg for s1") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store.AddMessage(ctx, "s2", "user", "msg for s2") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + h1, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory s1: %v", err) + } + h2, err := store.GetHistory(ctx, "s2") + if err != nil { + t.Fatalf("GetHistory s2: %v", err) + } + + if len(h1) != 1 || h1[0].Content != "msg for s1" { + t.Errorf("s1 history = %+v", h1) + } + if len(h2) != 1 || h2[0].Content != "msg for s2" { + t.Errorf("s2 history = %+v", h2) + } +} + +func BenchmarkAddMessage(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = store.AddMessage(ctx, "bench", "user", "benchmark message content") + } +} + +func BenchmarkGetHistory_100(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + for i := 0; i < 100; i++ { + _ = store.AddMessage(ctx, "bench", "user", "message content") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = store.GetHistory(ctx, "bench") + } +} + +func BenchmarkGetHistory_1000(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + for i := 0; i < 1000; i++ { + _ = store.AddMessage(ctx, "bench", "user", "message content") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = store.GetHistory(ctx, "bench") + } +} diff --git a/pkg/memory/migration.go b/pkg/memory/migration.go new file mode 100644 index 000000000..c9d5176ab --- /dev/null +++ b/pkg/memory/migration.go @@ -0,0 +1,108 @@ +package memory + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// jsonSession mirrors pkg/session.Session for migration purposes. +type jsonSession struct { + Key string `json:"key"` + Messages []providers.Message `json:"messages"` + Summary string `json:"summary,omitempty"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` +} + +// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir, +// writes them into the Store, and renames each migrated file to +// .json.migrated as a backup. Returns the number of sessions migrated. +// +// Files that fail to parse are logged and skipped. Already-migrated +// files (.json.migrated) are ignored, making the function idempotent. +func MigrateFromJSON( + ctx context.Context, sessionsDir string, store Store, +) (int, error) { + entries, err := os.ReadDir(sessionsDir) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("memory: read sessions dir: %w", err) + } + + migrated := 0 + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".json") { + continue + } + // Skip already-migrated files. + if strings.HasSuffix(name, ".migrated") { + continue + } + + srcPath := filepath.Join(sessionsDir, name) + + data, readErr := os.ReadFile(srcPath) + if readErr != nil { + log.Printf("memory: migrate: skip %s: %v", name, readErr) + continue + } + + var sess jsonSession + if parseErr := json.Unmarshal(data, &sess); parseErr != nil { + log.Printf("memory: migrate: skip %s: %v", name, parseErr) + continue + } + + // Use the key from the JSON content, not the filename. + // Filenames are sanitized (":" → "_") but keys are not. + key := sess.Key + if key == "" { + key = strings.TrimSuffix(name, ".json") + } + + // Use SetHistory (atomic replace) instead of per-message + // AddFullMessage. This makes migration idempotent: if the + // process crashes after writing messages but before the + // rename below, a retry replaces the partial data cleanly + // instead of duplicating messages. + if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil { + return migrated, fmt.Errorf( + "memory: migrate %s: set history: %w", + name, setErr, + ) + } + + if sess.Summary != "" { + if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil { + return migrated, fmt.Errorf( + "memory: migrate %s: set summary: %w", + name, sumErr, + ) + } + } + + // Rename to .migrated as backup (not delete). + renameErr := os.Rename(srcPath, srcPath+".migrated") + if renameErr != nil { + log.Printf("memory: migrate: rename %s: %v", name, renameErr) + } + + migrated++ + } + + return migrated, nil +} diff --git a/pkg/memory/migration_test.go b/pkg/memory/migration_test.go new file mode 100644 index 000000000..3170758b7 --- /dev/null +++ b/pkg/memory/migration_test.go @@ -0,0 +1,384 @@ +package memory + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func writeJSONSession( + t *testing.T, dir string, filename string, sess jsonSession, +) { + t.Helper() + data, err := json.MarshalIndent(sess, "", " ") + if err != nil { + t.Fatalf("marshal session: %v", err) + } + err = os.WriteFile(filepath.Join(dir, filename), data, 0o644) + if err != nil { + t.Fatalf("write session file: %v", err) + } +} + +func TestMigrateFromJSON_Basic(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "test.json", jsonSession{ + Key: "test", + Messages: []providers.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi"}, + }, + Summary: "A greeting.", + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 migrated, got %d", count) + } + + history, err := store.GetHistory(ctx, "test") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Content != "hello" || history[1].Content != "hi" { + t.Errorf("unexpected messages: %+v", history) + } + + summary, err := store.GetSummary(ctx, "test") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "A greeting." { + t.Errorf("summary = %q", summary) + } +} + +func TestMigrateFromJSON_WithToolCalls(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "tools.json", jsonSession{ + Key: "tools", + Messages: []providers.Message{ + { + Role: "assistant", + Content: "Searching...", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"q":"test"}`, + }, + }, + }, + }, + { + Role: "tool", + Content: "result", + ToolCallID: "call_1", + }, + }, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + history, err := store.GetHistory(ctx, "tools") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if len(history[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls)) + } + if history[0].ToolCalls[0].Function.Name != "web_search" { + t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name) + } + if history[1].ToolCallID != "call_1" { + t.Errorf("ToolCallID = %q", history[1].ToolCallID) + } +} + +func TestMigrateFromJSON_MultipleFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + writeJSONSession(t, sessionsDir, key+".json", jsonSession{ + Key: key, + Messages: []providers.Message{{Role: "user", Content: "msg " + key}}, + Created: time.Now(), + Updated: time.Now(), + }) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 3 { + t.Errorf("expected 3, got %d", count) + } + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + history, histErr := store.GetHistory(ctx, key) + if histErr != nil { + t.Fatalf("GetHistory(%q): %v", key, histErr) + } + if len(history) != 1 { + t.Errorf("session %q: expected 1 msg, got %d", key, len(history)) + } + } +} + +func TestMigrateFromJSON_InvalidJSON(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + // One valid, one invalid. + writeJSONSession(t, sessionsDir, "good.json", jsonSession{ + Key: "good", + Messages: []providers.Message{{Role: "user", Content: "ok"}}, + Created: time.Now(), + Updated: time.Now(), + }) + err := os.WriteFile( + filepath.Join(sessionsDir, "bad.json"), + []byte("{invalid json"), + 0o644, + ) + if err != nil { + t.Fatalf("write bad file: %v", err) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 (bad file skipped), got %d", count) + } + + history, err := store.GetHistory(ctx, "good") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_RenamesFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "rename.json", jsonSession{ + Key: "rename", + Messages: []providers.Message{{Role: "user", Content: "hi"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + _, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + + // Original .json should not exist. + _, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json")) + if !os.IsNotExist(statErr) { + t.Error("rename.json should have been renamed") + } + // .json.migrated should exist. + _, statErr = os.Stat( + filepath.Join(sessionsDir, "rename.json.migrated"), + ) + if statErr != nil { + t.Errorf("rename.json.migrated should exist: %v", statErr) + } +} + +func TestMigrateFromJSON_Idempotent(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "idem.json", jsonSession{ + Key: "idem", + Messages: []providers.Message{{Role: "user", Content: "once"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count1, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("first migration: %v", err) + } + if count1 != 1 { + t.Errorf("first run: expected 1, got %d", count1) + } + + // Second run should find only .migrated files, skip them. + count2, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("second migration: %v", err) + } + if count2 != 0 { + t.Errorf("second run: expected 0, got %d", count2) + } + + history, err := store.GetHistory(ctx, "idem") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_ColonInKey(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + // File is named telegram_123 (sanitized), but the key inside is telegram:123. + writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{ + Key: "telegram:123", + Messages: []providers.Message{{Role: "user", Content: "from telegram"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + // Accessible via the original key "telegram:123". + history, err := store.GetHistory(ctx, "telegram:123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } + if history[0].Content != "from telegram" { + t.Errorf("content = %q", history[0].Content) + } + + // In the file-based store, "telegram:123" and "telegram_123" both + // sanitize to the same filename, so they share storage. This is + // expected — the colon-to-underscore mapping is a one-way function. + history2, err := store.GetHistory(ctx, "telegram_123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history2) != 1 { + t.Errorf("expected 1 (same file), got %d", len(history2)) + } +} + +func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) { + // Simulates a crash during migration: first run writes messages + // but doesn't rename the .json file. Second run must replace + // (not duplicate) the messages thanks to SetHistory semantics. + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "retry.json", jsonSession{ + Key: "retry", + Messages: []providers.Message{ + {Role: "user", Content: "one"}, + {Role: "assistant", Content: "two"}, + }, + Created: time.Now(), + Updated: time.Now(), + }) + + // First migration succeeds — writes messages and renames file. + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("first migration: %v", err) + } + if count != 1 { + t.Fatalf("expected 1, got %d", count) + } + + // Simulate "crash before rename": restore the .json file. + src := filepath.Join(sessionsDir, "retry.json.migrated") + dst := filepath.Join(sessionsDir, "retry.json") + if renameErr := os.Rename(src, dst); renameErr != nil { + t.Fatalf("restore .json: %v", renameErr) + } + + // Second migration should re-import without duplicating messages. + count, err = MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("second migration: %v", err) + } + if count != 1 { + t.Fatalf("expected 1, got %d", count) + } + + history, err := store.GetHistory(ctx, "retry") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + // Must be exactly 2 messages (not 4 from duplication). + if len(history) != 2 { + t.Fatalf("expected 2 messages (no duplicates), got %d", len(history)) + } + if history[0].Content != "one" || history[1].Content != "two" { + t.Errorf("unexpected messages: %+v", history) + } +} + +func TestMigrateFromJSON_NonexistentDir(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + count, err := MigrateFromJSON(ctx, "/nonexistent/path", store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 0 { + t.Errorf("expected 0, got %d", count) + } +} diff --git a/pkg/memory/store.go b/pkg/memory/store.go new file mode 100644 index 000000000..b6e11707d --- /dev/null +++ b/pkg/memory/store.go @@ -0,0 +1,42 @@ +package memory + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Store defines an interface for persistent session storage. +// Each method is an atomic operation — there is no separate Save() call. +type Store interface { + // AddMessage appends a simple text message to a session. + AddMessage(ctx context.Context, sessionKey, role, content string) error + + // AddFullMessage appends a complete message (with tool calls, etc.) to a session. + AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error + + // GetHistory returns all messages for a session in insertion order. + // Returns an empty slice (not nil) if the session does not exist. + GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error) + + // GetSummary returns the conversation summary for a session. + // Returns an empty string if no summary exists. + GetSummary(ctx context.Context, sessionKey string) (string, error) + + // SetSummary updates the conversation summary for a session. + SetSummary(ctx context.Context, sessionKey, summary string) error + + // TruncateHistory removes all but the last keepLast messages from a session. + // If keepLast <= 0, all messages are removed. + TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error + + // SetHistory replaces all messages in a session with the provided history. + SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error + + // Compact reclaims storage by physically removing logically truncated + // data. Backends that do not accumulate dead data may return nil. + Compact(ctx context.Context, sessionKey string) error + + // Close releases any resources held by the store. + Close() error +} diff --git a/pkg/migrate/internal/common_test.go b/pkg/migrate/internal/common_test.go index a089157f5..a67293c19 100644 --- a/pkg/migrate/internal/common_test.go +++ b/pkg/migrate/internal/common_test.go @@ -118,64 +118,55 @@ func TestPlanWorkspaceMigration(t *testing.T) { assert.GreaterOrEqual(t, len(actions), 1) } -func TestPlanWorkspaceMigrationWithExistingDestination(t *testing.T) { - tmpDir := t.TempDir() - srcWorkspace := filepath.Join(tmpDir, "src", "workspace") - dstWorkspace := filepath.Join(tmpDir, "dst", "workspace") +func TestPlanWorkspaceMigrationExistingFile(t *testing.T) { + tests := []struct { + name string + force bool + wantActionType ActionType + }{ + { + name: "backup when not forced", + force: false, + wantActionType: ActionBackup, + }, + { + name: "copy when forced", + force: true, + wantActionType: ActionCopy, + }, + } - err := os.MkdirAll(srcWorkspace, 0o755) - require.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + srcWorkspace := filepath.Join(tmpDir, "src", "workspace") + dstWorkspace := filepath.Join(tmpDir, "dst", "workspace") - err = os.MkdirAll(dstWorkspace, 0o755) - require.NoError(t, err) + err := os.MkdirAll(srcWorkspace, 0o755) + require.NoError(t, err) - err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644) - require.NoError(t, err) + err = os.MkdirAll(dstWorkspace, 0o755) + require.NoError(t, err) - err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644) - require.NoError(t, err) + err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644) + require.NoError(t, err) - actions, err := PlanWorkspaceMigration( - srcWorkspace, - dstWorkspace, - []string{"file1.txt"}, - []string{}, - false, - ) - require.NoError(t, err) + err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644) + require.NoError(t, err) - require.GreaterOrEqual(t, len(actions), 1) - assert.Equal(t, ActionBackup, actions[0].Type) -} + actions, err := PlanWorkspaceMigration( + srcWorkspace, + dstWorkspace, + []string{"file1.txt"}, + []string{}, + tt.force, + ) + require.NoError(t, err) -func TestPlanWorkspaceMigrationForce(t *testing.T) { - tmpDir := t.TempDir() - srcWorkspace := filepath.Join(tmpDir, "src", "workspace") - dstWorkspace := filepath.Join(tmpDir, "dst", "workspace") - - err := os.MkdirAll(srcWorkspace, 0o755) - require.NoError(t, err) - - err = os.MkdirAll(dstWorkspace, 0o755) - require.NoError(t, err) - - err = os.WriteFile(filepath.Join(srcWorkspace, "file1.txt"), []byte("source"), 0o644) - require.NoError(t, err) - - err = os.WriteFile(filepath.Join(dstWorkspace, "file1.txt"), []byte("existing"), 0o644) - require.NoError(t, err) - - actions, err := PlanWorkspaceMigration( - srcWorkspace, - dstWorkspace, - []string{"file1.txt"}, - []string{}, - true, - ) - require.NoError(t, err) - - require.GreaterOrEqual(t, len(actions), 1) - assert.Equal(t, ActionCopy, actions[0].Type) + require.GreaterOrEqual(t, len(actions), 1) + assert.Equal(t, tt.wantActionType, actions[0].Type) + }) + } } func TestPlanWorkspaceMigrationNonExistentSource(t *testing.T) { diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go index 74ec33b98..6c4f6a767 100644 --- a/pkg/providers/claude_cli_provider.go +++ b/pkg/providers/claude_cli_provider.go @@ -100,44 +100,12 @@ func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDe } if len(tools) > 0 { - parts = append(parts, p.buildToolsPrompt(tools)) + parts = append(parts, buildCLIToolsPrompt(tools)) } return strings.Join(parts, "\n\n") } -// buildToolsPrompt creates the tool definitions section for the system prompt. -func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string { - var sb strings.Builder - - sb.WriteString("## Available Tools\n\n") - sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") - sb.WriteString("```json\n") - sb.WriteString( - `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, - ) - sb.WriteString("\n```\n\n") - sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") - sb.WriteString("### Tool Definitions:\n\n") - - for _, tool := range tools { - if tool.Type != "function" { - continue - } - sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) - if tool.Function.Description != "" { - sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) - } - if len(tool.Function.Parameters) > 0 { - paramsJSON, _ := json.Marshal(tool.Function.Parameters) - sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) - } - sb.WriteString("\n") - } - - return sb.String() -} - // parseClaudeCliResponse parses the JSON output from the claude CLI. func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) { var resp claudeCliJSONResponse diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index 3a3cafaca..d4d648f5a 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -660,12 +660,11 @@ func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) { // --- buildToolsPrompt tests --- func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) { - p := NewClaudeCliProvider("/workspace") tools := []ToolDefinition{ {Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}}, {Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}}, } - got := p.buildToolsPrompt(tools) + got := buildCLIToolsPrompt(tools) if strings.Contains(got, "skip_me") { t.Error("buildToolsPrompt() should skip non-function tools") } @@ -675,11 +674,10 @@ func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) { } func TestBuildToolsPrompt_NoDescription(t *testing.T) { - p := NewClaudeCliProvider("/workspace") tools := []ToolDefinition{ {Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}}, } - got := p.buildToolsPrompt(tools) + got := buildCLIToolsPrompt(tools) if !strings.Contains(got, "bare_tool") { t.Error("should include tool name") } @@ -689,14 +687,13 @@ func TestBuildToolsPrompt_NoDescription(t *testing.T) { } func TestBuildToolsPrompt_NoParameters(t *testing.T) { - p := NewClaudeCliProvider("/workspace") tools := []ToolDefinition{ {Type: "function", Function: ToolFunctionDefinition{ Name: "no_params_tool", Description: "A tool with no parameters", }}, } - got := p.buildToolsPrompt(tools) + got := buildCLIToolsPrompt(tools) if strings.Contains(got, "Parameters:") { t.Error("should not include Parameters: section when nil") } diff --git a/pkg/providers/codex_cli_provider.go b/pkg/providers/codex_cli_provider.go index 4c783ece5..13f53ad9e 100644 --- a/pkg/providers/codex_cli_provider.go +++ b/pkg/providers/codex_cli_provider.go @@ -115,7 +115,7 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio } if len(tools) > 0 { - sb.WriteString(p.buildToolsPrompt(tools)) + sb.WriteString(buildCLIToolsPrompt(tools)) sb.WriteString("\n\n") } @@ -128,38 +128,6 @@ func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinitio return sb.String() } -// buildToolsPrompt creates a tool definitions section for the prompt. -func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string { - var sb strings.Builder - - sb.WriteString("## Available Tools\n\n") - sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") - sb.WriteString("```json\n") - sb.WriteString( - `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, - ) - sb.WriteString("\n```\n\n") - sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") - sb.WriteString("### Tool Definitions:\n\n") - - for _, tool := range tools { - if tool.Type != "function" { - continue - } - sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) - if tool.Function.Description != "" { - sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) - } - if len(tool.Function.Parameters) > 0 { - paramsJSON, _ := json.Marshal(tool.Function.Parameters) - sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) - } - sb.WriteString("\n") - } - - return sb.String() -} - // codexEvent represents a single JSONL event from `codex exec --json`. type codexEvent struct { Type string `json:"type"` diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index 20348dc27..5199e77b3 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -102,6 +102,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = "https://openrouter.ai/api/v1" } } + case "litellm": + if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" { + sel.apiKey = cfg.Providers.LiteLLM.APIKey + sel.apiBase = cfg.Providers.LiteLLM.APIBase + sel.proxy = cfg.Providers.LiteLLM.Proxy + if sel.apiBase == "" { + sel.apiBase = "http://localhost:4000/v1" + } + } case "zhipu", "glm": if cfg.Providers.Zhipu.APIKey != "" { sel.apiKey = cfg.Providers.Zhipu.APIKey diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index a119ca158..caa0666e0 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -53,7 +53,7 @@ func ExtractProtocol(model string) (protocol, modelID string) { // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. -// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot +// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot // Returns the provider, the model ID (without protocol prefix), and any error. func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) { if cfg == nil { @@ -92,7 +92,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.RequestTimeout, ), modelID, nil - case "openrouter", "groq", "zhipu", "gemini", "nvidia", + case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "vivgrid", "volcengine", "vllm", "qwen", "mistral": // All other OpenAI-compatible HTTP providers @@ -180,6 +180,8 @@ func getDefaultAPIBase(protocol string) string { return "https://api.openai.com/v1" case "openrouter": return "https://openrouter.ai/api/v1" + case "litellm": + return "http://localhost:4000/v1" case "groq": return "https://api.groq.com/openai/v1" case "zhipu": diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 31cae3442..17bc55d25 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -136,6 +136,32 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) { } } +func TestGetDefaultAPIBase_LiteLLM(t *testing.T) { + if got := getDefaultAPIBase("litellm"); got != "http://localhost:4000/v1" { + t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "litellm", got, "http://localhost:4000/v1") + } +} + +func TestCreateProviderFromConfig_LiteLLM(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-litellm", + Model: "litellm/my-proxy-alias", + APIKey: "test-key", + APIBase: "http://localhost:4000/v1", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "my-proxy-alias" { + t.Errorf("modelID = %q, want %q", modelID, "my-proxy-alias") + } +} + func TestCreateProviderFromConfig_Anthropic(t *testing.T) { cfg := &config.ModelConfig{ ModelName: "test-anthropic", diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go index 222aff6f2..36ccda4a1 100644 --- a/pkg/providers/factory_test.go +++ b/pkg/providers/factory_test.go @@ -17,6 +17,27 @@ func TestResolveProviderSelection(t *testing.T) { wantProxy string wantErrSubstr string }{ + { + name: "explicit litellm provider uses configured base", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "litellm" + cfg.Providers.LiteLLM.APIKey = "litellm-key" + cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1" + cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "http://localhost:4000/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit litellm provider defaults base when only key is configured", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "litellm" + cfg.Providers.LiteLLM.APIKey = "litellm-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "http://localhost:4000/v1", + }, { name: "explicit claude-cli provider routes to cli provider type", setup: func(cfg *config.Config) { diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index b04a6ba2b..8cd436795 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -116,7 +116,7 @@ func (p *Provider) Chat( requestBody := map[string]any{ "model": model, - "messages": stripSystemParts(messages), + "messages": serializeMessages(messages), } if len(tools) > 0 { @@ -289,24 +289,62 @@ func parseResponse(body []byte) (*LLMResponse, error) { // It mirrors protocoltypes.Message but omits SystemParts, which is an // internal field that would be unknown to third-party endpoints. type openaiMessage struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } -// stripSystemParts converts []Message to []openaiMessage, dropping the -// SystemParts field so it doesn't leak into the JSON payload sent to -// OpenAI-compatible APIs (some strict endpoints reject unknown fields). -func stripSystemParts(messages []Message) []openaiMessage { - out := make([]openaiMessage, len(messages)) - for i, m := range messages { - out[i] = openaiMessage{ - Role: m.Role, - Content: m.Content, - ToolCalls: m.ToolCalls, - ToolCallID: m.ToolCallID, +// serializeMessages converts internal Message structs to the OpenAI wire format. +// - Strips SystemParts (unknown to third-party endpoints) +// - Converts messages with Media to multipart content format (text + image_url parts) +// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages +func serializeMessages(messages []Message) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + if len(m.Media) == 0 { + out = append(out, openaiMessage{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolCallID: m.ToolCallID, + }) + continue } + + // Multipart content format for messages with media + parts := make([]map[string]any, 0, 1+len(m.Media)) + if m.Content != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": m.Content, + }) + } + for _, mediaURL := range m.Media { + parts = append(parts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": mediaURL, + }, + }) + } + + msg := map[string]any{ + "role": m.Role, + "content": parts, + } + if m.ToolCallID != "" { + msg["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + msg["tool_calls"] = m.ToolCalls + } + if m.ReasoningContent != "" { + msg["reasoning_content"] = m.ReasoningContent + } + out = append(out, msg) } return out } @@ -323,7 +361,11 @@ func normalizeModel(model, apiBase string) string { prefix := strings.ToLower(before) switch prefix { +<<<<<<< HEAD case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral", "vivgrid": +======= + case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral": +>>>>>>> origin_picoclaw/main return after default: return model diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index dc3f93d9f..9b6b56d36 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -5,8 +5,11 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { @@ -146,6 +149,56 @@ func TestProviderChat_ParsesReasoningContent(t *testing.T) { } } +func TestProviderChat_PreservesReasoningContentInHistory(t *testing.T) { + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + + // Simulate a multi-turn conversation where the assistant's previous + // reply included reasoning_content (e.g. from kimi-k2.5). + messages := []Message{ + {Role: "user", Content: "What is 1+1?"}, + {Role: "assistant", Content: "2", ReasoningContent: "Let me think... 1+1=2"}, + {Role: "user", Content: "What about 2+2?"}, + } + + _, err := p.Chat(t.Context(), messages, nil, "kimi-k2.5", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + // Verify reasoning_content is preserved in the serialized request. + reqMessages, ok := requestBody["messages"].([]any) + if !ok { + t.Fatalf("messages is not []any: %T", requestBody["messages"]) + } + assistantMsg, ok := reqMessages[1].(map[string]any) + if !ok { + t.Fatalf("assistant message is not map[string]any: %T", reqMessages[1]) + } + if assistantMsg["reasoning_content"] != "Let me think... 1+1=2" { + t.Errorf("reasoning_content not preserved in request, got %v", assistantMsg["reasoning_content"]) + } +} + func TestProviderChat_HTTPError(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad request", http.StatusBadRequest) @@ -206,6 +259,11 @@ func TestProviderChat_StripsGroqOllamaDeepseekVivgridPrefixes(t *testing.T) { input string wantModel string }{ + { + name: "strips litellm prefix and preserves proxy model name", + input: "litellm/my-proxy-alias", + wantModel: "my-proxy-alias", + }, { name: "strips groq prefix and keeps nested model", input: "groq/openai/gpt-oss-120b", @@ -372,3 +430,97 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) { t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout) } } + +func TestSerializeMessages_PlainText(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, + } + result := serializeMessages(messages) + + data, err := json.Marshal(result) + if err != nil { + t.Fatal(err) + } + + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["content"] != "hello" { + t.Fatalf("expected plain string content, got %v", msgs[0]["content"]) + } + if msgs[1]["reasoning_content"] != "thinking..." { + t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"]) + } +} + +func TestSerializeMessages_WithMedia(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + content, ok := msgs[0]["content"].([]any) + if !ok { + t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) + } + if len(content) != 2 { + t.Fatalf("expected 2 content parts, got %d", len(content)) + } + + textPart := content[0].(map[string]any) + if textPart["type"] != "text" || textPart["text"] != "describe this" { + t.Fatalf("text part mismatch: %v", textPart) + } + + imgPart := content[1].(map[string]any) + if imgPart["type"] != "image_url" { + t.Fatalf("expected image_url type, got %v", imgPart["type"]) + } + imgURL := imgPart["image_url"].(map[string]any) + if imgURL["url"] != "data:image/png;base64,abc123" { + t.Fatalf("image url mismatch: %v", imgURL["url"]) + } +} + +func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["tool_call_id"] != "call_1" { + t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"]) + } + // Content should be multipart array + if _, ok := msgs[0]["content"].([]any); !ok { + t.Fatalf("expected array content, got %T", msgs[0]["content"]) + } +} + +func TestSerializeMessages_StripsSystemParts(t *testing.T) { + messages := []protocoltypes.Message{ + { + Role: "system", + Content: "you are helpful", + SystemParts: []protocoltypes.ContentBlock{ + {Type: "text", Text: "you are helpful"}, + }, + }, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + raw := string(data) + if strings.Contains(raw, "system_parts") { + t.Fatal("system_parts should not appear in serialized output") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 99f13334e..194c1aa6f 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -65,6 +65,7 @@ type ContentBlock struct { type Message struct { Role string `json:"role"` Content string `json:"content"` + Media []string `json:"media,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters ToolCalls []ToolCall `json:"tool_calls,omitempty"` diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go index 49218b1b1..a33e1eb5c 100644 --- a/pkg/providers/toolcall_utils.go +++ b/pkg/providers/toolcall_utils.go @@ -5,7 +5,43 @@ package providers -import "encoding/json" +import ( + "encoding/json" + "fmt" + "strings" +) + +// buildCLIToolsPrompt creates the tool definitions section for a CLI provider system prompt. +func buildCLIToolsPrompt(tools []ToolDefinition) string { + var sb strings.Builder + + sb.WriteString("## Available Tools\n\n") + sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") + sb.WriteString("```json\n") + sb.WriteString( + `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`, + ) + sb.WriteString("\n```\n\n") + sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") + sb.WriteString("### Tool Definitions:\n\n") + + for _, tool := range tools { + if tool.Type != "function" { + continue + } + sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) + if tool.Function.Description != "" { + sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) + } + if len(tool.Function.Parameters) > 0 { + paramsJSON, _ := json.Marshal(tool.Function.Parameters) + sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) + } + sb.WriteString("\n") + } + + return sb.String() +} // NormalizeToolCall normalizes a ToolCall to ensure all fields are properly populated. // It handles cases where Name/Arguments might be in different locations (top-level vs Function) diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index fcbcf934b..30d84635a 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -64,6 +64,29 @@ type SkillsLoader struct { builtinSkills string // builtin skills } +// SkillRoots returns all unique skill root directories used by this loader. +// The order follows resolution priority: workspace > global > builtin. +func (sl *SkillsLoader) SkillRoots() []string { + roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills} + seen := make(map[string]struct{}, len(roots)) + out := make([]string, 0, len(roots)) + + for _, root := range roots { + trimmed := strings.TrimSpace(root) + if trimmed == "" { + continue + } + clean := filepath.Clean(trimmed) + if _, ok := seen[clean]; ok { + continue + } + seen[clean] = struct{}{} + out = append(out, clean) + } + + return out +} + func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader { return &SkillsLoader{ workspace: workspace, diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go index 9428bea62..31619f9c2 100644 --- a/pkg/skills/loader_test.go +++ b/pkg/skills/loader_test.go @@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) { }) } } + +func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) { + tmp := t.TempDir() + workspace := filepath.Join(tmp, "workspace") + global := filepath.Join(tmp, "global") + builtin := filepath.Join(tmp, "builtin") + + sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n") + roots := sl.SkillRoots() + + assert.Equal(t, []string{ + filepath.Join(workspace, "skills"), + global, + builtin, + }, roots) +} diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go new file mode 100644 index 000000000..6e53cf354 --- /dev/null +++ b/pkg/tools/mcp_tool.go @@ -0,0 +1,246 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "hash/fnv" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MCPManager defines the interface for MCP manager operations +// This allows for easier testing with mock implementations +type MCPManager interface { + CallTool( + ctx context.Context, + serverName, toolName string, + arguments map[string]any, + ) (*mcp.CallToolResult, error) +} + +// MCPTool wraps an MCP tool to implement the Tool interface +type MCPTool struct { + manager MCPManager + serverName string + tool *mcp.Tool +} + +// NewMCPTool creates a new MCP tool wrapper +func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool { + return &MCPTool{ + manager: manager, + serverName: serverName, + tool: tool, + } +} + +// sanitizeIdentifierComponent normalizes a string so it can be safely used +// as part of a tool/function identifier for downstream providers. +// It: +// - lowercases the string +// - replaces any character not in [a-z0-9_-] with '_' +// - collapses multiple consecutive '_' into a single '_' +// - trims leading/trailing '_' +// - falls back to "unnamed" if the result is empty +// - truncates overly long components to a reasonable length +func sanitizeIdentifierComponent(s string) string { + const maxLen = 64 + + s = strings.ToLower(s) + var b strings.Builder + b.Grow(len(s)) + + prevUnderscore := false + for _, r := range s { + isAllowed := (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-' + + if !isAllowed { + // Normalize any disallowed character to '_' + if !prevUnderscore { + b.WriteRune('_') + prevUnderscore = true + } + continue + } + + if r == '_' { + if prevUnderscore { + continue + } + prevUnderscore = true + } else { + prevUnderscore = false + } + + b.WriteRune(r) + } + + result := strings.Trim(b.String(), "_") + if result == "" { + result = "unnamed" + } + + if len(result) > maxLen { + result = result[:maxLen] + } + + return result +} + +// Name returns the tool name, prefixed with the server name. +// The total length is capped at 64 characters (OpenAI-compatible API limit). +// A short hash of the original (unsanitized) server and tool names is appended +// whenever sanitization is lossy or the name is truncated, ensuring that two +// names which differ only in disallowed characters remain distinct after sanitization. +func (t *MCPTool) Name() string { + // Prefix with server name to avoid conflicts, and sanitize components + sanitizedServer := sanitizeIdentifierComponent(t.serverName) + sanitizedTool := sanitizeIdentifierComponent(t.tool.Name) + full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool) + + // Check if sanitization was lossless (only lowercasing, no char replacement/truncation) + lossless := strings.ToLower(t.serverName) == sanitizedServer && + strings.ToLower(t.tool.Name) == sanitizedTool + + const maxTotal = 64 + if lossless && len(full) <= maxTotal { + return full + } + + // Sanitization was lossy or name too long: append hash of the ORIGINAL names + // (not the sanitized names) so different originals always yield different hashes. + h := fnv.New32a() + _, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name)) + suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars + + base := full + if len(base) > maxTotal-9 { + base = strings.TrimRight(full[:maxTotal-9], "_") + } + return base + "_" + suffix +} + +// Description returns the tool description +func (t *MCPTool) Description() string { + desc := t.tool.Description + if desc == "" { + desc = fmt.Sprintf("MCP tool from %s server", t.serverName) + } + // Add server info to description + return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc) +} + +// Parameters returns the tool parameters schema +func (t *MCPTool) Parameters() map[string]any { + // The InputSchema is already a JSON Schema object + schema := t.tool.InputSchema + + // Handle nil schema + if schema == nil { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + // Try direct conversion first (fast path) + if schemaMap, ok := schema.(map[string]any); ok { + return schemaMap + } + + // Handle json.RawMessage and []byte - unmarshal directly + var jsonData []byte + if rawMsg, ok := schema.(json.RawMessage); ok { + jsonData = rawMsg + } else if bytes, ok := schema.([]byte); ok { + jsonData = bytes + } + + if jsonData != nil { + var result map[string]any + if err := json.Unmarshal(jsonData, &result); err == nil { + return result + } + // Fallback on error + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + // For other types (structs, etc.), convert via JSON marshal/unmarshal + var err error + jsonData, err = json.Marshal(schema) + if err != nil { + // Fallback to empty schema if marshaling fails + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + var result map[string]any + if err := json.Unmarshal(jsonData, &result); err != nil { + // Fallback to empty schema if unmarshaling fails + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + return result +} + +// Execute executes the MCP tool +func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args) + if err != nil { + return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err) + } + + if result == nil { + nilErr := fmt.Errorf("MCP tool returned nil result without error") + return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr) + } + + // Handle error result from server + if result.IsError { + errMsg := extractContentText(result.Content) + return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)). + WithError(fmt.Errorf("MCP tool error: %s", errMsg)) + } + + // Extract text content from result + output := extractContentText(result.Content) + + return &ToolResult{ + ForLLM: output, + IsError: false, + } +} + +// extractContentText extracts text from MCP content array +func extractContentText(content []mcp.Content) string { + var parts []string + for _, c := range content { + switch v := c.(type) { + case *mcp.TextContent: + parts = append(parts, v.Text) + case *mcp.ImageContent: + // For images, just indicate that an image was returned + parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType)) + default: + // For other content types, use string representation + parts = append(parts, fmt.Sprintf("[Content: %T]", v)) + } + } + return strings.Join(parts, "\n") +} diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go new file mode 100644 index 000000000..95bb0f992 --- /dev/null +++ b/pkg/tools/mcp_tool_test.go @@ -0,0 +1,492 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MockMCPManager is a mock implementation of MCPManager interface for testing +type MockMCPManager struct { + callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) +} + +func (m *MockMCPManager) CallTool( + ctx context.Context, + serverName, toolName string, + arguments map[string]any, +) (*mcp.CallToolResult, error) { + if m.callToolFunc != nil { + return m.callToolFunc(ctx, serverName, toolName, arguments) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "mock result"}, + }, + IsError: false, + }, nil +} + +// TestNewMCPTool verifies MCP tool creation +func TestNewMCPTool(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{ + "type": "string", + "description": "Test input", + }, + }, + }, + } + + mcpTool := NewMCPTool(manager, "test_server", tool) + + if mcpTool == nil { + t.Fatal("NewMCPTool should not return nil") + } + // Verify tool properties we can access + if mcpTool.Name() != "mcp_test_server_test_tool" { + t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name()) + } +} + +// TestMCPTool_Name verifies tool name with server prefix +func TestMCPTool_Name(t *testing.T) { + tests := []struct { + name string + serverName string + toolName string + expected string + }{ + { + name: "simple name", + serverName: "github", + toolName: "create_issue", + expected: "mcp_github_create_issue", + }, + { + name: "filesystem server", + serverName: "filesystem", + toolName: "read_file", + expected: "mcp_filesystem_read_file", + }, + { + name: "remote server", + serverName: "remote-api", + toolName: "fetch_data", + expected: "mcp_remote-api_fetch_data", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{Name: tt.toolName} + mcpTool := NewMCPTool(manager, tt.serverName, tool) + + result := mcpTool.Name() + if result != tt.expected { + t.Errorf("Expected name '%s', got '%s'", tt.expected, result) + } + }) + } +} + +// TestMCPTool_Description verifies tool description generation +func TestMCPTool_Description(t *testing.T) { + tests := []struct { + name string + serverName string + toolDescription string + expectContains []string + }{ + { + name: "with description", + serverName: "github", + toolDescription: "Create a GitHub issue", + expectContains: []string{"[MCP:github]", "Create a GitHub issue"}, + }, + { + name: "empty description", + serverName: "filesystem", + toolDescription: "", + expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + Description: tt.toolDescription, + } + mcpTool := NewMCPTool(manager, tt.serverName, tool) + + result := mcpTool.Description() + + for _, expected := range tt.expectContains { + if !strings.Contains(result, expected) { + t.Errorf("Description should contain '%s', got: %s", expected, result) + } + } + }) + } +} + +// TestMCPTool_Parameters verifies parameter schema conversion +func TestMCPTool_Parameters(t *testing.T) { + tests := []struct { + name string + inputSchema any + expectType string + checkProperty string + expectProperty bool + }{ + { + name: "map schema", + inputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Search query", + }, + }, + "required": []string{"query"}, + }, + expectType: "object", + checkProperty: "query", + expectProperty: true, + }, + { + name: "nil schema", + inputSchema: nil, + expectType: "object", + expectProperty: false, + }, + { + name: "json.RawMessage schema", + inputSchema: []byte(`{ + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "Repository name" + }, + "stars": { + "type": "integer", + "description": "Minimum stars" + } + }, + "required": ["repo"] + }`), + expectType: "object", + checkProperty: "repo", + expectProperty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + InputSchema: tt.inputSchema, + } + mcpTool := NewMCPTool(manager, "test_server", tool) + + params := mcpTool.Parameters() + + if params == nil { + t.Fatal("Parameters should not be nil") + } + + if params["type"] != tt.expectType { + t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"]) + } + + // Check if property exists when expected + if tt.checkProperty != "" { + properties, ok := params["properties"].(map[string]any) + if !ok && tt.expectProperty { + t.Errorf("Expected properties to be a map") + return + } + if ok { + _, hasProperty := properties[tt.checkProperty] + if hasProperty != tt.expectProperty { + t.Errorf("Expected property '%s' existence: %v, got: %v", + tt.checkProperty, tt.expectProperty, hasProperty) + } + } + } + }) + } +} + +// TestMCPTool_Execute_Success tests successful tool execution +func TestMCPTool_Execute_Success(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + // Verify correct parameters passed + if serverName != "github" { + t.Errorf("Expected serverName 'github', got '%s'", serverName) + } + if toolName != "search_repos" { + t.Errorf("Expected toolName 'search_repos', got '%s'", toolName) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Found 3 repositories"}, + }, + IsError: false, + }, nil + }, + } + + tool := &mcp.Tool{ + Name: "search_repos", + Description: "Search GitHub repositories", + } + mcpTool := NewMCPTool(manager, "github", tool) + + ctx := context.Background() + args := map[string]any{ + "query": "golang mcp", + } + + result := mcpTool.Execute(ctx, args) + + if result == nil { + t.Fatal("Result should not be nil") + } + if result.IsError { + t.Errorf("Expected no error, got error: %s", result.ForLLM) + } + if result.ForLLM != "Found 3 repositories" { + t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM) + } +} + +// TestMCPTool_Execute_ManagerError tests execution when manager returns error +func TestMCPTool_Execute_ManagerError(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return nil, fmt.Errorf("connection failed") + }, + } + + tool := &mcp.Tool{Name: "test_tool"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]any{}) + + if result == nil { + t.Fatal("Result should not be nil") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if !strings.Contains(result.ForLLM, "MCP tool execution failed") { + t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "connection failed") { + t.Errorf("Error message should include original error, got: %s", result.ForLLM) + } +} + +// TestMCPTool_Execute_ServerError tests execution when server returns error +func TestMCPTool_Execute_ServerError(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Invalid API key"}, + }, + IsError: true, + }, nil + }, + } + + tool := &mcp.Tool{Name: "test_tool"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]any{}) + + if result == nil { + t.Fatal("Result should not be nil") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if !strings.Contains(result.ForLLM, "MCP tool returned error") { + t.Errorf("Error message should mention server error, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Invalid API key") { + t.Errorf("Error message should include server message, got: %s", result.ForLLM) + } +} + +// TestMCPTool_Execute_MultipleContent tests execution with multiple content items +func TestMCPTool_Execute_MultipleContent(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "First line"}, + &mcp.TextContent{Text: "Second line"}, + &mcp.TextContent{Text: "Third line"}, + }, + IsError: false, + }, nil + }, + } + + tool := &mcp.Tool{Name: "multi_output"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]any{}) + + if result.IsError { + t.Errorf("Expected no error, got: %s", result.ForLLM) + } + + expected := "First line\nSecond line\nThird line" + if result.ForLLM != expected { + t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM) + } +} + +// TestExtractContentText_TextContent tests text content extraction +func TestExtractContentText_TextContent(t *testing.T) { + content := []mcp.Content{ + &mcp.TextContent{Text: "Hello World"}, + &mcp.TextContent{Text: "Second message"}, + } + + result := extractContentText(content) + expected := "Hello World\nSecond message" + + if result != expected { + t.Errorf("Expected '%s', got '%s'", expected, result) + } +} + +// TestExtractContentText_ImageContent tests image content extraction +func TestExtractContentText_ImageContent(t *testing.T) { + content := []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("base64data"), + MIMEType: "image/png", + }, + } + + result := extractContentText(content) + + if !strings.Contains(result, "[Image:") { + t.Errorf("Expected image indicator, got: %s", result) + } + if !strings.Contains(result, "image/png") { + t.Errorf("Expected MIME type in output, got: %s", result) + } +} + +// TestExtractContentText_MixedContent tests mixed content types +func TestExtractContentText_MixedContent(t *testing.T) { + content := []mcp.Content{ + &mcp.TextContent{Text: "Description"}, + &mcp.ImageContent{ + Data: []byte("data"), + MIMEType: "image/jpeg", + }, + &mcp.TextContent{Text: "More text"}, + } + + result := extractContentText(content) + + if !strings.Contains(result, "Description") { + t.Errorf("Should contain text content, got: %s", result) + } + if !strings.Contains(result, "[Image:") { + t.Errorf("Should contain image indicator, got: %s", result) + } + if !strings.Contains(result, "More text") { + t.Errorf("Should contain second text, got: %s", result) + } +} + +// TestExtractContentText_EmptyContent tests empty content array +func TestExtractContentText_EmptyContent(t *testing.T) { + content := []mcp.Content{} + + result := extractContentText(content) + + if result != "" { + t.Errorf("Expected empty string for empty content, got: %s", result) + } +} + +// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface +func TestMCPTool_InterfaceCompliance(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{Name: "test"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + // Verify it implements Tool interface + var _ Tool = mcpTool +} + +// TestMCPTool_Parameters_MapSchema tests schema that's already a map +func TestMCPTool_Parameters_MapSchema(t *testing.T) { + manager := &MockMCPManager{} + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "description": "The name parameter", + }, + }, + "required": []string{"name"}, + } + + tool := &mcp.Tool{ + Name: "test_tool", + InputSchema: schema, + } + mcpTool := NewMCPTool(manager, "test_server", tool) + + params := mcpTool.Parameters() + + // Should return the schema as-is when it's already a map + if params["type"] != "object" { + t.Errorf("Expected type 'object', got '%v'", params["type"]) + } + + props, ok := params["properties"].(map[string]any) + if !ok { + t.Error("Properties should be a map") + } + + nameParam, ok := props["name"].(map[string]any) + if !ok { + t.Error("Name parameter should exist") + } + + if nameParam["type"] != "string" { + t.Errorf("Name type should be 'string', got '%v'", nameParam["type"]) + } +} diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 15ef4ff73..d1e4a373e 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -3,6 +3,7 @@ package tools import ( "context" "fmt" + "sync/atomic" ) type SendCallback func(channel, chatID, content string) error @@ -11,7 +12,7 @@ type MessageTool struct { sendCallback SendCallback defaultChannel string defaultChatID string - sentInRound bool // Tracks whether a message was sent in the current processing round + sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round } func NewMessageTool() *MessageTool { @@ -50,12 +51,12 @@ func (t *MessageTool) Parameters() map[string]any { func (t *MessageTool) SetContext(channel, chatID string) { t.defaultChannel = channel t.defaultChatID = chatID - t.sentInRound = false // Reset send tracking for new processing round + t.sentInRound.Store(false) // Reset send tracking for new processing round } // HasSentInRound returns true if the message tool sent a message during the current round. func (t *MessageTool) HasSentInRound() bool { - return t.sentInRound + return t.sentInRound.Load() } func (t *MessageTool) SetSendCallback(callback SendCallback) { @@ -94,7 +95,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes } } - t.sentInRound = true + t.sentInRound.Store(true) // Silent: user already received the message directly return &ToolResult{ ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index d37a093a8..0ba983e02 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry { func (r *ToolRegistry) Register(tool Tool) { r.mu.Lock() defer r.mu.Unlock() - r.tools[tool.Name()] = tool + name := tool.Name() + if _, exists := r.tools[name]; exists { + logger.WarnCF("tools", "Tool registration overwrites existing tool", + map[string]any{"name": name}) + } + r.tools[name] = tool } func (r *ToolRegistry) Get(name string) (Tool, bool) { diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index cdfe0d6ce..244f0d4a2 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -121,37 +122,53 @@ func RunToolLoop( } messages = append(messages, assistantMsg) - // 7. Execute tool calls - for _, tc := range normalizedToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "tool": tc.Name, - "iteration": iteration, - }) + // 7. Execute tool calls in parallel + type indexedResult struct { + result *ToolResult + tc providers.ToolCall + } - // Execute tool (no async callback for subagents - they run independently) - var toolResult *ToolResult - if config.Tools != nil { - toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) - } else { - toolResult = ErrorResult("No tools available") + results := make([]indexedResult, len(normalizedToolCalls)) + var wg sync.WaitGroup + + for i, tc := range normalizedToolCalls { + results[i].tc = tc + + wg.Add(1) + go func(idx int, tc providers.ToolCall) { + defer wg.Done() + + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "tool": tc.Name, + "iteration": iteration, + }) + + var toolResult *ToolResult + if config.Tools != nil { + toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) + } else { + toolResult = ErrorResult("No tools available") + } + results[idx].result = toolResult + }(i, tc) + } + wg.Wait() + + // Append results in original order + for _, r := range results { + contentForLLM := r.result.ForLLM + if contentForLLM == "" && r.result.Err != nil { + contentForLLM = r.result.Err.Error() } - // Determine content for LLM - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() - } - - // Add tool result message - toolResultMsg := providers.Message{ + messages = append(messages, providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) + ToolCallID: r.tc.ID, + }) } } diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 10498126b..7b14686c9 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -109,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in return "", fmt.Errorf("failed to read response: %w", err) } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body)) + } + var searchResp struct { Web struct { Results []struct { @@ -391,6 +395,88 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil } +type GLMSearchProvider struct { + apiKey string + baseURL string + searchEngine string + proxy string + client *http.Client +} + +func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := p.baseURL + if searchURL == "" { + searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search" + } + + payload := map[string]any{ + "search_query": query, + "search_engine": p.searchEngine, + "search_intent": false, + "count": count, + "content_size": "medium", + } + + bodyBytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := p.client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body)) + } + + var searchResp struct { + SearchResult []struct { + Title string `json:"title"` + Content string `json:"content"` + Link string `json:"link"` + } `json:"search_result"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + results := searchResp.SearchResult + if len(results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query)) + for i, item := range results { + if i >= count { + break + } + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link)) + if item.Content != "" { + lines = append(lines, fmt.Sprintf(" %s", item.Content)) + } + } + + return strings.Join(lines, "\n"), nil +} + type WebSearchTool struct { provider SearchProvider maxResults int @@ -409,6 +495,11 @@ type WebSearchToolOptions struct { PerplexityAPIKey string PerplexityMaxResults int PerplexityEnabled bool + GLMSearchAPIKey string + GLMSearchBaseURL string + GLMSearchEngine string + GLMSearchMaxResults int + GLMSearchEnabled bool Proxy string } @@ -416,7 +507,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { var provider SearchProvider maxResults := 5 - // Priority: Perplexity > Brave > Tavily > DuckDuckGo + // Priority: Perplexity > Brave > Tavily > DuckDuckGo > GLM Search if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { client, err := createHTTPClient(opts.Proxy, perplexityTimeout) if err != nil { @@ -458,6 +549,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { if opts.DuckDuckGoMaxResults > 0 { maxResults = opts.DuckDuckGoMaxResults } + } else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" { + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err) + } + searchEngine := opts.GLMSearchEngine + if searchEngine == "" { + searchEngine = "search_std" + } + provider = &GLMSearchProvider{ + apiKey: opts.GLMSearchAPIKey, + baseURL: opts.GLMSearchBaseURL, + searchEngine: searchEngine, + proxy: opts.Proxy, + client: client, + } + if opts.GLMSearchMaxResults > 0 { + maxResults = opts.GLMSearchMaxResults + } } else { return nil, nil } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 8a8b88131..bdd30d385 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -681,3 +681,135 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser) } } + +func TestWebTool_GLMSearch_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + } + if r.Header.Get("Authorization") != "Bearer test-glm-key" { + t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization")) + } + + var payload map[string]any + json.NewDecoder(r.Body).Decode(&payload) + if payload["search_query"] != "test query" { + t.Errorf("Expected search_query 'test query', got %v", payload["search_query"]) + } + if payload["search_engine"] != "search_std" { + t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"]) + } + + response := map[string]any{ + "id": "web-search-test", + "created": 1709568000, + "search_result": []map[string]any{ + { + "title": "Test GLM Result", + "content": "GLM search snippet", + "link": "https://example.com/glm", + "media": "Example", + "publish_date": "2026-03-04", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + tool, err := NewWebSearchTool(WebSearchToolOptions{ + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-glm-key", + GLMSearchBaseURL: server.URL, + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "query": "test query", + }) + + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + if !strings.Contains(result.ForUser, "Test GLM Result") { + t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser) + } + if !strings.Contains(result.ForUser, "https://example.com/glm") { + t.Errorf("Expected URL in output, got: %s", result.ForUser) + } + if !strings.Contains(result.ForUser, "via GLM Search") { + t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser) + } +} + +func TestWebTool_GLMSearch_APIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"invalid api key"}`)) + })) + defer server.Close() + + tool, err := NewWebSearchTool(WebSearchToolOptions{ + GLMSearchEnabled: true, + GLMSearchAPIKey: "bad-key", + GLMSearchBaseURL: server.URL, + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "query": "test query", + }) + + if !result.IsError { + t.Errorf("Expected IsError=true for 401 response") + } + if !strings.Contains(result.ForLLM, "status 401") { + t.Errorf("Expected status 401 in error, got: %s", result.ForLLM) + } +} + +func TestWebTool_GLMSearch_Priority(t *testing.T) { + // GLM Search should only be selected when all other providers are disabled + tool, err := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: true, + DuckDuckGoMaxResults: 5, + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-key", + GLMSearchBaseURL: "https://example.com", + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + // DuckDuckGo should win over GLM Search + if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok { + t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider) + } + + // With DuckDuckGo disabled, GLM Search should be selected + tool2, err := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: false, + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-key", + GLMSearchBaseURL: "https://example.com", + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + if _, ok := tool2.provider.(*GLMSearchProvider); !ok { + t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider) + } +} diff --git a/pkg/utils/media.go b/pkg/utils/media.go index a34889fb8..3e1c5d88e 100644 --- a/pkg/utils/media.go +++ b/pkg/utils/media.go @@ -3,6 +3,7 @@ package utils import ( "io" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -52,11 +53,12 @@ type DownloadOptions struct { Timeout time.Duration ExtraHeaders map[string]string LoggerPrefix string + ProxyURL string } // DownloadFile downloads a file from URL to a local temp directory. // Returns the local file path or empty string on error. -func DownloadFile(url, filename string, opts DownloadOptions) string { +func DownloadFile(urlStr, filename string, opts DownloadOptions) string { // Set defaults if opts.Timeout == 0 { opts.Timeout = 60 * time.Second @@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName) // Create HTTP request - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest("GET", urlStr, nil) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{ "error": err.Error(), @@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { } client := &http.Client{Timeout: opts.Timeout} + if opts.ProxyURL != "" { + proxyURL, parseErr := url.Parse(opts.ProxyURL) + if parseErr != nil { + logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{ + "error": parseErr.Error(), + "proxy": opts.ProxyURL, + }) + return "" + } + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } resp, err := client.Do(req) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{ "error": err.Error(), - "url": url, + "url": urlStr, }) return "" } @@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { if resp.StatusCode != http.StatusOK { logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{ "status": resp.StatusCode, - "url": url, + "url": urlStr, }) return "" } diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index f973e77fe..e949d7a22 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -10,12 +10,19 @@ import ( "net/http" "os" "path/filepath" + "strings" "time" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) +type Transcriber interface { + Name() string + Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) +} + type GroqTranscriber struct { apiKey string apiBase string @@ -152,8 +159,22 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) return &result, nil } -func (t *GroqTranscriber) IsAvailable() bool { - available := t.apiKey != "" - logger.DebugCF("voice", "Checking transcriber availability", map[string]any{"available": available}) - return available +func (t *GroqTranscriber) Name() string { + return "groq" +} + +// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or +// nil if no supported transcription provider is configured. +func DetectTranscriber(cfg *config.Config) Transcriber { + // Direct Groq provider config takes priority. + if key := cfg.Providers.Groq.APIKey; key != "" { + return NewGroqTranscriber(key) + } + // Fall back to any model-list entry that uses the groq/ protocol. + for _, mc := range cfg.ModelList { + if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" { + return NewGroqTranscriber(mc.APIKey) + } + } + return nil } diff --git a/pkg/voice/transcriber_test.go b/pkg/voice/transcriber_test.go new file mode 100644 index 000000000..9b6add333 --- /dev/null +++ b/pkg/voice/transcriber_test.go @@ -0,0 +1,160 @@ +package voice + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// Ensure GroqTranscriber satisfies the Transcriber interface at compile time. +var _ Transcriber = (*GroqTranscriber)(nil) + +func TestGroqTranscriberName(t *testing.T) { + tr := NewGroqTranscriber("sk-test") + if got := tr.Name(); got != "groq" { + t.Errorf("Name() = %q, want %q", got, "groq") + } +} + +func TestDetectTranscriber(t *testing.T) { + tests := []struct { + name string + cfg *config.Config + wantNil bool + wantName string + }{ + { + name: "no config", + cfg: &config.Config{}, + wantNil: true, + }, + { + name: "groq provider key", + cfg: &config.Config{ + Providers: config.ProvidersConfig{ + Groq: config.ProviderConfig{APIKey: "sk-groq-direct"}, + }, + }, + wantName: "groq", + }, + { + name: "groq via model list", + cfg: &config.Config{ + ModelList: []config.ModelConfig{ + {Model: "openai/gpt-4o", APIKey: "sk-openai"}, + {Model: "groq/llama-3.3-70b", APIKey: "sk-groq-model"}, + }, + }, + wantName: "groq", + }, + { + name: "groq model list entry without key is skipped", + cfg: &config.Config{ + ModelList: []config.ModelConfig{ + {Model: "groq/llama-3.3-70b", APIKey: ""}, + }, + }, + wantNil: true, + }, + { + name: "provider key takes priority over model list", + cfg: &config.Config{ + Providers: config.ProvidersConfig{ + Groq: config.ProviderConfig{APIKey: "sk-groq-direct"}, + }, + ModelList: []config.ModelConfig{ + {Model: "groq/llama-3.3-70b", APIKey: "sk-groq-model"}, + }, + }, + wantName: "groq", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tr := DetectTranscriber(tc.cfg) + if tc.wantNil { + if tr != nil { + t.Errorf("DetectTranscriber() = %v, want nil", tr) + } + return + } + if tr == nil { + t.Fatal("DetectTranscriber() = nil, want non-nil") + } + if got := tr.Name(); got != tc.wantName { + t.Errorf("Name() = %q, want %q", got, tc.wantName) + } + }) + } +} + +func TestTranscribe(t *testing.T) { + // Write a minimal fake audio file so the transcriber can open and send it. + tmpDir := t.TempDir() + audioPath := filepath.Join(tmpDir, "clip.ogg") + if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil { + t.Fatalf("failed to write fake audio file: %v", err) + } + + t.Run("success", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/audio/transcriptions" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer sk-test" { + t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(TranscriptionResponse{ + Text: "hello world", + Language: "en", + Duration: 1.5, + }) + })) + defer srv.Close() + + tr := NewGroqTranscriber("sk-test") + tr.apiBase = srv.URL + + resp, err := tr.Transcribe(context.Background(), audioPath) + if err != nil { + t.Fatalf("Transcribe() error: %v", err) + } + if resp.Text != "hello world" { + t.Errorf("Text = %q, want %q", resp.Text, "hello world") + } + if resp.Language != "en" { + t.Errorf("Language = %q, want %q", resp.Language, "en") + } + }) + + t.Run("api error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized) + })) + defer srv.Close() + + tr := NewGroqTranscriber("sk-bad") + tr.apiBase = srv.URL + + _, err := tr.Transcribe(context.Background(), audioPath) + if err == nil { + t.Fatal("expected error for non-200 response, got nil") + } + }) + + t.Run("missing file", func(t *testing.T) { + tr := NewGroqTranscriber("sk-test") + _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg")) + if err == nil { + t.Fatal("expected error for missing file, got nil") + } + }) +} diff --git a/scripts/test-docker-mcp.sh b/scripts/test-docker-mcp.sh new file mode 100755 index 000000000..9d582ffa0 --- /dev/null +++ b/scripts/test-docker-mcp.sh @@ -0,0 +1,49 @@ +#!/bin/sh +# Test script for MCP tools in Docker (full-featured image) + +set -e + +COMPOSE_FILE="docker/docker-compose.full.yml" +SERVICE="picoclaw-agent" + +echo "🧪 Testing MCP tools in Docker container (full-featured image)..." +echo "" + +# Build the image +echo "📦 Building Docker image..." +docker compose -f "$COMPOSE_FILE" build "$SERVICE" + +# Test npx +echo "✅ Testing npx..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'npx --version' + +# Test npm +echo "✅ Testing npm..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'npm --version' + +# Test node +echo "✅ Testing Node.js..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'node --version' + +# Test git +echo "✅ Testing git..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'git --version' + +# Test python +echo "✅ Testing Python..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'python3 --version' + +# Test uv +echo "✅ Testing uv..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'uv --version' + +# Test MCP server installation (quick) +echo "✅ Testing @modelcontextprotocol/server-filesystem MCP server install with npx..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c '