+
+Téléchargez l'APK depuis [picoclaw.io](https://picoclaw.io/download/) et installez-le directement. Pas besoin de Termux !
+
+**Option 2 : Termux**
+
+
+Terminal Launcher (pour les environnements à ressources limitées)
1. Installez [Termux](https://github.com/termux/termux-app) (téléchargez depuis [GitHub Releases](https://github.com/termux/termux-app/releases), ou cherchez dans F-Droid / Google Play)
2. Exécutez les commandes suivantes :
@@ -321,13 +341,6 @@ Suivez ensuite la section Terminal Launcher ci-dessous pour terminer la configur
-**Option 2 : Installation APK (bientôt disponible)**
-
-Un APK Android autonome avec WebUI intégré est en développement. Restez à l'écoute !
-
-
-Terminal Launcher (pour les environnements à ressources limitées)
-
Pour les environnements minimaux où seul le binaire principal `picoclaw` est disponible (sans Launcher UI), vous pouvez tout configurer via la ligne de commande et un fichier de configuration JSON.
**1. Initialiser**
diff --git a/README.id.md b/README.id.md
index 3fe4c1276..d3c556dde 100644
--- a/README.id.md
+++ b/README.id.md
@@ -56,6 +56,8 @@
## 📢 Berita
+2026-03-31 📱 **Dukungan Android!** PicoClaw sekarang berjalan di Android! Unduh APK di [picoclaw.io](https://picoclaw.io/download)
+
2026-03-25 🚀 **v0.2.4 Dirilis!** Perombakan arsitektur Agent (SubTurn, Hooks, Steering, EventBus), integrasi WeChat/WeCom, penguatan keamanan (.security.yml, penyaringan data sensitif), provider baru (AWS Bedrock, Azure, Xiaomi MiMo), dan 35 perbaikan bug. PicoClaw telah mencapai **26K Stars**!
2026-03-17 🚀 **v0.2.3 Dirilis!** UI system tray (Windows & Linux), pelacakan status sub-agent (`spawn_status`), eksperimental Gateway hot-reload, gerbang keamanan Cron, dan 2 perbaikan keamanan. PicoClaw telah mencapai **25K Stars**!
@@ -301,7 +303,25 @@ Untuk dokumentasi TUI lengkap, lihat [docs.picoclaw.io](https://docs.picoclaw.io
Berikan kehidupan kedua untuk ponsel lama Anda! Ubah menjadi Asisten AI pintar dengan PicoClaw.
-**Opsi 1: Termux (tersedia sekarang)**
+**Opsi 1: Instal APK**
+
+Pratinjau:
+
+
+
+
+
+
+
+
+
+
+Unduh APK dari [picoclaw.io](https://picoclaw.io/download/) dan instal langsung. Tanpa Termux!
+
+**Opsi 2: Termux**
+
+
+Terminal Launcher (untuk lingkungan dengan sumber daya terbatas)
1. Instal [Termux](https://github.com/termux/termux-app) (unduh dari [GitHub Releases](https://github.com/termux/termux-app/releases), atau cari di F-Droid / Google Play)
2. Jalankan perintah berikut:
@@ -318,13 +338,6 @@ Kemudian ikuti bagian Terminal Launcher di bawah untuk menyelesaikan konfigurasi
-**Opsi 2: Instal APK (segera hadir)**
-
-APK Android mandiri dengan WebUI bawaan sedang dalam pengembangan. Pantau terus!
-
-
-Terminal Launcher (untuk lingkungan dengan sumber daya terbatas)
-
Untuk lingkungan minimal di mana hanya binary inti `picoclaw` yang tersedia (tanpa Launcher UI), Anda dapat mengonfigurasi semuanya melalui command line dan file konfigurasi JSON.
**1. Inisialisasi**
diff --git a/README.it.md b/README.it.md
index 8748aea9c..6fe6c5e17 100644
--- a/README.it.md
+++ b/README.it.md
@@ -56,6 +56,8 @@
## 📢 Novità
+2026-03-31 📱 **Supporto Android!** PicoClaw ora funziona su Android! Scarica l'APK su [picoclaw.io](https://picoclaw.io/download)
+
2026-03-25 🚀 **v0.2.4 rilasciata!** Revisione dell'architettura Agent (SubTurn, Hooks, Steering, EventBus), integrazione WeChat/WeCom, rafforzamento della sicurezza (.security.yml, filtraggio dati sensibili), nuovi provider (AWS Bedrock, Azure, Xiaomi MiMo) e 35 correzioni di bug. PicoClaw raggiunge **26K Stars**!
2026-03-17 🚀 **v0.2.3 rilasciata!** Interfaccia system tray (Windows & Linux), query sullo stato dei sub-agent (`spawn_status`), hot-reload sperimentale del Gateway, gate di sicurezza per Cron e 2 correzioni di sicurezza. PicoClaw raggiunge **25K Stars**!
@@ -301,7 +303,25 @@ Per la documentazione dettagliata del TUI, vedi [docs.picoclaw.io](https://docs.
Dai una seconda vita al tuo telefono di dieci anni fa! Trasformalo in un assistente IA intelligente con PicoClaw.
-**Opzione 1: Termux (disponibile ora)**
+**Opzione 1: Installazione APK**
+
+Anteprima:
+
+
+
+
+
+
+
+
+
+
+Scarica l'APK da [picoclaw.io](https://picoclaw.io/download/) e installa direttamente. Senza Termux!
+
+**Opzione 2: Termux**
+
+
+Terminal Launcher (per ambienti con risorse limitate)
1. Installa [Termux](https://github.com/termux/termux-app) (scarica da [GitHub Releases](https://github.com/termux/termux-app/releases), o cerca su F-Droid / Google Play)
2. Esegui i seguenti comandi:
@@ -318,13 +338,6 @@ Poi segui la sezione Terminal Launcher qui sotto per completare la configurazion
-**Opzione 2: APK Install (prossimamente)**
-
-Un APK Android standalone con WebUI integrato è in sviluppo. Resta sintonizzato!
-
-
-Terminal Launcher (per ambienti con risorse limitate)
-
Per ambienti minimali dove è disponibile solo il binario core `picoclaw` (senza Launcher UI), puoi configurare tutto tramite riga di comando e un file di configurazione JSON.
**1. Inizializza**
diff --git a/README.ja.md b/README.ja.md
index 3772ff532..793c41fcb 100644
--- a/README.ja.md
+++ b/README.ja.md
@@ -56,6 +56,8 @@
## 📢 ニュース
+2026-03-31 📱 **Android サポート!** PicoClawがAndroidで動作!APKは[picoclaw.io](https://picoclaw.io/download)からダウンロード
+
2026-03-25 🚀 **v0.2.4 リリース!** Agent アーキテクチャ全面刷新(SubTurn、Hooks、Steering、EventBus)、WeChat/WeCom 統合、セキュリティ強化(.security.yml、機密データフィルタリング)、新プロバイダー(AWS Bedrock、Azure、Xiaomi MiMo)、35 件のバグ修正。PicoClaw **26K ⭐** 達成!
2026-03-17 🚀 **v0.2.3 リリース!** システムトレイ UI(Windows & Linux)、サブエージェントステータス追跡(`spawn_status`)、実験的 Gateway ホットリロード、cron セキュリティゲート、セキュリティ修正 2 件。PicoClaw **25K ⭐** 達成!
@@ -301,7 +303,25 @@ TUI の詳細なドキュメントは [docs.picoclaw.io](https://docs.picoclaw.i
10 年前のスマホに第二の人生を!PicoClaw でスマート AI アシスタントに変身させましょう。
-**オプション 1: Termux(現在利用可能)**
+**オプション 1: APK インストール**
+
+プレビュー:
+
+
+
+
+
+
+
+
+
+
+[picoclaw.io](https://picoclaw.io/download/) から APK をダウンロードして直接インストール。Termux 不要!
+
+**オプション 2: Termux**
+
+
+Terminal Launcher(リソース制約環境向け)
1. [Termux](https://github.com/termux/termux-app) をインストール([GitHub Releases](https://github.com/termux/termux-app/releases) からダウンロード、または F-Droid / Google Play で検索)
2. 以下のコマンドを実行:
@@ -318,13 +338,6 @@ termux-chroot ./picoclaw onboard # chroot で標準的な Linux ファイル
-**オプション 2: APK インストール(近日公開)**
-
-内蔵 WebUI を備えたスタンドアロン Android APK を開発中です。お楽しみに!
-
-
-Terminal Launcher(リソース制約環境向け)
-
`picoclaw` コアバイナリのみが利用可能な最小環境(Launcher UI なし)では、コマンドラインと JSON 設定ファイルですべてを設定できます。
**1. 初期化**
diff --git a/README.md b/README.md
index 947fed9e2..a48a53d47 100644
--- a/README.md
+++ b/README.md
@@ -56,6 +56,8 @@
## 📢 News
+2026-03-31 📱 **Android Support!** PicoClaw now runs on Android! Download the APK at [picoclaw.io](https://picoclaw.io/download)
+
2026-03-25 🚀 **v0.2.4 Released!** Agent architecture overhaul (SubTurn, Hooks, Steering, EventBus), WeChat/WeCom integration, security hardening (.security.yml, sensitive data filtering), new providers (AWS Bedrock, Azure, Xiaomi MiMo), and 35 bug fixes. PicoClaw has reached **26K Stars**!
2026-03-17 🚀 **v0.2.3 Released!** System tray UI (Windows & Linux), sub-agent status query (`spawn_status`), experimental Gateway hot-reload, Cron security gating, and 2 security fixes. PicoClaw has reached **25K Stars**!
@@ -301,7 +303,25 @@ For detailed TUI documentation, see [docs.picoclaw.io](https://docs.picoclaw.io)
Give your decade-old phone a second life! Turn it into a smart AI Assistant with PicoClaw.
-**Option 1: Termux (available now)**
+**Option 1: APK Install**
+
+Preview:
+
+
+
+
+
+
+
+
+
+
+Download the APK from [picoclaw.io](https://picoclaw.io/download/) and install directly. No Termux required!
+
+**Option 2: Termux**
+
+
+Terminal Launcher (for resource-constrained environments)
1. Install [Termux](https://github.com/termux/termux-app) (download from [GitHub Releases](https://github.com/termux/termux-app/releases), or search in F-Droid / Google Play)
2. Run the following commands:
@@ -318,13 +338,6 @@ Then follow the Terminal Launcher section below to complete configuration.
-**Option 2: APK Install (coming soon)**
-
-A standalone Android APK with built-in WebUI is in development. Stay tuned!
-
-
-Terminal Launcher (for resource-constrained environments)
-
For minimal environments where only the `picoclaw` core binary is available (no Launcher UI), you can configure everything via the command line and a JSON config file.
**1. Initialize**
@@ -441,7 +454,7 @@ For full provider configuration details, see [Providers & Models](docs/providers
## 💬 Channels (Chat Apps)
-Talk to your PicoClaw through 17+ messaging platforms:
+Talk to your PicoClaw through 18+ messaging platforms:
| Channel | Setup | Protocol | Docs |
|---------|-------|----------|------|
@@ -456,6 +469,7 @@ Talk to your PicoClaw through 17+ messaging platforms:
| **Feishu / Lark** | Medium (App ID + Secret) | WebSocket/SDK | [Guide](docs/channels/feishu/README.md) |
| **LINE** | Medium (credentials + webhook) | Webhook | [Guide](docs/channels/line/README.md) |
| **WeCom** | Easy (QR login or manual) | WebSocket | [Guide](docs/channels/wecom/README.md) |
+| **VK** | Easy (group token) | Long Poll | [Guide](docs/channels/vk/README.md) |
| **IRC** | Medium (server + nick) | IRC protocol | [Guide](docs/chat-apps.md#irc) |
| **OneBot** | Medium (WebSocket URL) | OneBot v11 | [Guide](docs/channels/onebot/README.md) |
| **MaixCam** | Easy (enable) | TCP socket | [Guide](docs/channels/maixcam/README.md) |
diff --git a/README.my.md b/README.my.md
index c07cdd005..f00fb438c 100644
--- a/README.my.md
+++ b/README.my.md
@@ -56,6 +56,8 @@
## 📢 Berita
+2026-03-31 📱 **Sokongan Android!** PicoClaw sekarang berjalan di Android! Muat turun APK di [picoclaw.io](https://picoclaw.io/download)
+
2026-03-25 🚀 **v0.2.4 Dikeluarkan!** Penstrukturan semula seni bina Agent (SubTurn, Hooks, Steering, EventBus), integrasi WeChat/WeCom, penguatan keselamatan (.security.yml, penapisan data sensitif), penyedia baharu (AWS Bedrock, Azure, Xiaomi MiMo), dan 35 pembetulan pepijat. PicoClaw mencapai **26K Stars**!
2026-03-17 🚀 **v0.2.3 Dikeluarkan!** UI dulang sistem (Windows & Linux), pertanyaan status sub-agent (`spawn_status`), muat semula panas Gateway eksperimental, kawalan keselamatan Cron, dan 2 pembetulan keselamatan. PicoClaw mencapai **25K Stars**!
@@ -298,7 +300,25 @@ Untuk dokumentasi TUI terperinci, lihat [docs.picoclaw.io](https://docs.picoclaw
Berikan telefon lama anda kehidupan baru! Jadikannya Pembantu AI pintar dengan PicoClaw.
-**Pilihan 1: Termux (tersedia sekarang)**
+**Pilihan 1: Pasang APK**
+
+Pratonton:
+
+
+
+
+
+
+
+
+
+
+Muat turun APK dari [picoclaw.io](https://picoclaw.io/download/) dan pasang secara langsung. Tiada Termux diperlukan!
+
+**Pilihan 2: Termux**
+
+
+Pelancar Terminal (untuk persekitaran terhad sumber)
1. Pasang [Termux](https://github.com/termux/termux-app) (muat turun dari [GitHub Releases](https://github.com/termux/termux-app/releases), atau cari di F-Droid / Google Play)
2. Jalankan arahan berikut:
@@ -315,13 +335,6 @@ Kemudian ikuti bahagian Pelancar Terminal di bawah untuk melengkapkan konfiguras
-**Pilihan 2: APK (akan datang)**
-
-APK Android bebas dengan WebUI terbina dalam sedang dalam pembangunan. Nantikan!
-
-
-Pelancar Terminal (untuk persekitaran terhad sumber)
-
Untuk persekitaran minimal di mana hanya binari teras `picoclaw` tersedia (tiada UI Pelancar), anda boleh mengkonfigurasi semua melalui baris arahan dan fail konfigurasi JSON.
**1. Mulakan**
diff --git a/README.pt-br.md b/README.pt-br.md
index dfe7cb0f2..db11d4d82 100644
--- a/README.pt-br.md
+++ b/README.pt-br.md
@@ -56,6 +56,8 @@
## 📢 Novidades
+2026-03-31 📱 **Suporte Android!** PicoClaw agora roda no Android! Baixe o APK em [picoclaw.io](https://picoclaw.io/download)
+
2026-03-25 🚀 **v0.2.4 Lançada!** Reformulação da arquitetura Agent (SubTurn, Hooks, Steering, EventBus), integração WeChat/WeCom, fortalecimento de segurança (.security.yml, filtragem de dados sensíveis), novos providers (AWS Bedrock, Azure, Xiaomi MiMo) e 35 correções de bugs. O PicoClaw atingiu **26K Stars**!
2026-03-17 🚀 **v0.2.3 Lançada!** UI na bandeja do sistema (Windows e Linux), consulta de status de sub-agent (`spawn_status`), hot-reload experimental do Gateway, controle de segurança do Cron e 2 correções de segurança. O PicoClaw atingiu **25K Stars**!
@@ -301,7 +303,25 @@ Para documentação detalhada do TUI, veja [docs.picoclaw.io](https://docs.picoc
Dê uma segunda vida ao seu celular de uma década! Transforme-o em um Assistente de IA inteligente com o PicoClaw.
-**Opção 1: Termux (disponível agora)**
+**Opção 1: Instalação via APK**
+
+Pré-visualização:
+
+
+
+
+
+
+
+
+
+
+Baixe o APK de [picoclaw.io](https://picoclaw.io/download/) e instale diretamente. Sem necessidade de Termux!
+
+**Opção 2: Termux**
+
+
+Terminal Launcher (para ambientes com recursos limitados)
1. Instale o [Termux](https://github.com/termux/termux-app) (baixe nas [GitHub Releases](https://github.com/termux/termux-app/releases), ou pesquise no F-Droid / Google Play)
2. Execute os seguintes comandos:
@@ -318,13 +338,6 @@ Em seguida, siga a seção Terminal Launcher abaixo para concluir a configuraç
-**Opção 2: Instalação via APK (em breve)**
-
-Um APK Android independente com WebUI integrado está em desenvolvimento. Fique ligado!
-
-
-Terminal Launcher (para ambientes com recursos limitados)
-
Para ambientes mínimos onde apenas o binário principal `picoclaw` está disponível (sem Launcher UI), você pode configurar tudo via linha de comando e um arquivo de configuração JSON.
**1. Inicializar**
diff --git a/README.vi.md b/README.vi.md
index 6c8f6ad44..78b8a9a59 100644
--- a/README.vi.md
+++ b/README.vi.md
@@ -56,6 +56,8 @@
## 📢 Tin tức
+2026-03-31 📱 **Hỗ trợ Android!** PicoClaw giờ chạy trên Android! Tải APK tại [picoclaw.io](https://picoclaw.io/download)
+
2026-03-25 🚀 **v0.2.4 đã phát hành!** Tái cấu trúc kiến trúc Agent (SubTurn, Hooks, Steering, EventBus), tích hợp WeChat/WeCom, tăng cường bảo mật (.security.yml, lọc dữ liệu nhạy cảm), provider mới (AWS Bedrock, Azure, Xiaomi MiMo) và 35 bản vá lỗi. PicoClaw đã đạt **26K Stars**!
2026-03-17 🚀 **v0.2.3 đã phát hành!** Giao diện system tray (Windows & Linux), truy vấn trạng thái sub-agent (`spawn_status`), thử nghiệm Gateway hot-reload, bảo mật Cron, và 2 bản vá bảo mật. PicoClaw đã đạt **25K Stars**!
@@ -301,7 +303,25 @@ Sử dụng menu TUI để: **1)** Cấu hình Provider -> **2)** Cấu hình Ch
Hãy cho chiếc điện thoại cũ của bạn một cuộc sống mới! Biến nó thành Trợ lý AI thông minh với PicoClaw.
-**Tùy chọn 1: Termux (có sẵn ngay)**
+**Tùy chọn 1: Cài đặt APK**
+
+Xem trước:
+
+
+
+
+
+
+
+
+
+
+Tải APK từ [picoclaw.io](https://picoclaw.io/download/) và cài đặt trực tiếp. Không cần Termux!
+
+**Tùy chọn 2: Termux**
+
+
+Terminal Launcher (cho môi trường hạn chế tài nguyên)
1. Cài đặt [Termux](https://github.com/termux/termux-app) (tải từ [GitHub Releases](https://github.com/termux/termux-app/releases), hoặc tìm kiếm trong F-Droid / Google Play)
2. Chạy các lệnh sau:
@@ -318,13 +338,6 @@ Sau đó làm theo phần Terminal Launcher bên dưới để hoàn tất cấu
-**Tùy chọn 2: Cài đặt APK (sắp ra mắt)**
-
-Một APK Android độc lập với WebUI tích hợp đang được phát triển. Hãy đón chờ!
-
-
-Terminal Launcher (cho môi trường hạn chế tài nguyên)
-
Đối với các môi trường tối giản chỉ có binary lõi `picoclaw` (không có Launcher UI), bạn có thể cấu hình mọi thứ qua dòng lệnh và tệp cấu hình JSON.
**1. Khởi tạo**
diff --git a/README.zh.md b/README.zh.md
index b92e8e889..2ba0913fc 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -56,6 +56,8 @@
## 📢 新闻
+2026-03-31 📱 **Android 支持!** PicoClaw 现可在 Android 上运行!APK 下载地址:[picoclaw.io](https://picoclaw.io/download)
+
2026-03-25 🚀 **v0.2.4 发布!** Agent 架构全面重构(SubTurn、Hook、Steering、EventBus)、微信/企业微信深度集成、安全体系升级(.security.yml、敏感数据过滤)、新增 Provider(AWS Bedrock、Azure、小米 MiMo),以及 35 项 Bug 修复。PicoClaw 已达 **26K ⭐**!
2026-03-17 🚀 **v0.2.3 发布!** 系统托盘 UI(Windows & Linux)、子 Agent 状态查询 (`spawn_status`)、实验性 Gateway 热重载、Cron 安全门控,以及 2 项安全修复。PicoClaw 已达 **25K ⭐**!
@@ -301,7 +303,25 @@ picoclaw-launcher-tui
让你十年前的旧手机焕发新生!将它变成你的 AI 助手。
-**方式一:Termux(现已可用)**
+**方式一:APK 安装**
+
+预览:
+
+
+
+
+
+
+
+
+
+
+从 [picoclaw.io](https://picoclaw.io/download/) 下载 APK 并直接安装,无需 Termux!
+
+**方式二:Termux**
+
+
+Terminal Launcher(适用于资源受限环境)
1. 安装 [Termux](https://github.com/termux/termux-app)(可从 [GitHub Releases](https://github.com/termux/termux-app/releases) 下载,或在 F-Droid / Google Play 中搜索)
2. 执行以下命令:
@@ -318,13 +338,6 @@ termux-chroot ./picoclaw onboard # chroot 提供标准 Linux 文件系统布
-**方式二:APK 安装(即将推出)**
-
-内置 WebUI 的独立 Android APK 正在开发中,敬请期待!
-
-
-Terminal Launcher(适用于资源受限环境)
-
对于只有 `picoclaw` 核心二进制文件的极简环境(无 Launcher UI),可通过命令行和 JSON 配置文件完成所有配置。
**1. 初始化**
@@ -435,7 +448,7 @@ PicoClaw 通过 `model_list` 配置支持 30+ LLM Provider,使用 `协议/模
## 💬 Channels(聊天应用)
-通过 17+ 消息平台与你的 PicoClaw 对话:
+通过 18+ 消息平台与你的 PicoClaw 对话:
| Channel | 配置难度 | 协议 | 文档 |
|---------|----------|------|------|
@@ -450,6 +463,7 @@ PicoClaw 通过 `model_list` 配置支持 30+ LLM Provider,使用 `协议/模
| **飞书 / Lark** | 中等(App ID + Secret) | WebSocket/SDK | [指南](docs/channels/feishu/README.zh.md) |
| **LINE** | 中等(credentials + webhook) | Webhook | [指南](docs/channels/line/README.zh.md) |
| **企业微信** | 简单(扫码登录或手动配置) | WebSocket | [指南](docs/channels/wecom/README.zh.md) |
+| **VK** | 简单(群组 token) | Long Poll | [指南](docs/channels/vk/README.md) |
| **IRC** | 中等(server + nick) | IRC 协议 | [指南](docs/zh/chat-apps.md#irc) |
| **OneBot** | 中等(WebSocket URL) | OneBot v11 | [指南](docs/channels/onebot/README.zh.md) |
| **MaixCam** | 简单(启用即可) | TCP socket | [指南](docs/channels/maixcam/README.zh.md) |
diff --git a/assets/fui_log_page.jpg b/assets/fui_log_page.jpg
new file mode 100644
index 000000000..188c46982
Binary files /dev/null and b/assets/fui_log_page.jpg differ
diff --git a/assets/fui_main_page.jpg b/assets/fui_main_page.jpg
new file mode 100644
index 000000000..f9c5b5c34
Binary files /dev/null and b/assets/fui_main_page.jpg differ
diff --git a/assets/fui_setting_page.jpg b/assets/fui_setting_page.jpg
new file mode 100644
index 000000000..3481088e3
Binary files /dev/null and b/assets/fui_setting_page.jpg differ
diff --git a/assets/fui_web_page.jpg b/assets/fui_web_page.jpg
new file mode 100644
index 000000000..2f57c64c7
Binary files /dev/null and b/assets/fui_web_page.jpg differ
diff --git a/assets/wechat.png b/assets/wechat.png
index 07a05dd91..66ffa99e9 100644
Binary files a/assets/wechat.png and b/assets/wechat.png differ
diff --git a/cmd/membench/eval.go b/cmd/membench/eval.go
new file mode 100644
index 000000000..bddee76fd
--- /dev/null
+++ b/cmd/membench/eval.go
@@ -0,0 +1,366 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/seahorse"
+)
+
+// EvalResult holds per-sample evaluation results for one mode.
+type EvalResult struct {
+ Mode string `json:"mode"`
+ SampleID string `json:"sampleId"`
+ QAResults []QAResult `json:"qaResults"`
+ Agg AggMetrics `json:"aggregated"`
+}
+
+// QAResult holds metrics for a single QA pair.
+type QAResult struct {
+ Question string `json:"question"`
+ Category int `json:"category"`
+ GoldAnswer string `json:"goldAnswer"`
+ TokenF1 float64 `json:"tokenF1"`
+ HitRate float64 `json:"hitRate"`
+}
+
+// AggMetrics holds aggregated evaluation metrics.
+type AggMetrics struct {
+ OverallF1 float64 `json:"overallF1"`
+ OverallHitRate float64 `json:"overallHitRate"`
+ ByCategory map[int]*CatMetrics `json:"byCategory"`
+ TotalQuestions int `json:"totalQuestions"`
+}
+
+// CatMetrics holds metrics for a single category.
+type CatMetrics struct {
+ F1 float64 `json:"f1"`
+ HitRate float64 `json:"hitRate"`
+ QuestionCount int `json:"questionCount"`
+}
+
+// EvalLegacy evaluates using legacy session store (raw history + budget truncation).
+func EvalLegacy(
+ ctx context.Context,
+ samples []LocomoSample,
+ legacy *LegacyStore,
+ budgetTokens int,
+) []EvalResult {
+ results := make([]EvalResult, 0, len(samples))
+ for si := range samples {
+ sample := &samples[si]
+ history := legacy.GetHistory(sample.SampleID)
+
+ // Convert messages to content strings
+ allContent := make([]string, 0, len(history))
+ for _, msg := range history {
+ allContent = append(allContent, msg.Content)
+ }
+
+ qaResults := make([]QAResult, 0, len(sample.QA))
+ for qi := range sample.QA {
+ qa := &sample.QA[qi]
+ // Budget truncate the full history
+ truncated, _ := BudgetTruncate(allContent, budgetTokens)
+ context := StringListToContent(truncated)
+
+ f1 := TokenOverlapF1(context, qa.AnswerString())
+ hitRate := RecallHitRate(qa.Evidence, sample, context)
+
+ qaResults = append(qaResults, QAResult{
+ Question: qa.Question,
+ Category: qa.Category,
+ GoldAnswer: qa.AnswerString(),
+ TokenF1: f1,
+ HitRate: hitRate,
+ })
+ }
+
+ results = append(results, EvalResult{
+ Mode: "legacy",
+ SampleID: sample.SampleID,
+ QAResults: qaResults,
+ Agg: aggregateMetrics(qaResults),
+ })
+ }
+ return results
+}
+
+// EvalSeahorse evaluates using seahorse short memory (per-keyword search + expand).
+func EvalSeahorse(
+ ctx context.Context,
+ samples []LocomoSample,
+ ir *SeahorseIngestResult,
+ budgetTokens int,
+) []EvalResult {
+ store := ir.Engine.GetRetrieval().Store()
+ retrieval := ir.Engine.GetRetrieval()
+
+ results := make([]EvalResult, 0, len(samples))
+ for si := range samples {
+ sample := &samples[si]
+ convID, ok := ir.ConvMap[sample.SampleID]
+ if !ok {
+ log.Printf("WARN: no conversation ID for sample %s", sample.SampleID)
+ continue
+ }
+
+ qaResults := make([]QAResult, 0, len(sample.QA))
+ for qi := range sample.QA {
+ qa := &sample.QA[qi]
+ keywords := ExtractKeywords(qa.Question)
+
+ // Search each keyword individually and union results,
+ // tracking best BM25 rank per message for relevance sorting.
+ bestRank := map[int64]float64{}
+ for _, kw := range keywords {
+ searchResults, err := store.SearchMessages(ctx, seahorse.SearchInput{
+ Pattern: kw,
+ ConversationID: convID,
+ Limit: 20,
+ })
+ if err != nil {
+ log.Printf("WARN: search failed for keyword %q: %v", kw, err)
+ continue
+ }
+ for _, sr := range searchResults {
+ if sr.MessageID > 0 {
+ if prev, ok := bestRank[sr.MessageID]; !ok || sr.Rank < prev {
+ bestRank[sr.MessageID] = sr.Rank
+ }
+ }
+ }
+ }
+ // Sort messageIDs by rank ascending (best/most-negative first).
+ // BudgetTruncate walks from the front, keeping best-ranked messages.
+ // Note: SQLite FTS5 bm25() returns negative values where more
+ // negative = better match.
+ messageIDs := make([]int64, 0, len(bestRank))
+ for id := range bestRank {
+ messageIDs = append(messageIDs, id)
+ }
+ sort.Slice(messageIDs, func(i, j int) bool {
+ return bestRank[messageIDs[i]] < bestRank[messageIDs[j]]
+ })
+
+ // Expand messages to get full content
+ var contentParts []string
+ if len(messageIDs) > 0 {
+ expandResult, err := retrieval.ExpandMessages(ctx, messageIDs)
+ if err != nil {
+ log.Printf("WARN: expand failed for sample %s: %v", sample.SampleID, err)
+ } else {
+ for _, msg := range expandResult.Messages {
+ contentParts = append(contentParts, msg.Content)
+ }
+ }
+ }
+
+ if len(contentParts) == 0 {
+ qaResults = append(qaResults, QAResult{
+ Question: qa.Question,
+ Category: qa.Category,
+ GoldAnswer: qa.AnswerString(),
+ TokenF1: 0.0,
+ HitRate: 0.0,
+ })
+ continue
+ }
+
+ // Budget truncate (drop worst-ranked)
+ truncated, _ := BudgetTruncate(contentParts, budgetTokens)
+ context := StringListToContent(truncated)
+
+ f1 := TokenOverlapF1(context, qa.AnswerString())
+ hitRate := RecallHitRate(qa.Evidence, sample, context)
+
+ qaResults = append(qaResults, QAResult{
+ Question: qa.Question,
+ Category: qa.Category,
+ GoldAnswer: qa.AnswerString(),
+ TokenF1: f1,
+ HitRate: hitRate,
+ })
+ }
+
+ results = append(results, EvalResult{
+ Mode: "seahorse",
+ SampleID: sample.SampleID,
+ QAResults: qaResults,
+ Agg: aggregateMetrics(qaResults),
+ })
+ }
+ return results
+}
+
+// aggregateMetrics computes overall and per-category metrics.
+func aggregateMetrics(qaResults []QAResult) AggMetrics {
+ byCat := map[int]*CatMetrics{}
+ totalF1 := 0.0
+ totalHitRate := 0.0
+ for _, qr := range qaResults {
+ totalF1 += qr.TokenF1
+ totalHitRate += qr.HitRate
+ cat, ok := byCat[qr.Category]
+ if !ok {
+ cat = &CatMetrics{}
+ byCat[qr.Category] = cat
+ }
+ cat.F1 += qr.TokenF1
+ cat.HitRate += qr.HitRate
+ cat.QuestionCount++
+ }
+ n := len(qaResults)
+ if n == 0 {
+ n = 1
+ }
+ agg := AggMetrics{
+ OverallF1: totalF1 / float64(n),
+ OverallHitRate: totalHitRate / float64(n),
+ ByCategory: byCat,
+ TotalQuestions: len(qaResults),
+ }
+ for _, cat := range agg.ByCategory {
+ if cat.QuestionCount > 0 {
+ cat.F1 /= float64(cat.QuestionCount)
+ cat.HitRate /= float64(cat.QuestionCount)
+ }
+ }
+ return agg
+}
+
+// SaveResults writes per-sample eval results to JSON files.
+func SaveResults(results []EvalResult, outDir string) error {
+ if err := os.MkdirAll(outDir, 0o755); err != nil {
+ return fmt.Errorf("create output dir: %w", err)
+ }
+ for _, r := range results {
+ path := filepath.Join(outDir, fmt.Sprintf("eval_%s_%s.json", r.Mode, r.SampleID))
+ data, err := json.MarshalIndent(r, "", " ")
+ if err != nil {
+ return fmt.Errorf("marshal result: %w", err)
+ }
+ if err := os.WriteFile(path, data, 0o644); err != nil {
+ return fmt.Errorf("write result: %w", err)
+ }
+ }
+ return nil
+}
+
+// SaveAggregated writes a combined results.json with all modes.
+func SaveAggregated(results []EvalResult, outDir string) error {
+ byMode := map[string][]EvalResult{}
+ for _, r := range results {
+ byMode[r.Mode] = append(byMode[r.Mode], r)
+ }
+
+ aggMap := map[string]AggMetrics{}
+ for mode, modeResults := range byMode {
+ aggMap[mode] = computeModeAgg(modeResults)
+ }
+
+ data, err := json.MarshalIndent(aggMap, "", " ")
+ if err != nil {
+ return err
+ }
+ return os.WriteFile(filepath.Join(outDir, "results.json"), data, 0o644)
+}
+
+// computeModeAgg aggregates results for a single mode using weighted averaging
+// (weighted by question count per sample). All modes must have the same Mode field.
+func computeModeAgg(results []EvalResult) AggMetrics {
+ agg := AggMetrics{ByCategory: map[int]*CatMetrics{}}
+ for _, r := range results {
+ agg.OverallF1 += r.Agg.OverallF1 * float64(r.Agg.TotalQuestions)
+ agg.OverallHitRate += r.Agg.OverallHitRate * float64(r.Agg.TotalQuestions)
+ agg.TotalQuestions += r.Agg.TotalQuestions
+ for cat, cm := range r.Agg.ByCategory {
+ existing, ok := agg.ByCategory[cat]
+ if !ok {
+ existing = &CatMetrics{}
+ agg.ByCategory[cat] = existing
+ }
+ existing.F1 += cm.F1 * float64(cm.QuestionCount)
+ existing.HitRate += cm.HitRate * float64(cm.QuestionCount)
+ existing.QuestionCount += cm.QuestionCount
+ }
+ }
+ if agg.TotalQuestions > 0 {
+ agg.OverallF1 /= float64(agg.TotalQuestions)
+ agg.OverallHitRate /= float64(agg.TotalQuestions)
+ }
+ for _, cat := range agg.ByCategory {
+ if cat.QuestionCount > 0 {
+ cat.F1 /= float64(cat.QuestionCount)
+ cat.HitRate /= float64(cat.QuestionCount)
+ }
+ }
+ return agg
+}
+
+// printSection prints a single comparison table section.
+func printSection(title string, results []EvalResult) {
+ fmt.Printf("\n--- %s ---\n", title)
+ byMode := map[string][]EvalResult{}
+ for _, r := range results {
+ byMode[r.Mode] = append(byMode[r.Mode], r)
+ }
+
+ modes := map[string]AggMetrics{}
+ for mode, modeResults := range byMode {
+ modes[mode] = computeModeAgg(modeResults)
+ }
+
+ modeKeys := make([]string, 0, len(modes))
+ for k := range modes {
+ modeKeys = append(modeKeys, k)
+ }
+ sort.Strings(modeKeys)
+
+ // Collect all category keys across modes
+ catSet := map[int]bool{}
+ for _, agg := range modes {
+ for cat := range agg.ByCategory {
+ catSet[cat] = true
+ }
+ }
+ cats := make([]int, 0, len(catSet))
+ for cat := range catSet {
+ cats = append(cats, cat)
+ }
+ sort.Ints(cats)
+
+ fmt.Printf("%-10s %-8s %-8s", "Mode", "HitRate", "F1")
+ for _, cat := range cats {
+ fmt.Printf(" %-7s", fmt.Sprintf("C%d", cat))
+ }
+ fmt.Println()
+ fmt.Println(strings.Repeat("-", 10+8+8+7*len(cats)+8))
+
+ for _, mode := range modeKeys {
+ agg := modes[mode]
+ fmt.Printf("%-10s %-8.4f %-8.4f", mode, agg.OverallHitRate, agg.OverallF1)
+ for _, cat := range cats {
+ if cm, ok := agg.ByCategory[cat]; ok {
+ fmt.Printf(" %-7.4f", cm.HitRate)
+ } else {
+ fmt.Printf(" %-7s", "N/A")
+ }
+ }
+ fmt.Println()
+ }
+}
+
+// PrintComparison outputs a human-readable comparison table to stdout.
+func PrintComparison(results []EvalResult, llmResults []EvalResult) {
+ printSection("No LLM generation", results)
+ if len(llmResults) > 0 {
+ printSection("With LLM", llmResults)
+ }
+}
diff --git a/cmd/membench/eval_test.go b/cmd/membench/eval_test.go
new file mode 100644
index 000000000..d500a38ca
--- /dev/null
+++ b/cmd/membench/eval_test.go
@@ -0,0 +1,104 @@
+package main
+
+import (
+ "math"
+ "testing"
+)
+
+func TestComputeModeAggAllCategories(t *testing.T) {
+ results := []EvalResult{
+ {
+ Mode: "test",
+ SampleID: "s1",
+ QAResults: []QAResult{
+ {Category: 1, TokenF1: 0.5, HitRate: 0.8},
+ {Category: 2, TokenF1: 0.3, HitRate: 0.6},
+ {Category: 3, TokenF1: 0.1, HitRate: 0.4},
+ {Category: 4, TokenF1: 0.7, HitRate: 0.9},
+ {Category: 5, TokenF1: 0.2, HitRate: 0.1},
+ },
+ },
+ }
+ for i := range results {
+ results[i].Agg = aggregateMetrics(results[i].QAResults)
+ }
+
+ got := computeModeAgg(results)
+
+ // Should have all 5 categories
+ for cat := 1; cat <= 5; cat++ {
+ cm, ok := got.ByCategory[cat]
+ if !ok {
+ t.Errorf("ByCategory missing category %d", cat)
+ continue
+ }
+ if cm.QuestionCount != 1 {
+ t.Errorf("ByCategory[%d].QuestionCount = %d, want 1", cat, cm.QuestionCount)
+ }
+ }
+
+ // Verify specific F1 values per category
+ wantF1 := map[int]float64{1: 0.5, 2: 0.3, 3: 0.1, 4: 0.7, 5: 0.2}
+ for cat, want := range wantF1 {
+ if cm, ok := got.ByCategory[cat]; ok {
+ if math.Abs(cm.F1-want) > 1e-9 {
+ t.Errorf("ByCategory[%d].F1 = %.4f, want %.4f", cat, cm.F1, want)
+ }
+ }
+ }
+}
+
+func TestComputeModeAgg(t *testing.T) {
+ // Two samples with different question counts:
+ // sample-a: 2 questions, F1 = [0.4, 0.6] → avg 0.5
+ // sample-b: 8 questions, F1 = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] → avg 0.1
+ //
+ // Unweighted (PrintComparison bug): (0.5 + 0.1) / 2 = 0.3
+ // Weighted (correct): (0.4+0.6 + 0.1*8) / 10 = 1.8 / 10 = 0.18
+ results := []EvalResult{
+ {
+ Mode: "test",
+ SampleID: "sample-a",
+ QAResults: []QAResult{
+ {TokenF1: 0.4, HitRate: 0.5},
+ {TokenF1: 0.6, HitRate: 0.7},
+ },
+ },
+ {
+ Mode: "test",
+ SampleID: "sample-b",
+ QAResults: []QAResult{
+ {TokenF1: 0.1, HitRate: 0.2},
+ {TokenF1: 0.1, HitRate: 0.2},
+ {TokenF1: 0.1, HitRate: 0.2},
+ {TokenF1: 0.1, HitRate: 0.2},
+ {TokenF1: 0.1, HitRate: 0.2},
+ {TokenF1: 0.1, HitRate: 0.2},
+ {TokenF1: 0.1, HitRate: 0.2},
+ {TokenF1: 0.1, HitRate: 0.2},
+ },
+ },
+ }
+ // Compute per-sample aggregates
+ for i := range results {
+ results[i].Agg = aggregateMetrics(results[i].QAResults)
+ }
+
+ got := computeModeAgg(results)
+
+ // Weighted: (0.4+0.6+0.1*8) / 10 = 1.8/10 = 0.18
+ wantF1 := 0.18
+ if math.Abs(got.OverallF1-wantF1) > 1e-9 {
+ t.Errorf("OverallF1 = %.6f, want %.6f (weighted average)", got.OverallF1, wantF1)
+ }
+
+ // Weighted: (0.5+0.7+0.2*8) / 10 = 2.8/10 = 0.28
+ wantRecall := 0.28
+ if math.Abs(got.OverallHitRate-wantRecall) > 1e-9 {
+ t.Errorf("OverallHitRate = %.6f, want %.6f (weighted average)", got.OverallHitRate, wantRecall)
+ }
+
+ if got.TotalQuestions != 10 {
+ t.Errorf("TotalQuestions = %d, want 10", got.TotalQuestions)
+ }
+}
diff --git a/cmd/membench/ingest.go b/cmd/membench/ingest.go
new file mode 100644
index 000000000..70d559c2b
--- /dev/null
+++ b/cmd/membench/ingest.go
@@ -0,0 +1,85 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "log"
+
+ "github.com/sipeed/picoclaw/pkg/seahorse"
+)
+
+// ConvMap stores the mapping from sampleID to seahorse ConversationID.
+type ConvMap map[string]int64
+
+// SeahorseIngestResult holds the results of ingesting into seahorse.
+type SeahorseIngestResult struct {
+ Engine *seahorse.Engine
+ ConvMap ConvMap // sampleID → conversationID
+}
+
+// IngestSeahorse loads all LOCOMO samples into a seahorse Engine.
+// Returns the engine and a mapping from sampleID to conversationID for scoped retrieval.
+func IngestSeahorse(ctx context.Context, samples []LocomoSample, dbPath string) (*SeahorseIngestResult, error) {
+ noopFn := func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
+ return "", nil
+ }
+
+ engine, err := seahorse.NewEngine(seahorse.Config{
+ DBPath: dbPath,
+ }, noopFn)
+ if err != nil {
+ return nil, fmt.Errorf("create seahorse engine: %w", err)
+ }
+
+ store := engine.GetRetrieval().Store()
+ convMap := make(ConvMap)
+
+ for si := range samples {
+ sample := &samples[si]
+ sessionKey := "locomo-" + sample.SampleID
+
+ // Check if conversation already exists (idempotent)
+ existing, _ := store.GetConversationBySessionKey(ctx, sessionKey)
+ if existing != nil {
+ convMap[sample.SampleID] = existing.ConversationID
+ log.Printf("Skipping existing sample %s: convID=%d", sample.SampleID, existing.ConversationID)
+ continue
+ }
+
+ turns := GetTurns(sample)
+
+ // Convert turns to seahorse messages
+ msgs := make([]seahorse.Message, 0, len(turns))
+ for _, turn := range turns {
+ content := turn.Speaker + ": " + turn.Text
+ msgs = append(msgs, seahorse.Message{
+ Role: "user",
+ Content: content,
+ TokenCount: len(turn.Text) / 4,
+ })
+ }
+
+ // Ingest all turns for this sample
+ _, err := engine.Ingest(ctx, sessionKey, msgs)
+ if err != nil {
+ return nil, fmt.Errorf("ingest sample %s: %w", sample.SampleID, err)
+ }
+
+ // Get the conversation ID for scoped retrieval
+ conv, err := store.GetConversationBySessionKey(ctx, sessionKey)
+ if err != nil {
+ return nil, fmt.Errorf("get conversation for %s: %w", sample.SampleID, err)
+ }
+ if conv == nil {
+ return nil, fmt.Errorf("conversation not found for %s after ingest", sample.SampleID)
+ }
+ convMap[sample.SampleID] = conv.ConversationID
+ log.Printf("Ingested sample %s: %d turns, convID=%d", sample.SampleID, len(turns), conv.ConversationID)
+ }
+
+ log.Printf("Seahorse ingestion complete: %d samples, %d conversations", len(samples), len(convMap))
+ return &SeahorseIngestResult{
+ Engine: engine,
+ ConvMap: convMap,
+ }, nil
+}
diff --git a/cmd/membench/ingest_test.go b/cmd/membench/ingest_test.go
new file mode 100644
index 000000000..e8748deed
--- /dev/null
+++ b/cmd/membench/ingest_test.go
@@ -0,0 +1,79 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "path/filepath"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/seahorse"
+)
+
+func TestIngestSeahorseIdempotent(t *testing.T) {
+ ctx := context.Background()
+ tmpDir := t.TempDir()
+ dbPath := filepath.Join(tmpDir, "test.db")
+
+ // Minimal test data
+ samples := []LocomoSample{
+ {
+ SampleID: "test-1",
+ Conversation: map[string]json.RawMessage{
+ "session_1": json.RawMessage(`[
+ {"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message"},
+ {"speaker":"B","dia_id":"D1:2","text":"another message for testing purposes"}
+ ]`),
+ },
+ },
+ }
+
+ // First ingestion
+ result1, err := IngestSeahorse(ctx, samples, dbPath)
+ if err != nil {
+ t.Fatalf("first ingest failed: %v", err)
+ }
+ convCount1 := len(result1.ConvMap)
+ result1.Engine.Close()
+
+ // Second ingestion on same DB — should reuse existing data
+ result2, err := IngestSeahorse(ctx, samples, dbPath)
+ if err != nil {
+ t.Fatalf("second ingest failed: %v", err)
+ }
+ defer result2.Engine.Close()
+
+ // ConvMap should have same number of entries (no duplicates)
+ if len(result2.ConvMap) != convCount1 {
+ t.Errorf("second ingest convMap has %d entries, want %d (same as first)",
+ len(result2.ConvMap), convCount1)
+ }
+
+ // Verify conversation IDs are the same (reused, not new ones)
+ for id, cid1 := range result1.ConvMap {
+ cid2, ok := result2.ConvMap[id]
+ if !ok {
+ t.Errorf("sample %s missing from second ConvMap", id)
+ continue
+ }
+ if cid2 != cid1 {
+ t.Errorf("sample %s: second ingest got convID %d, want %d (reused)", id, cid2, cid1)
+ }
+ }
+
+ // Verify no duplicate messages by counting
+ store := result2.Engine.GetRetrieval().Store()
+ for _, convID := range result2.ConvMap {
+ msgs, err := store.SearchMessages(ctx, seahorse.SearchInput{
+ Pattern: "test",
+ ConversationID: convID,
+ Limit: 100,
+ })
+ if err != nil {
+ t.Fatalf("search failed: %v", err)
+ }
+ // Should find exactly 1 message containing "test" (the first turn)
+ if len(msgs) > 2 {
+ t.Errorf("found %d messages for 'test' in conv %d, expected ≤2 (no duplicates)", len(msgs), convID)
+ }
+ }
+}
diff --git a/cmd/membench/legacy_store.go b/cmd/membench/legacy_store.go
new file mode 100644
index 000000000..80cbd2704
--- /dev/null
+++ b/cmd/membench/legacy_store.go
@@ -0,0 +1,34 @@
+package main
+
+import (
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/session"
+)
+
+// LegacyStore wraps session.SessionManager for legacy baseline.
+type LegacyStore struct {
+ sm *session.SessionManager
+}
+
+// NewLegacyStore creates a new in-memory session manager.
+func NewLegacyStore() *LegacyStore {
+ return &LegacyStore{
+ sm: session.NewSessionManager(""),
+ }
+}
+
+// IngestSample loads all turns from a LOCOMO sample into the legacy session store.
+func (ls *LegacyStore) IngestSample(sample *LocomoSample) {
+ sessionKey := "locomo-" + sample.SampleID
+ turns := GetTurns(sample)
+ for _, turn := range turns {
+ content := turn.Speaker + ": " + turn.Text
+ ls.sm.AddMessage(sessionKey, "user", content)
+ }
+}
+
+// GetHistory returns all messages for a sample's session.
+func (ls *LegacyStore) GetHistory(sampleID string) []providers.Message {
+ sessionKey := "locomo-" + sampleID
+ return ls.sm.GetHistory(sessionKey)
+}
diff --git a/cmd/membench/locomo.go b/cmd/membench/locomo.go
new file mode 100644
index 000000000..28ace3680
--- /dev/null
+++ b/cmd/membench/locomo.go
@@ -0,0 +1,142 @@
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+// LocomoSample represents one conversation sample from the LOCOMO dataset.
+type LocomoSample struct {
+ SampleID string `json:"sample_id"`
+ Conversation map[string]json.RawMessage `json:"conversation"`
+ QA []LocomoQA `json:"qa"`
+}
+
+// LocomoTurn represents a single turn in a conversation.
+type LocomoTurn struct {
+ Speaker string `json:"speaker"`
+ DiaID string `json:"dia_id"`
+ Text string `json:"text"`
+}
+
+// LocomoQA represents a question-answer pair with evidence.
+type LocomoQA struct {
+ Question string `json:"question"`
+ Answer json.RawMessage `json:"answer"` // can be string or int (category 1-4)
+ AdversarialAnswer string `json:"adversarial_answer"` // category 5 only
+ Evidence []string `json:"evidence"`
+ Category int `json:"category"` // 1=single-hop, 2=multi-hop, 3=open-ended, 5=adversarial
+}
+
+// AnswerString returns the answer as a string, handling both string and int types.
+func (qa *LocomoQA) AnswerString() string {
+ // Prefer answer field (category 1-4)
+ if len(qa.Answer) > 0 {
+ var s string
+ if err := json.Unmarshal(qa.Answer, &s); err == nil {
+ return s
+ }
+ var n json.Number
+ if err := json.Unmarshal(qa.Answer, &n); err == nil {
+ return n.String()
+ }
+ return strings.Trim(string(qa.Answer), `"`)
+ }
+ // Fallback to adversarial_answer (category 5)
+ return qa.AdversarialAnswer
+}
+
+// LoadDataset reads all JSON files from dataDir and returns parsed samples.
+func LoadDataset(dataDir string) ([]LocomoSample, error) {
+ entries, err := os.ReadDir(dataDir)
+ if err != nil {
+ return nil, fmt.Errorf("read data dir %s: %w", dataDir, err)
+ }
+
+ var samples []LocomoSample
+ for _, entry := range entries {
+ if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".json") {
+ path := filepath.Join(dataDir, entry.Name())
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, fmt.Errorf("read file %s: %w", path, err)
+ }
+ var batch []LocomoSample
+ if err := json.Unmarshal(data, &batch); err != nil {
+ return nil, fmt.Errorf("parse file %s: %w", path, err)
+ }
+ samples = append(samples, batch...)
+ }
+ }
+ return samples, nil
+}
+
+// GetSessionNames returns sorted session keys (session_1, session_2, ...) from conversation.
+func GetSessionNames(conv map[string]json.RawMessage) []string {
+ var names []string
+ for k := range conv {
+ if strings.HasPrefix(k, "session_") && !strings.Contains(k, "_date_time") {
+ names = append(names, k)
+ }
+ }
+ sort.Slice(names, func(i, j int) bool {
+ ni := sessionNum(names[i])
+ nj := sessionNum(names[j])
+ return ni < nj
+ })
+ return names
+}
+
+func sessionNum(key string) int {
+ // "session_1" → 1, "session_10" → 10
+ parts := strings.SplitN(key, "_", 2)
+ if len(parts) < 2 {
+ return 0
+ }
+ n, _ := strconv.Atoi(parts[1])
+ return n
+}
+
+// GetTurns flattens all sessions' turns in chronological order.
+func GetTurns(sample *LocomoSample) []LocomoTurn {
+ names := GetSessionNames(sample.Conversation)
+ var all []LocomoTurn
+ for _, name := range names {
+ raw, ok := sample.Conversation[name]
+ if !ok {
+ continue
+ }
+ var turns []LocomoTurn
+ if err := json.Unmarshal(raw, &turns); err != nil {
+ log.Printf("WARNING: unmarshal failed for session %q in sample %s: %v", name, sample.SampleID, err)
+ continue
+ }
+ all = append(all, turns...)
+ }
+ return all
+}
+
+// GetTurnByDiaID finds a specific turn by dia_id (e.g. "D1:3").
+func GetTurnByDiaID(sample *LocomoSample, diaID string) *LocomoTurn {
+ turns := GetTurns(sample)
+ for i := range turns {
+ if turns[i].DiaID == diaID {
+ return &turns[i]
+ }
+ }
+ return nil
+}
+
+// GetSpeakers returns the two speaker names from conversation metadata.
+func GetSpeakers(conv map[string]json.RawMessage) (string, string) {
+ var a, b string
+ json.Unmarshal(conv["speaker_a"], &a)
+ json.Unmarshal(conv["speaker_b"], &b)
+ return a, b
+}
diff --git a/cmd/membench/locomo_test.go b/cmd/membench/locomo_test.go
new file mode 100644
index 000000000..2d5170bc9
--- /dev/null
+++ b/cmd/membench/locomo_test.go
@@ -0,0 +1,67 @@
+package main
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestAnswerString(t *testing.T) {
+ tests := []struct {
+ name string
+ json string
+ want string
+ }{
+ {
+ "string answer",
+ `{"question":"Q","answer":"Paris","evidence":[],"category":1}`,
+ "Paris",
+ },
+ {
+ "int answer",
+ `{"question":"Q","answer":42,"evidence":[],"category":1}`,
+ "42",
+ },
+ {
+ "adversarial answer (category 5)",
+ `{"question":"Q","evidence":[],"category":5,"adversarial_answer":"self-care is important"}`,
+ "self-care is important",
+ },
+ {
+ "both answer and adversarial_answer present",
+ `{"question":"Q","answer":"normal","evidence":[],"category":5,"adversarial_answer":"adversarial"}`,
+ "normal",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var qa LocomoQA
+ if err := json.Unmarshal([]byte(tt.json), &qa); err != nil {
+ t.Fatalf("unmarshal: %v", err)
+ }
+ got := qa.AnswerString()
+ if got != tt.want {
+ t.Errorf("AnswerString() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestGetSessionNames(t *testing.T) {
+ conv := map[string]json.RawMessage{
+ "session_2": {},
+ "session_1": {},
+ "session_10": {},
+ "session_1_date_time": {},
+ "speaker_a": {},
+ }
+ names := GetSessionNames(conv)
+ want := []string{"session_1", "session_2", "session_10"}
+ if len(names) != len(want) {
+ t.Fatalf("got %v, want %v", names, want)
+ }
+ for i, n := range names {
+ if n != want[i] {
+ t.Errorf("names[%d] = %q, want %q", i, n, want[i])
+ }
+ }
+}
diff --git a/cmd/membench/main.go b/cmd/membench/main.go
new file mode 100644
index 000000000..0c5a9387a
--- /dev/null
+++ b/cmd/membench/main.go
@@ -0,0 +1,208 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+var (
+ flagData string
+ flagOut string
+ flagMode string
+ flagBudget int
+)
+
+func main() {
+ // Suppress seahorse INFO logs during benchmark
+ logger.SetLevel(logger.WARN)
+
+ rootCmd := &cobra.Command{
+ Use: "membench",
+ Short: "Memory benchmark tool for picoclaw",
+ }
+
+ ingestCmd := &cobra.Command{
+ Use: "ingest",
+ Short: "Load LOCOMO data into storage backends",
+ RunE: runIngest,
+ }
+ ingestCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
+ ingestCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
+ ingestCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to ingest: legacy, seahorse, or all")
+
+ evalCmd := &cobra.Command{
+ Use: "eval",
+ Short: "Run QA evaluation against ingested data",
+ RunE: runEval,
+ }
+ evalCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
+ evalCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
+ evalCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to evaluate: legacy, seahorse, or all")
+ evalCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
+
+ reportCmd := &cobra.Command{
+ Use: "report",
+ Short: "Output comparison results from evaluation",
+ RunE: runReport,
+ }
+ reportCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
+
+ runCmd := &cobra.Command{
+ Use: "run",
+ Short: "Convenience: eval + report (ingestion is done inline)",
+ RunE: runAll,
+ }
+ runCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
+ runCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
+ runCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to run: legacy, seahorse, or all")
+ runCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
+
+ rootCmd.AddCommand(ingestCmd, evalCmd, reportCmd, runCmd)
+
+ if err := rootCmd.Execute(); err != nil {
+ os.Exit(1)
+ }
+}
+
+func modesFromFlag() []string {
+ switch strings.ToLower(flagMode) {
+ case "all":
+ return []string{"legacy", "seahorse"}
+ default:
+ return []string{strings.ToLower(flagMode)}
+ }
+}
+
+func runIngest(cmd *cobra.Command, args []string) error {
+ if flagData == "" {
+ return fmt.Errorf("--data is required")
+ }
+ modes := modesFromFlag()
+ if len(modes) == 0 {
+ return nil
+ }
+
+ ctx := context.Background()
+ samples, err := LoadDataset(flagData)
+ if err != nil {
+ return fmt.Errorf("load dataset: %w", err)
+ }
+ log.Printf("Loaded %d samples from %s", len(samples), flagData)
+
+ for _, mode := range modes {
+ switch mode {
+ case "legacy":
+ legacy := NewLegacyStore()
+ for i := range samples {
+ legacy.IngestSample(&samples[i])
+ }
+ log.Printf("legacy: ingested %d samples", len(samples))
+ case "seahorse":
+ dbPath := filepath.Join(flagOut, "seahorse.db")
+ if err := os.MkdirAll(flagOut, 0o755); err != nil {
+ return fmt.Errorf("create out dir: %w", err)
+ }
+ _, err := IngestSeahorse(ctx, samples, dbPath)
+ if err != nil {
+ return fmt.Errorf("ingest seahorse: %w", err)
+ }
+ }
+ }
+ return nil
+}
+
+func runEval(cmd *cobra.Command, args []string) error {
+ if flagData == "" {
+ return fmt.Errorf("--data is required")
+ }
+ modes := modesFromFlag()
+ if len(modes) == 0 {
+ return nil
+ }
+
+ ctx := context.Background()
+ samples, err := LoadDataset(flagData)
+ if err != nil {
+ return fmt.Errorf("load dataset: %w", err)
+ }
+ log.Printf("Loaded %d samples", len(samples))
+
+ var allResults []EvalResult
+
+ for _, mode := range modes {
+ switch mode {
+ case "legacy":
+ legacy := NewLegacyStore()
+ for i := range samples {
+ legacy.IngestSample(&samples[i])
+ }
+ results := EvalLegacy(ctx, samples, legacy, flagBudget)
+ allResults = append(allResults, results...)
+ log.Printf("legacy: evaluated %d samples", len(results))
+ case "seahorse":
+ dbPath := filepath.Join(flagOut, "seahorse.db")
+ ir, err := IngestSeahorse(ctx, samples, dbPath)
+ if err != nil {
+ return fmt.Errorf("ingest seahorse: %w", err)
+ }
+ results := EvalSeahorse(ctx, samples, ir, flagBudget)
+ allResults = append(allResults, results...)
+ log.Printf("seahorse: evaluated %d samples", len(results))
+ }
+ }
+
+ if err := SaveResults(allResults, flagOut); err != nil {
+ return fmt.Errorf("save results: %w", err)
+ }
+ if err := SaveAggregated(allResults, flagOut); err != nil {
+ return fmt.Errorf("save aggregated: %w", err)
+ }
+
+ PrintComparison(allResults, nil)
+ return nil
+}
+
+func runReport(cmd *cobra.Command, args []string) error {
+ entries, err := os.ReadDir(flagOut)
+ if err != nil {
+ return fmt.Errorf("read out dir: %w", err)
+ }
+
+ var allResults []EvalResult
+ for _, entry := range entries {
+ if !entry.IsDir() && strings.HasPrefix(entry.Name(), "eval_") && strings.HasSuffix(entry.Name(), ".json") {
+ path := filepath.Join(flagOut, entry.Name())
+ var r EvalResult
+ data, err := os.ReadFile(path)
+ if err != nil {
+ log.Printf("WARN: read %s: %v", path, err)
+ continue
+ }
+ if err := json.Unmarshal(data, &r); err != nil {
+ log.Printf("WARN: parse %s: %v", path, err)
+ continue
+ }
+ allResults = append(allResults, r)
+ }
+ }
+
+ if len(allResults) == 0 {
+ return fmt.Errorf("no eval results found in %s", flagOut)
+ }
+
+ PrintComparison(allResults, nil)
+ return nil
+}
+
+func runAll(cmd *cobra.Command, args []string) error {
+ return runEval(cmd, args)
+}
diff --git a/cmd/membench/metrics.go b/cmd/membench/metrics.go
new file mode 100644
index 000000000..7e3db2dde
--- /dev/null
+++ b/cmd/membench/metrics.go
@@ -0,0 +1,227 @@
+package main
+
+import (
+ "fmt"
+ "log"
+ "regexp"
+ "strconv"
+ "strings"
+ "unicode"
+)
+
+// diaIDRe matches valid dia_id patterns like "D1:3", "D30:5".
+var diaIDRe = regexp.MustCompile(`^D(\d+):(\d+)$`)
+
+// SplitEvidenceIDs splits an evidence string that may contain multiple
+// semicolon-separated or space-separated dia_ids. Only returns valid IDs.
+// Example: "D8:6; D9:17" → ["D8:6", "D9:17"]
+// Example: "D9:1 D4:4 D4:6" → ["D9:1", "D4:4", "D4:6"]
+func SplitEvidenceIDs(evidence string) []string {
+ if evidence == "" {
+ return nil
+ }
+ // Split on semicolons first, then spaces
+ parts := strings.Split(evidence, ";")
+ var ids []string
+ for _, part := range parts {
+ for _, token := range strings.Fields(strings.TrimSpace(part)) {
+ token = strings.TrimSpace(token)
+ if diaIDRe.MatchString(token) {
+ ids = append(ids, NormalizeDiaID(token))
+ }
+ }
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ return ids
+}
+
+// NormalizeDiaID strips leading zeros from the number parts of a dia_id.
+// "D30:05" → "D30:5", "D10:003" → "D10:3"
+func NormalizeDiaID(id string) string {
+ m := diaIDRe.FindStringSubmatch(id)
+ if m == nil {
+ return id
+ }
+ session, _ := strconv.Atoi(m[1])
+ turn, _ := strconv.Atoi(m[2])
+ return fmt.Sprintf("D%d:%d", session, turn)
+}
+
+// stopwords is a fixed English stopword list for deterministic keyword extraction.
+var stopwords = map[string]struct{}{
+ "a": {}, "an": {}, "the": {},
+ "is": {}, "are": {}, "was": {}, "were": {},
+ "did": {}, "does": {}, "do": {},
+ "when": {}, "where": {}, "what": {}, "who": {},
+ "how": {}, "why": {},
+ "to": {}, "of": {}, "in": {}, "on": {}, "at": {},
+ "for": {}, "and": {}, "or": {}, "but": {}, "not": {},
+ "it": {}, "this": {}, "that": {}, "with": {},
+ "from": {}, "by": {}, "as": {},
+ "if": {}, "then": {}, "than": {}, "so": {},
+ "no": {}, "yes": {},
+ "all": {}, "any": {}, "each": {}, "every": {},
+ "some": {}, "such": {},
+ "about": {}, "into": {}, "over": {},
+ "after": {}, "before": {}, "between": {},
+ "through": {}, "during": {}, "until": {},
+ "would": {}, "could": {}, "should": {},
+ "may": {}, "might": {}, "can": {},
+ "will": {}, "shall": {}, "must": {},
+ "have": {}, "has": {}, "had": {},
+ "been": {}, "being": {}, "be": {},
+ "go": {}, "went": {}, "gone": {},
+ "i": {}, "you": {}, "me": {}, "my": {}, "your": {},
+ "we": {}, "they": {}, "them": {}, "our": {},
+ "its": {}, "their": {}, "he": {}, "she": {},
+ "his": {}, "her": {},
+}
+
+// ExtractKeywords removes stopwords and punctuation, returns individual keywords.
+// Deterministic: uses fixed stopword list, no LLM.
+func ExtractKeywords(question string) []string {
+ // Lowercase and split on whitespace/punctuation
+ lower := strings.ToLower(question)
+ words := strings.FieldsFunc(lower, func(r rune) bool {
+ return !unicode.IsLetter(r) && !unicode.IsDigit(r)
+ })
+
+ var keywords []string
+ for _, w := range words {
+ if w == "" || len(w) < 2 {
+ continue
+ }
+ if _, ok := stopwords[w]; ok {
+ continue
+ }
+ keywords = append(keywords, w)
+ if len(keywords) >= 6 {
+ break
+ }
+ }
+ return keywords
+}
+
+// TokenOverlapF1 computes token-level F1 between prediction and reference.
+// Both strings are lowercased and split on whitespace.
+// NOTE: This metric underestimates quality for multi-hop (cat 2) and
+// open-ended (cat 3) questions where the gold answer uses different phrasing
+// than the source text. LLM-Judge scoring is a v2 follow-up.
+func TokenOverlapF1(prediction, reference string) float64 {
+ predTokens := tokenize(prediction)
+ refTokens := tokenize(reference)
+
+ if len(predTokens) == 0 && len(refTokens) == 0 {
+ return 1.0
+ }
+ if len(predTokens) == 0 || len(refTokens) == 0 {
+ return 0.0
+ }
+
+ // Count matches
+ refCount := map[string]int{}
+ for _, t := range refTokens {
+ refCount[t]++
+ }
+
+ predCount := map[string]int{}
+ for _, t := range predTokens {
+ predCount[t]++
+ }
+
+ var matches float64
+ for token, pc := range predCount {
+ if rc, ok := refCount[token]; ok {
+ matches += float64(min(pc, rc))
+ }
+ }
+
+ precision := matches / float64(len(predTokens))
+ recall := matches / float64(len(refTokens))
+
+ if precision+recall == 0 {
+ return 0.0
+ }
+ return 2 * precision * recall / (precision + recall)
+}
+
+func tokenize(s string) []string {
+ lower := strings.ToLower(s)
+ return strings.Fields(lower)
+}
+
+// RecallHitRate computes fraction of evidence IDs found in retrieved content.
+// For each evidence dia_id, looks up the turn text and checks substring match.
+// Logs a warning for turns with text < 20 chars (higher false-positive risk).
+func RecallHitRate(evidenceIDs []string, sample *LocomoSample, retrievedContent string) float64 {
+ if len(evidenceIDs) == 0 {
+ return 1.0 // no evidence required = perfect
+ }
+
+ // Expand any multi-ID evidence entries (e.g. "D8:6; D9:17" or "D9:1 D4:4")
+ var expanded []string
+ for _, id := range evidenceIDs {
+ split := SplitEvidenceIDs(id)
+ if split != nil {
+ expanded = append(expanded, split...)
+ }
+ }
+ if len(expanded) == 0 {
+ log.Printf("WARNING: no valid dia_ids after expanding evidence %v", evidenceIDs)
+ return float64(0) / float64(len(evidenceIDs))
+ }
+
+ // Build turn index once (avoids re-parsing JSON per ID)
+ turns := GetTurns(sample)
+ turnMap := make(map[string]*LocomoTurn, len(turns))
+ for i := range turns {
+ turnMap[turns[i].DiaID] = &turns[i]
+ }
+
+ lowerRetrieved := strings.ToLower(retrievedContent)
+ found := 0
+ resolvable := 0
+ for _, diaID := range expanded {
+ turn, ok := turnMap[diaID]
+ if !ok {
+ log.Printf("WARNING: dia_id %q not found in sample %s", diaID, sample.SampleID)
+ continue
+ }
+ resolvable++
+ if len(turn.Text) < 20 {
+ log.Printf("WARNING: short turn text (%d chars) for dia_id %s: %q",
+ len(turn.Text), diaID, turn.Text)
+ }
+ if strings.Contains(lowerRetrieved, strings.ToLower(turn.Text)) {
+ found++
+ }
+ }
+ if resolvable == 0 {
+ return 0.0 // no resolvable evidence = can't evaluate
+ }
+ return float64(found) / float64(resolvable)
+}
+
+// BudgetTruncate truncates messages to fit within a token budget.
+// Returns the truncated messages and total token count.
+func BudgetTruncate(messages []string, budgetTokens int) ([]string, int) {
+ var result []string
+ total := 0
+ // Walk from the front (best first) and keep until budget exhausted.
+ for i := 0; i < len(messages); i++ {
+ tokens := len(messages[i]) / 4
+ if total+tokens > budgetTokens && len(result) > 0 {
+ break
+ }
+ result = append(result, messages[i])
+ total += tokens
+ }
+ return result, total
+}
+
+// StringListToContent joins a list of strings into a single content string.
+func StringListToContent(parts []string) string {
+ return strings.Join(parts, "\n")
+}
diff --git a/cmd/membench/metrics_test.go b/cmd/membench/metrics_test.go
new file mode 100644
index 000000000..99e4ad6d4
--- /dev/null
+++ b/cmd/membench/metrics_test.go
@@ -0,0 +1,239 @@
+package main
+
+import (
+ "encoding/json"
+ "math"
+ "testing"
+)
+
+func TestSplitEvidenceIDs(t *testing.T) {
+ tests := []struct {
+ input string
+ want []string
+ }{
+ {"D1:3", []string{"D1:3"}},
+ {"D8:6; D9:17", []string{"D8:6", "D9:17"}},
+ {"D9:1 D4:4 D4:6", []string{"D9:1", "D4:4", "D4:6"}},
+ {"D22:1 D22:2 D9:10 D9:11", []string{"D22:1", "D22:2", "D9:10", "D9:11"}},
+ {"D21:18 D21:22 D11:15 D11:19", []string{"D21:18", "D21:22", "D11:15", "D11:19"}},
+ {"D30:05", []string{"D30:5"}},
+ {"D", nil},
+ {"D:", nil},
+ {"", nil},
+ }
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got := SplitEvidenceIDs(tt.input)
+ if len(got) != len(tt.want) {
+ t.Fatalf("SplitEvidenceIDs(%q) = %v, want %v", tt.input, got, tt.want)
+ }
+ for i := range got {
+ if got[i] != tt.want[i] {
+ t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
+ }
+ }
+ })
+ }
+}
+
+func TestNormalizeDiaID(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"D1:3", "D1:3"},
+ {"D30:05", "D30:5"},
+ {"D10:003", "D10:3"},
+ {"D1:0", "D1:0"},
+ }
+ for _, tt := range tests {
+ got := NormalizeDiaID(tt.input)
+ if got != tt.want {
+ t.Errorf("NormalizeDiaID(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ }
+}
+
+func TestTokenOverlapF1(t *testing.T) {
+ tests := []struct {
+ name string
+ prediction string
+ reference string
+ want float64
+ }{
+ {"exact match", "hello world", "hello world", 1.0},
+ {"no overlap", "foo bar", "baz qux", 0.0},
+ {"empty both", "", "", 1.0},
+ {"empty prediction", "", "hello", 0.0},
+ {"empty reference", "hello", "", 0.0},
+ {"partial overlap", "the cat sat on the mat", "the cat on the floor", 8.0 / 11.0},
+ {"case insensitive", "Hello World", "hello world", 1.0},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := TokenOverlapF1(tt.prediction, tt.reference)
+ if math.Abs(got-tt.want) > 1e-9 {
+ t.Errorf("TokenOverlapF1(%q, %q) = %.4f, want %.4f",
+ tt.prediction, tt.reference, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestBudgetTruncate(t *testing.T) {
+ t.Run("within budget returns all", func(t *testing.T) {
+ msgs := []string{"short", "message", "here"}
+ result, total := BudgetTruncate(msgs, 1000)
+ if len(result) != 3 {
+ t.Errorf("expected 3 messages, got %d", len(result))
+ }
+ if total == 0 {
+ t.Error("expected non-zero token count")
+ }
+ })
+
+ t.Run("over budget keeps best first", func(t *testing.T) {
+ msgs := []string{
+ "best message that is quite long and takes up tokens",
+ "good message also fairly long content",
+ "worst short",
+ }
+ result, _ := BudgetTruncate(msgs, 5) // very small budget
+ if len(result) == 0 {
+ t.Fatal("expected at least one message")
+ }
+ // Best-ranked (first) should be kept
+ if result[0] != "best message that is quite long and takes up tokens" {
+ t.Errorf("expected best message kept first, got %q", result[0])
+ }
+ })
+
+ t.Run("over budget keeps best ranked first", func(t *testing.T) {
+ // Messages are sorted by bm25 rank ascending (best/most-negative first).
+ // When budget is insufficient, BudgetTruncate must keep the front
+ // (best-ranked) messages, not the tail (worst-ranked).
+ msgs := []string{
+ "best ranked message with some content here",
+ "second best message also has content",
+ "third message here too",
+ "worst ranked short",
+ }
+ // Budget only fits ~1 message (~10 tokens per message, budget=12)
+ result, _ := BudgetTruncate(msgs, 12)
+ if len(result) == 0 {
+ t.Fatal("expected at least one message")
+ }
+ if result[0] != "best ranked message with some content here" {
+ t.Errorf("expected best-ranked (first) message kept, got %q", result[0])
+ }
+ // Worst-ranked (last) must NOT appear
+ for _, m := range result {
+ if m == "worst ranked short" {
+ t.Error("worst-ranked message should have been truncated")
+ }
+ }
+ })
+
+ t.Run("preserves original order", func(t *testing.T) {
+ msgs := []string{"alpha", "beta", "gamma"}
+ result, _ := BudgetTruncate(msgs, 100)
+ for i, got := range result {
+ if got != msgs[i] {
+ t.Errorf("result[%d] = %q, want %q", i, got, msgs[i])
+ }
+ }
+ })
+
+ t.Run("empty input", func(t *testing.T) {
+ result, total := BudgetTruncate(nil, 100)
+ if len(result) != 0 {
+ t.Errorf("expected 0 messages, got %d", len(result))
+ }
+ if total != 0 {
+ t.Errorf("expected 0 tokens, got %d", total)
+ }
+ })
+}
+
+func TestRecallHitRate(t *testing.T) {
+ // Build a sample with known turns
+ sample := &LocomoSample{
+ SampleID: "test-sample",
+ Conversation: map[string]json.RawMessage{
+ "session_1": json.RawMessage(`[
+ {"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message with enough length"},
+ {"speaker":"B","dia_id":"D1:2","text":"another message for testing recall computation purposes here"},
+ {"speaker":"A","dia_id":"D1:3","text":"third turn with some more content to test"}
+ ]`),
+ },
+ }
+
+ t.Run("all evidence found", func(t *testing.T) {
+ retrieved := "hello world this is a test message with enough length another message for testing recall computation purposes here"
+ got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
+ if math.Abs(got-1.0) > 1e-9 {
+ t.Errorf("RecallHitRate all found = %.4f, want 1.0", got)
+ }
+ })
+
+ t.Run("partial evidence found", func(t *testing.T) {
+ retrieved := "hello world this is a test message with enough length"
+ got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
+ if math.Abs(got-0.5) > 1e-9 {
+ t.Errorf("RecallHitRate partial = %.4f, want 0.5", got)
+ }
+ })
+
+ t.Run("no evidence required", func(t *testing.T) {
+ got := RecallHitRate(nil, sample, "anything")
+ if got != 1.0 {
+ t.Errorf("RecallHitRate no evidence = %.4f, want 1.0", got)
+ }
+ })
+
+ t.Run("missing turn excluded from denominator", func(t *testing.T) {
+ // D1:1 is found, D99:1 does not exist in sample
+ // Should only count resolvable turns in denominator
+ retrieved := "hello world this is a test message with enough length"
+ got := RecallHitRate([]string{"D1:1", "D99:1"}, sample, retrieved)
+ if math.Abs(got-1.0) > 1e-9 {
+ t.Errorf("RecallHitRate missing turn = %.4f, want 1.0 (unresolvable excluded)", got)
+ }
+ })
+}
+
+func TestExtractKeywords(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want []string
+ }{
+ {"simple", "What is the capital of France", []string{"capital", "france"}},
+ {
+ "stops removed",
+ "Who is the president of the United States",
+ []string{"president", "united", "states"},
+ },
+ {
+ "max 6 keywords",
+ "one two three four five six seven eight nine ten",
+ []string{"one", "two", "three", "four", "five", "six"},
+ },
+ {"short words filtered", "I am a go to the store", []string{"am", "store"}},
+ {"empty", "", nil},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := ExtractKeywords(tt.input)
+ if len(got) != len(tt.want) {
+ t.Fatalf("ExtractKeywords(%q) = %v (len %d), want %v (len %d)",
+ tt.input, got, len(got), tt.want, len(tt.want))
+ }
+ for i := range got {
+ if got[i] != tt.want[i] {
+ t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
+ }
+ }
+ })
+ }
+}
diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go
index bf9c0389f..543577e68 100644
--- a/cmd/picoclaw/main.go
+++ b/cmd/picoclaw/main.go
@@ -9,6 +9,7 @@ package main
import (
"fmt"
"os"
+ "time"
"github.com/spf13/cobra"
@@ -24,10 +25,11 @@ import (
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/status"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal/version"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/updater"
)
func NewPicoclawCommand() *cobra.Command {
- short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, config.GetVersion())
+ short := fmt.Sprintf("%s picoclaw - Personal AI Assistant %s\n\n", internal.Logo, config.GetVersion())
cmd := &cobra.Command{
Use: "picoclaw",
@@ -45,6 +47,7 @@ func NewPicoclawCommand() *cobra.Command {
migrate.NewMigrateCommand(),
skills.NewSkillsCommand(),
model.NewModelCommand(),
+ updater.NewUpdateCommand("picoclaw"),
version.NewVersionCommand(),
)
@@ -66,6 +69,21 @@ const (
func main() {
fmt.Printf("%s", banner)
+
+ tz_env := os.Getenv("TZ")
+ if tz_env != "" {
+ fmt.Println("TZ environment:", tz_env)
+ zoneinfo_env := os.Getenv("ZONEINFO")
+ fmt.Println("ZONEINFO environment:", zoneinfo_env)
+ loc, err := time.LoadLocation(tz_env)
+ if err != nil {
+ fmt.Println("Error loading time zone:", err)
+ } else {
+ fmt.Println("Time zone loaded successfully:", loc)
+ time.Local = loc //nolint:gosmopolitan // We intentionally set local timezone from TZ env
+ }
+ }
+
cmd := NewPicoclawCommand()
if err := cmd.Execute(); err != nil {
os.Exit(1)
diff --git a/cmd/picoclaw/main_test.go b/cmd/picoclaw/main_test.go
index ad18cb330..3e147cbfe 100644
--- a/cmd/picoclaw/main_test.go
+++ b/cmd/picoclaw/main_test.go
@@ -17,7 +17,7 @@ func TestNewPicoclawCommand(t *testing.T) {
require.NotNil(t, cmd)
- short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, config.GetVersion())
+ short := fmt.Sprintf("%s picoclaw - Personal AI Assistant %s\n\n", internal.Logo, config.GetVersion())
assert.Equal(t, "picoclaw", cmd.Use)
assert.Equal(t, short, cmd.Short)
@@ -43,6 +43,7 @@ func TestNewPicoclawCommand(t *testing.T) {
"onboard",
"skills",
"status",
+ "update",
"version",
}
diff --git a/config/config.example.json b/config/config.example.json
index 814c82503..f0cce6d72 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -48,6 +48,11 @@
"model": "deepseek/deepseek-chat",
"api_key": "sk-your-deepseek-key"
},
+ {
+ "model_name": "venice-uncensored",
+ "model": "venice/venice-uncensored",
+ "api_key": "your-venice-api-key"
+ },
{
"model_name": "lmstudio-local",
"model": "lmstudio/openai/gpt-oss-20b"
@@ -416,7 +421,11 @@
"enabled": true
},
"read_file": {
- "enabled": true
+ "enabled": true,
+ "mode": "bytes"
+ },
+ "send_tts": {
+ "enabled": false
},
"spawn": {
"enabled": true
diff --git a/docker/Dockerfile.goreleaser.launcher b/docker/Dockerfile.goreleaser.launcher
index 5d65576f7..0a20a90b3 100644
--- a/docker/Dockerfile.goreleaser.launcher
+++ b/docker/Dockerfile.goreleaser.launcher
@@ -9,4 +9,4 @@ COPY $TARGETPLATFORM/picoclaw-launcher /usr/local/bin/picoclaw-launcher
COPY $TARGETPLATFORM/picoclaw-launcher-tui /usr/local/bin/picoclaw-launcher-tui
ENTRYPOINT ["picoclaw-launcher"]
-CMD ["-public", "-no-browser"]
+CMD ["-console", "-public", "-no-browser"]
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index 0bf46a2ae..7c940621f 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -45,8 +45,11 @@ services:
- launcher
environment:
- PICOCLAW_GATEWAY_HOST=0.0.0.0
+ # Set a fixed dashboard token instead of a random one each restart.
+ # If not set, a random token is generated and printed to the console on startup.
+ #- PICOCLAW_LAUNCHER_TOKEN=your-secret-token-here
ports:
- - "127.0.0.1:18800:18800"
- - "127.0.0.1:18790:18790"
+ - "18800:18800"
+ - "18790:18790"
volumes:
- ./data:/root/.picoclaw
diff --git a/docs/channels/telegram/README.fr.md b/docs/channels/telegram/README.fr.md
index d9ab0644f..17a73ad1c 100644
--- a/docs/channels/telegram/README.fr.md
+++ b/docs/channels/telegram/README.fr.md
@@ -13,18 +13,20 @@ Le canal Telegram utilise le long polling via l'API Bot Telegram pour une commun
"enabled": true,
"token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
"allow_from": ["123456789"],
- "proxy": ""
+ "proxy": "",
+ "use_markdown_v2": false
}
}
}
```
-| Champ | Type | Requis | Description |
-| ---------- | ------ | ------ | ------------------------------------------------------------------------ |
-| enabled | bool | Oui | Activer ou non le canal Telegram |
-| token | string | Oui | Token de l'API Bot Telegram |
-| allow_from | array | Non | Liste blanche d'identifiants utilisateur ; vide signifie tous les utilisateurs |
-| proxy | string | Non | URL du proxy pour se connecter à l'API Telegram (ex. http://127.0.0.1:7890) |
+| Champ | Type | Requis | Description |
+| --------------- | ------ | ------ | ------------------------------------------------------------------------ |
+| enabled | bool | Oui | Activer ou non le canal Telegram |
+| token | string | Oui | Token de l'API Bot Telegram |
+| allow_from | array | Non | Liste blanche d'identifiants utilisateur ; vide signifie tous les utilisateurs |
+| proxy | string | Non | URL du proxy pour se connecter à l'API Telegram (ex. http://127.0.0.1:7890) |
+| use_markdown_v2 | bool | Non | Activer le formatage Telegram MarkdownV2 |
## Configuration initiale
@@ -33,3 +35,20 @@ Le canal Telegram utilise le long polling via l'API Bot Telegram pour une commun
3. Obtenir le Token de l'API HTTP
4. Renseigner le Token dans le fichier de configuration
5. (Optionnel) Configurer `allow_from` pour restreindre les identifiants utilisateur autorisés à interagir (les IDs peuvent être obtenus via `@userinfobot`)
+
+## Formatage avancées
+
+Vous pouvez définir `use_markdown_v2: true` pour activer les options de formatage améliorées. Cela permet au bot d'utiliser toutes les fonctionnalités de Telegram MarkdownV2, y compris les styles imbriqués, les spoilers et les blocs de largeur fixe personnalisés.
+
+```json
+{
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_BOT_TOKEN",
+ "allow_from": ["YOUR_USER_ID"],
+ "use_markdown_v2": true
+ }
+ }
+}
+```
diff --git a/docs/channels/telegram/README.ja.md b/docs/channels/telegram/README.ja.md
index 03c48cb64..09209cc3c 100644
--- a/docs/channels/telegram/README.ja.md
+++ b/docs/channels/telegram/README.ja.md
@@ -13,18 +13,20 @@ Telegram チャンネルは、Telegram Bot API を使用したロングポーリ
"enabled": true,
"token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
"allow_from": ["123456789"],
- "proxy": ""
+ "proxy": "",
+ "use_markdown_v2": false
}
}
}
```
-| フィールド | 型 | 必須 | 説明 |
-| ---------- | ------ | ---- | ----------------------------------------------------------------- |
-| enabled | bool | はい | Telegram チャンネルを有効にするかどうか |
-| token | string | はい | Telegram Bot API トークン |
-| allow_from | array | いいえ | 許可するユーザーIDのリスト。空の場合はすべてのユーザーを許可 |
-| proxy | string | いいえ | Telegram API への接続に使用するプロキシ URL (例: http://127.0.0.1:7890) |
+| フィールド | 型 | 必須 | 説明 |
+| --------------- | ------ | ---- | ----------------------------------------------------------------- |
+| enabled | bool | はい | Telegram チャンネルを有効にするかどうか |
+| token | string | はい | Telegram Bot API トークン |
+| allow_from | array | いいえ | 許可するユーザーIDのリスト。空の場合はすべてのユーザーを許可 |
+| proxy | string | いいえ | Telegram API への接続に使用するプロキシ URL (例: http://127.0.0.1:7890) |
+| use_markdown_v2 | bool | いいえ | Telegram MarkdownV2 フォーマットを有効にする |
## セットアップ手順
@@ -33,3 +35,20 @@ Telegram チャンネルは、Telegram Bot API を使用したロングポーリ
3. HTTP API トークンを取得する
4. 設定ファイルにトークンを入力する
5. (任意) `allow_from` を設定して、対話を許可するユーザー ID を制限する(ID は `@userinfobot` で取得可能)
+
+## 高度なフォーマット
+
+`use_markdown_v2: true` を設定することで、增强されたフォーマットオプションを有効にできます。これにより、ボットは Telegram MarkdownV2 の全機能(ネストされたスタイル、スポイラー、カスタム固定幅ブロックなど)を利用できます。
+
+```json
+{
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_BOT_TOKEN",
+ "allow_from": ["YOUR_USER_ID"],
+ "use_markdown_v2": true
+ }
+ }
+}
+```
diff --git a/docs/channels/telegram/README.md b/docs/channels/telegram/README.md
index 86c016a5d..78368f5d2 100644
--- a/docs/channels/telegram/README.md
+++ b/docs/channels/telegram/README.md
@@ -13,18 +13,20 @@ The Telegram channel uses long polling via the Telegram Bot API for bot-based co
"enabled": true,
"token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
"allow_from": ["123456789"],
- "proxy": ""
+ "proxy": "",
+ "use_markdown_v2": false
}
}
}
```
-| Field | Type | Required | Description |
-| ---------- | ------ | -------- | ------------------------------------------------------------------ |
-| enabled | bool | Yes | Whether to enable the Telegram channel |
-| token | string | Yes | Telegram Bot API Token |
-| allow_from | array | No | Allowlist of user IDs; empty means all users are allowed |
-| proxy | string | No | Proxy URL for connecting to the Telegram API (e.g. http://127.0.0.1:7890) |
+| Field | Type | Required | Description |
+| ---------------- | ------ | -------- | ------------------------------------------------------------------ |
+| enabled | bool | Yes | Whether to enable the Telegram channel |
+| token | string | Yes | Telegram Bot API Token |
+| allow_from | array | No | Allowlist of user IDs; empty means all users are allowed |
+| proxy | string | No | Proxy URL for connecting to the Telegram API (e.g. http://127.0.0.1:7890) |
+| use_markdown_v2 | bool | No | Enable Telegram MarkdownV2 formatting |
## Setup
@@ -53,3 +55,20 @@ Examples:
/use git
explain how to squash the last 3 commits
```
+
+## Advanced Formatting
+
+You can set `use_markdown_v2: true` to enable enhanced formatting options. This allows the bot to utilize the full range of Telegram MarkdownV2 features, including nested styles, spoilers, and custom fixed-width blocks.
+
+```json
+{
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_BOT_TOKEN",
+ "allow_from": ["YOUR_USER_ID"],
+ "use_markdown_v2": true
+ }
+ }
+}
+```
diff --git a/docs/channels/telegram/README.pt-br.md b/docs/channels/telegram/README.pt-br.md
index 8d2c935b4..e86d51d8e 100644
--- a/docs/channels/telegram/README.pt-br.md
+++ b/docs/channels/telegram/README.pt-br.md
@@ -13,18 +13,20 @@ O canal Telegram utiliza long polling via a API de Bot do Telegram para comunica
"enabled": true,
"token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
"allow_from": ["123456789"],
- "proxy": ""
+ "proxy": "",
+ "use_markdown_v2": false
}
}
}
```
-| Campo | Tipo | Obrigatório | Descrição |
-| ---------- | ------ | ----------- | -------------------------------------------------------------------------- |
-| enabled | bool | Sim | Se o canal Telegram deve ser habilitado |
-| token | string | Sim | Token da API de Bot do Telegram |
-| allow_from | array | Não | Lista de IDs de usuários permitidos; vazio significa todos os usuários |
-| proxy | string | Não | URL do proxy para conexão com a API do Telegram (ex. http://127.0.0.1:7890) |
+| Campo | Tipo | Obrigatório | Descrição |
+| --------------- | ------ | ----------- | -------------------------------------------------------------------------- |
+| enabled | bool | Sim | Se o canal Telegram deve ser habilitado |
+| token | string | Sim | Token da API de Bot do Telegram |
+| allow_from | array | Não | Lista de IDs de usuários permitidos; vazio significa todos os usuários |
+| proxy | string | Não | URL do proxy para conexão com a API do Telegram (ex. http://127.0.0.1:7890) |
+| use_markdown_v2 | bool | Não | Habilitar formatação Telegram MarkdownV2 |
## Configuração inicial
@@ -33,3 +35,20 @@ O canal Telegram utiliza long polling via a API de Bot do Telegram para comunica
3. Obtenha o Token da API HTTP
4. Preencha o Token no arquivo de configuração
5. (Opcional) Configure `allow_from` para restringir quais IDs de usuário podem interagir (os IDs podem ser obtidos via `@userinfobot`)
+
+## Formatação Avançada
+
+Você pode definir `use_markdown_v2: true` para habilitar opções de formatação aprimoradas. Isso permite que o bot utilize todos os recursos do Telegram MarkdownV2, incluindo estilos aninhados, spoilers e blocos de largura fixa personalizados.
+
+```json
+{
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_BOT_TOKEN",
+ "allow_from": ["YOUR_USER_ID"],
+ "use_markdown_v2": true
+ }
+ }
+}
+```
diff --git a/docs/channels/telegram/README.vi.md b/docs/channels/telegram/README.vi.md
index 858a9fc41..70ee1f51b 100644
--- a/docs/channels/telegram/README.vi.md
+++ b/docs/channels/telegram/README.vi.md
@@ -13,18 +13,20 @@ Kênh Telegram sử dụng long polling qua Telegram Bot API để giao tiếp d
"enabled": true,
"token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
"allow_from": ["123456789"],
- "proxy": ""
+ "proxy": "",
+ "use_markdown_v2": false
}
}
}
```
-| Trường | Kiểu | Bắt buộc | Mô tả |
-| ---------- | ------ | -------- | ------------------------------------------------------------------------ |
-| enabled | bool | Có | Có bật kênh Telegram hay không |
-| token | string | Có | Token API Bot Telegram |
-| allow_from | array | Không | Danh sách trắng ID người dùng; để trống nghĩa là cho phép tất cả |
-| proxy | string | Không | URL proxy để kết nối với Telegram API (ví dụ: http://127.0.0.1:7890) |
+| Trường | Kiểu | Bắt buộc | Mô tả |
+| -------------- | ------ | -------- | ------------------------------------------------------------------------ |
+| enabled | bool | Có | Có bật kênh Telegram hay không |
+| token | string | Có | Token API Bot Telegram |
+| allow_from | array | Không | Danh sách trắng ID người dùng; để trống nghĩa là cho phép tất cả |
+| proxy | string | Không | URL proxy để kết nối với Telegram API (ví dụ: http://127.0.0.1:7890) |
+| use_markdown_v2 | bool | Không | Bật định dạng Telegram MarkdownV2 |
## Hướng dẫn thiết lập
@@ -33,3 +35,20 @@ Kênh Telegram sử dụng long polling qua Telegram Bot API để giao tiếp d
3. Lấy Token API HTTP
4. Điền Token vào file cấu hình
5. (Tùy chọn) Cấu hình `allow_from` để giới hạn ID người dùng được phép tương tác (có thể lấy ID qua `@userinfobot`)
+
+## Định dạng nâng cao
+
+Bạn có thể đặt `use_markdown_v2: true` để bật các tùy chọn định dạng nâng cao. Điều này cho phép bot sử dụng toàn bộ các tính năng của Telegram MarkdownV2, bao gồm các kiểu lồng nhau, spoiler và các khối chiều rộng cố định tùy chỉnh.
+
+```json
+{
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_BOT_TOKEN",
+ "allow_from": ["YOUR_USER_ID"],
+ "use_markdown_v2": true
+ }
+ }
+}
+```
diff --git a/docs/channels/telegram/README.zh.md b/docs/channels/telegram/README.zh.md
index 1d9dcc46e..fc544cd86 100644
--- a/docs/channels/telegram/README.zh.md
+++ b/docs/channels/telegram/README.zh.md
@@ -13,18 +13,20 @@ Telegram Channel 通过 Telegram 机器人 API 使用长轮询实现基于机器
"enabled": true,
"token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
"allow_from": ["123456789"],
- "proxy": ""
+ "proxy": "",
+ "use_markdown_v2": false
}
}
}
```
-| 字段 | 类型 | 必填 | 描述 |
-| ---------- | ------ | ---- | --------------------------------------------------------- |
-| enabled | bool | 是 | 是否启用 Telegram 频道 |
-| token | string | 是 | Telegram 机器人 API Token |
-| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
-| proxy | string | 否 | 连接 Telegram API 的代理 URL (例如 http://127.0.0.1:7890) |
+| 字段 | 类型 | 必填 | 描述 |
+| ---------------- | ------ | ---- | --------------------------------------------------------- |
+| enabled | bool | 是 | 是否启用 Telegram 频道 |
+| token | string | 是 | Telegram 机器人 API Token |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+| proxy | string | 否 | 连接 Telegram API 的代理 URL (例如 http://127.0.0.1:7890) |
+| use_markdown_v2 | bool | 否 | 启用 Telegram MarkdownV2 格式化 |
## 设置流程
@@ -50,6 +52,23 @@ Telegram 会在启动时自动注册 PicoClaw 的顶级 Bot 命令,包括 `/st
```text
/list skills
/use git explain how to squash the last 3 commits
-/use italiapersonalfinance
-dammi le ultime news
+/use git
+explain how to squash the last 3 commits
+```
+
+## 高级格式化
+
+您可以设置 `use_markdown_v2: true` 来启用增强的格式化选项。这允许机器人使用 Telegram MarkdownV2 的全部功能,包括嵌套样式、剧透和自定义等宽代码块。
+
+```json
+{
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_BOT_TOKEN",
+ "allow_from": ["YOUR_USER_ID"],
+ "use_markdown_v2": true
+ }
+ }
+}
```
diff --git a/docs/channels/vk/README.md b/docs/channels/vk/README.md
new file mode 100644
index 000000000..bfff084e6
--- /dev/null
+++ b/docs/channels/vk/README.md
@@ -0,0 +1,194 @@
+# VK (VKontakte)
+
+The VK channel uses Bots Long Poll API for bot-based communication with VK social network. It supports text messages, media attachments (photos, videos, audio, documents, stickers), and group chat interactions.
+
+## Configuration
+
+```json
+{
+ "channels": {
+ "vk": {
+ "enabled": true,
+ "token": "NOT_HERE",
+ "group_id": 123456789,
+ "allow_from": ["123456789"],
+ "group_trigger": {
+ "mention_only": false,
+ "prefixes": ["/bot", "!bot"]
+ }
+ }
+ }
+}
+```
+
+| Field | Type | Required | Description |
+| ---------------- | ------ | -------- | ------------------------------------------------------------------ |
+| enabled | bool | Yes | Whether to enable the VK channel |
+| token | string | Yes | Set to `NOT_HERE` - token is stored securely (see Token Storage) |
+| group_id | int | Yes | VK Community ID (Group ID) |
+| allow_from | array | No | Allowlist of user IDs; empty means all users are allowed |
+| group_trigger | object | No | Configuration for group chat triggers |
+
+### Token Storage
+
+For security reasons, the VK access token should not be stored directly in the configuration file. Instead:
+
+1. Set `token` to `"NOT_HERE"` in the configuration
+2. Store the actual token using one of these methods:
+ - **Environment variable**: Set `PICOCLAW_CHANNELS_VK_TOKEN` environment variable
+ - **Secure storage**: Use PicoClaw's secure token storage mechanism
+
+Example using environment variable:
+```bash
+export PICOCLAW_CHANNELS_VK_TOKEN="vk1.a.abc123..."
+```
+
+### Group Trigger Configuration
+
+| Field | Type | Description |
+| ------------ | -------- | ------------------------------------------------------------------ |
+| mention_only | bool | Only respond when bot is mentioned in group chats |
+| prefixes | []string | List of prefixes that trigger bot response in group chats |
+
+## Setup
+
+### 1. Create a VK Community
+
+1. Go to [VK](https://vk.com) and log in
+2. Create a new community or use an existing one
+3. Note your Community ID (found in the community URL, e.g., `public123456789`)
+
+### 2. Enable Messages
+
+1. Go to your community page
+2. Click "Manage" → "Messages" → "Community Messages"
+3. Enable community messages
+
+### 3. Create Access Token
+
+1. Go to "Manage" → "API usage" → "Access tokens"
+2. Click "Create token"
+3. Select the following permissions:
+ - `messages` - Access to messages
+ - `photos` - Access to photos (optional)
+ - `docs` - Access to documents (optional)
+4. Copy the generated access token
+5. Store the token securely (see Token Storage section below)
+
+### 4. Configure PicoClaw
+
+1. Add the token to your PicoClaw configuration
+2. Set the `group_id` to your community ID (numeric value)
+3. (Optional) Configure `allow_from` to restrict which user IDs can interact
+
+## Features
+
+### Supported Message Types
+
+- **Text messages**: Full support for text messages
+- **Photos**: Photos are displayed as `[photo]` placeholder
+- **Videos**: Videos are displayed as `[video]` placeholder
+- **Audio**: Audio files are displayed as `[audio]` placeholder
+- **Voice messages**: Voice messages are displayed as `[voice]` placeholder and support transcription
+- **Documents**: Documents are displayed as `[document: filename]`
+- **Stickers**: Stickers are displayed as `[sticker]` placeholder
+
+### Voice Support
+
+The VK channel supports both voice message reception and text-to-speech capabilities:
+
+- **ASR (Automatic Speech Recognition)**: Voice messages can be transcribed to text using configured voice models
+- **TTS (Text-to-Speech)**: Text responses can be converted to voice messages
+
+To enable voice transcription, configure a voice model in your providers setup. See [Voice Transcription](../../providers.md#voice-transcription) for details.
+
+### Group Chat Support
+
+The VK channel supports group chats with configurable triggers:
+
+- **Mention-only mode**: Bot only responds when mentioned
+- **Prefix mode**: Bot responds to messages starting with specified prefixes
+- **Permissive mode**: Bot responds to all messages (default)
+
+### Message Length
+
+VK has a maximum message length of 4000 characters. PicoClaw automatically splits longer messages into multiple parts.
+
+## Example Configuration
+
+### Basic Configuration
+
+```json
+{
+ "channels": {
+ "vk": {
+ "enabled": true,
+ "token": "NOT_HERE",
+ "group_id": 123456789
+ }
+ }
+}
+```
+
+### With User Whitelist
+
+```json
+{
+ "channels": {
+ "vk": {
+ "enabled": true,
+ "token": "NOT_HERE",
+ "group_id": 123456789,
+ "allow_from": ["123456789", "987654321"]
+ }
+ }
+}
+```
+
+### With Group Chat Triggers
+
+```json
+{
+ "channels": {
+ "vk": {
+ "enabled": true,
+ "token": "NOT_HERE",
+ "group_id": 123456789,
+ "group_trigger": {
+ "prefixes": ["/bot", "!bot"]
+ }
+ }
+ }
+}
+```
+
+## Troubleshooting
+
+### Bot Not Responding
+
+1. Check that the access token is valid
+2. Verify that the `group_id` is correct
+3. Ensure the user ID is in `allow_from` if configured
+4. Check PicoClaw logs for error messages
+
+### Permission Errors
+
+Make sure the access token has the necessary permissions:
+- `messages` - Required for sending and receiving messages
+- `photos` - Optional, for handling photo attachments
+- `docs` - Optional, for handling document attachments
+
+### Group Chat Issues
+
+If the bot doesn't respond in group chats:
+1. Check `group_trigger` configuration
+2. Try using a prefix to trigger the bot
+3. Check if the bot has permission to read group messages
+
+## API Reference
+
+The VK channel uses the [VK SDK for Go](https://github.com/SevereCloud/vksdk) library, which supports VK API version 5.199.
+
+For more information about VK API, see:
+- [VK API Documentation](https://dev.vk.com/en)
+- [VK Bots Long Poll API](https://dev.vk.com/en/api/bots-long-poll/getting-started)
diff --git a/docs/configuration.md b/docs/configuration.md
index 363b59690..6c5a9f776 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -246,6 +246,66 @@ Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous
| `tools.allow_read_paths` | string[] | `[]` | Additional paths allowed for reading outside workspace |
| `tools.allow_write_paths` | string[] | `[]` | Additional paths allowed for writing outside workspace |
+### Read File Mode
+
+`read_file` has two mutually exclusive implementations selected by config. PicoClaw registers exactly one of them at startup:
+
+| Config Key | Type | Default | Description |
+|------------|------|---------|-------------|
+| `tools.read_file.enabled` | bool | `true` | Enables the `read_file` tool |
+| `tools.read_file.mode` | string | `bytes` | Selects the `read_file` implementation: `bytes` or `lines` |
+| `tools.read_file.max_read_file_size` | int | `65536` | Maximum bytes returned by `read_file` |
+
+#### Mode: `bytes`
+
+Optimized for arbitrary files and binary-safe pagination.
+
+Parameters:
+
+* `path` (required): File path
+* `offset` (optional): Starting byte offset, default `0`
+* `length` (optional): Maximum number of bytes to read, default `max_read_file_size`
+
+Use `bytes` when:
+
+* You may read binary files
+* You want deterministic byte-range pagination
+
+#### Mode: `lines`
+
+Text-oriented behavior, optimized for source files, markdown, logs, and configs. The tool reads sequentially by line and stops when the configured byte budget is reached.
+
+Parameters:
+
+* `path` (required): File path
+* `start_line` (optional): Starting line number, 1-indexed and inclusive, default `1`
+* `max_lines` (optional): Maximum number of lines to read, default = all remaining lines until EOF or byte budget
+
+Behavior notes:
+
+* Binary-looking files are rejected with guidance to switch `read_file` to `mode = bytes`
+* Extremely long single lines are truncated rather than skipped
+
+Use `mode = lines` when:
+
+* The agent mostly reads text files
+* You want line-based pagination in prompts and tool calls
+* You want cleaner chunks for code review, logs, and documentation
+
+#### Example
+
+```json
+{
+ "tools": {
+ "read_file": {
+ "enabled": true,
+ "mode": "lines",
+ "max_read_file_size": 65536
+ }
+ }
+}
+```
+
### Exec Security
| Config Key | Type | Default | Description |
diff --git a/docs/fr/providers.md b/docs/fr/providers.md
index d0da81897..3305ec5ee 100644
--- a/docs/fr/providers.md
+++ b/docs/fr/providers.md
@@ -99,6 +99,24 @@ Cette conception permet également le **support multi-agents** avec une sélecti
}
```
+#### Champs d'entrée `model_list`
+
+| Champ | Type | Requis | Description |
+|-------|------|--------|-------------|
+| `model_name` | string | Oui | Nom unique pour référencer ce modèle dans la config agent |
+| `model` | string | Oui | Identifiant fournisseur/modèle (ex : `openai/gpt-5.4`, `azure/gpt-5.4`, `anthropic/claude-sonnet-4.6`) |
+| `api_keys` | string[] | Oui* | Clé(s) API pour l'authentification. Plusieurs clés permettent la rotation par requête. Non requis pour les fournisseurs locaux (Ollama, LM Studio, VLLM) |
+| `api_base` | string | Non | Remplace l'URL de base API par défaut |
+| `proxy` | string | Non | URL du proxy HTTP pour cette entrée de modèle |
+| `user_agent` | string | Non | En-tête `User-Agent` personnalisé pour les requêtes API (supporté par les providers OpenAI-compatible, Anthropic et Azure) |
+| `request_timeout` | int | Non | Délai d'expiration de la requête en secondes (la valeur par défaut varie selon le provider) |
+| `max_tokens_field` | string | Non | Remplace le nom du champ max tokens dans le corps de la requête (ex : `max_completion_tokens` pour les modèles o1) |
+| `thinking_level` | string | Non | Niveau de pensée étendue : `off`, `low`, `medium`, `high`, `xhigh` ou `adaptive` |
+| `extra_body` | object | Non | Champs supplémentaires à injecter dans chaque corps de requête |
+| `rpm` | int | Non | Limite de requêtes par minute |
+| `fallbacks` | string[] | Non | Noms des modèles de secours pour le basculement automatique |
+| `enabled` | bool | Non | Activer ou désactiver cette entrée de modèle (par défaut : `true`) |
+
#### Exemples par Vendor
**OpenAI**
@@ -190,6 +208,7 @@ Pour l'accès direct à l'API Anthropic ou les endpoints personnalisés qui ne p
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
"api_keys": ["sk-..."],
+ "user_agent": "MyApp/1.0",
"request_timeout": 300
}
```
diff --git a/docs/it/configuration.md b/docs/it/configuration.md
deleted file mode 100644
index 6a79a9543..000000000
--- a/docs/it/configuration.md
+++ /dev/null
@@ -1,219 +0,0 @@
-# ⚙️ Guida alla Configurazione
-
-> Torna al [README](../../README.md)
-
-## ⚙️ Configurazione
-
-File di configurazione: `~/.picoclaw/config.json`
-
-### Variabili d'Ambiente
-
-Puoi sovrascrivere i percorsi predefiniti usando variabili d'ambiente. Questo è utile per installazioni portatili, distribuzioni containerizzate, o per eseguire picoclaw come servizio di sistema. Queste variabili sono indipendenti e controllano percorsi diversi.
-
-| Variabile | Descrizione | Percorso Predefinito |
-|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
-| `PICOCLAW_CONFIG` | Sovrascrive il percorso al file di configurazione. Indica direttamente a picoclaw quale `config.json` caricare, ignorando tutte le altre posizioni. | `~/.picoclaw/config.json` |
-| `PICOCLAW_HOME` | Sovrascrive la directory radice per i dati di picoclaw. Modifica la posizione predefinita del `workspace` e delle altre directory dati. | `~/.picoclaw` |
-
-**Esempi:**
-
-```bash
-# Esegui picoclaw usando un file di configurazione specifico
-# Il percorso del workspace verrà letto da quel file di configurazione
-PICOCLAW_CONFIG=/etc/picoclaw/production.json picoclaw gateway
-
-# Esegui picoclaw con tutti i dati salvati in /opt/picoclaw
-# La configurazione verrà caricata dal percorso predefinito ~/.picoclaw/config.json
-# Il workspace verrà creato in /opt/picoclaw/workspace
-PICOCLAW_HOME=/opt/picoclaw picoclaw agent
-
-# Usa entrambi per un setup completamente personalizzato
-PICOCLAW_HOME=/srv/picoclaw PICOCLAW_CONFIG=/srv/picoclaw/main.json picoclaw gateway
-```
-
-### Struttura del Workspace
-
-PicoClaw salva i dati nel workspace configurato (predefinito: `~/.picoclaw/workspace`):
-
-```
-~/.picoclaw/workspace/
-├── sessions/ # Sessioni di conversazione e cronologia
-├── memory/ # Memoria a lungo termine (MEMORY.md)
-├── state/ # Stato persistente (ultimo canale, ecc.)
-├── cron/ # Database dei job pianificati
-├── skills/ # Skill personalizzate
-├── AGENTS.md # Guida al comportamento dell'agent
-├── HEARTBEAT.md # Prompt per task periodici (controllato ogni 30 min)
-├── IDENTITY.md # Identità dell'agent
-├── SOUL.md # Anima dell'agent
-└── USER.md # Preferenze dell'utente
-```
-
-> **Nota:** Le modifiche a `AGENTS.md`, `SOUL.md`, `USER.md`, `IDENTITY.md` e `memory/MEMORY.md` vengono rilevate automaticamente a runtime tramite il tracciamento della data di modifica (mtime). **Non è necessario riavviare il gateway** dopo aver modificato questi file — l'agent caricherà il nuovo contenuto alla prossima richiesta.
-
-### Sorgenti delle Skill
-
-Per impostazione predefinita, le skill vengono caricate da:
-
-1. `~/.picoclaw/workspace/skills` (workspace)
-2. `~/.picoclaw/skills` (globale)
-3. `/skills` (builtin)
-
-Per configurazioni avanzate/di test, puoi sovrascrivere la directory radice delle skill builtin con:
-
-```bash
-export PICOCLAW_BUILTIN_SKILLS=/path/to/skills
-```
-
-### Politica Unificata di Esecuzione dei Comandi
-
-- I comandi slash generici vengono eseguiti tramite un unico percorso in `pkg/agent/loop.go` via `commands.Executor`.
-- Gli adattatori dei canali non consumano più localmente i comandi generici; inoltrano il testo in entrata al percorso bus/agent. Telegram registra ancora automaticamente i comandi supportati all'avvio.
-- Un comando slash sconosciuto (ad esempio `/foo`) viene passato all'elaborazione LLM come se fosse un messaggio dell'utente.
-- Un comando registrato ma non supportato sul canale corrente (ad esempio `/show` su WhatsApp) restituisce un errore esplicito all'utente e interrompe l'elaborazione.
-
-### 🔒 Sandbox di Sicurezza
-
-PicoClaw esegue in un ambiente sandboxed per impostazione predefinita. L'agent può accedere solo ai file ed eseguire comandi all'interno del workspace configurato.
-
-#### Configurazione Predefinita
-
-```json
-{
- "agents": {
- "defaults": {
- "workspace": "~/.picoclaw/workspace",
- "restrict_to_workspace": true
- }
- }
-}
-```
-
-| Opzione | Predefinito | Descrizione |
-| ----------------------- | ----------------------- | ---------------------------------------------------- |
-| `workspace` | `~/.picoclaw/workspace` | Directory di lavoro dell'agent |
-| `restrict_to_workspace` | `true` | Limita l'accesso a file/comandi al workspace |
-
-#### Strumenti Protetti
-
-Quando `restrict_to_workspace: true`, i seguenti strumenti sono in sandbox:
-
-| Strumento | Funzione | Restrizione |
-| ------------- | ------------------------- | ---------------------------------------------------- |
-| `read_file` | Legge file | Solo file all'interno del workspace |
-| `write_file` | Scrive file | Solo file all'interno del workspace |
-| `list_dir` | Elenca directory | Solo directory all'interno del workspace |
-| `edit_file` | Modifica file | Solo file all'interno del workspace |
-| `append_file` | Aggiunge ai file | Solo file all'interno del workspace |
-| `exec` | Esegue comandi | I percorsi dei comandi devono essere nel workspace |
-
-#### Protezione Exec Aggiuntiva
-
-Anche con `restrict_to_workspace: false`, lo strumento `exec` blocca questi comandi pericolosi:
-
-* `rm -rf`, `del /f`, `rmdir /s` — Cancellazione di massa
-* `format`, `mkfs`, `diskpart` — Formattazione del disco
-* `dd if=` — Imaging del disco
-* Scrittura su `/dev/sd[a-z]` — Scritture dirette su disco
-* `shutdown`, `reboot`, `poweroff` — Spegnimento del sistema
-* Fork bomb `:(){ :|:& };:`
-
-### Controllo Accesso ai File
-
-| Chiave di configurazione | Tipo | Predefinito | Descrizione |
-|--------------------------|------|-------------|-------------|
-| `tools.allow_read_paths` | string[] | `[]` | Percorsi aggiuntivi consentiti per la lettura al di fuori del workspace |
-| `tools.allow_write_paths` | string[] | `[]` | Percorsi aggiuntivi consentiti per la scrittura al di fuori del workspace |
-
-### Sicurezza Exec
-
-| Chiave di configurazione | Tipo | Predefinito | Descrizione |
-|--------------------------|------|-------------|-------------|
-| `tools.exec.allow_remote` | bool | `false` | Consente lo strumento exec da canali remoti (Telegram/Discord ecc.) |
-| `tools.exec.enable_deny_patterns` | bool | `true` | Abilita l'intercettazione dei comandi pericolosi |
-| `tools.exec.custom_deny_patterns` | string[] | `[]` | Pattern regex personalizzati da bloccare |
-| `tools.exec.custom_allow_patterns` | string[] | `[]` | Pattern regex personalizzati da consentire |
-
-> **Nota di sicurezza:** La protezione dei symlink è abilitata per impostazione predefinita — tutti i percorsi file vengono risolti tramite `filepath.EvalSymlinks` prima del confronto con la whitelist, prevenendo attacchi di escape tramite symlink.
-
-#### Limitazione Nota: Processi Figlio degli Strumenti di Build
-
-Il controllo di sicurezza exec ispeziona solo la riga di comando avviata direttamente da PicoClaw. Non ispeziona ricorsivamente i processi figlio generati da strumenti di sviluppo consentiti come `make`, `go run`, `cargo`, `npm run` o script di build personalizzati.
-
-Ciò significa che un comando di primo livello può comunque compilare o avviare altri binari dopo aver superato il controllo iniziale. In pratica, tratta gli script di build, i Makefile, gli script di pacchetti e i binari generati come codice eseguibile che richiede lo stesso livello di revisione di un comando shell diretto.
-
-Per ambienti ad alto rischio:
-
-* Esamina gli script di build prima dell'esecuzione.
-* Preferisci l'approvazione/revisione manuale per i workflow di compilazione ed esecuzione.
-* Esegui PicoClaw in un container o VM se hai bisogno di un isolamento più forte di quello fornito dal controllo integrato.
-
-#### Esempi di Errore
-
-```
-[ERROR] tool: Tool execution failed
-{tool=exec, error=Command blocked by safety guard (path outside working dir)}
-```
-
-```
-[ERROR] tool: Tool execution failed
-{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
-```
-
-#### Disabilitare le Restrizioni (Rischio di Sicurezza)
-
-Se hai bisogno che l'agent acceda a percorsi al di fuori del workspace:
-
-**Metodo 1: File di configurazione**
-
-```json
-{
- "agents": {
- "defaults": {
- "restrict_to_workspace": false
- }
- }
-}
-```
-
-**Metodo 2: Variabile d'ambiente**
-
-```bash
-export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
-```
-
-> ⚠️ **Attenzione**: Disabilitare questa restrizione consente all'agent di accedere a qualsiasi percorso sul tuo sistema. Usare con cautela solo in ambienti controllati.
-
-#### Coerenza dei Confini di Sicurezza
-
-L'impostazione `restrict_to_workspace` si applica in modo coerente a tutti i percorsi di esecuzione:
-
-| Percorso di esecuzione | Confine di sicurezza |
-| ---------------------- | --------------------------------- |
-| Main Agent | `restrict_to_workspace` ✅ |
-| Subagent / Spawn | Eredita la stessa restrizione ✅ |
-| Heartbeat tasks | Eredita la stessa restrizione ✅ |
-
-Tutti i percorsi condividono la stessa restrizione del workspace — non è possibile aggirare il confine di sicurezza tramite subagent o task pianificati.
-
-### Heartbeat (Task Periodici)
-
-PicoClaw può eseguire task periodici automaticamente. Crea un file `HEARTBEAT.md` nel tuo workspace:
-
-```markdown
-# Periodic Tasks
-
-- Check my email for important messages
-- Review my calendar for upcoming events
-- Check the weather forecast
-```
-
-L'agent leggerà questo file ogni 30 minuti (configurabile) ed eseguirà tutti i task usando gli strumenti disponibili.
-
-#### Task Asincroni con Spawn
-
-Per task di lunga durata (ricerca web, chiamate API), usa lo strumento `spawn` per creare un **subagent**:
-
-```markdown
-# Periodic Tasks
-```
diff --git a/docs/ja/providers.md b/docs/ja/providers.md
index e29c113f3..878530966 100644
--- a/docs/ja/providers.md
+++ b/docs/ja/providers.md
@@ -99,6 +99,24 @@
}
```
+#### `model_list` エントリフィールド
+
+| フィールド | 型 | 必須 | 説明 |
+|-----------|------|------|------|
+| `model_name` | string | はい | agent 設定でこのモデルを参照するための一意の名前 |
+| `model` | string | はい | ベンダー/モデル識別子(例:`openai/gpt-5.4`、`azure/gpt-5.4`、`anthropic/claude-sonnet-4.6`) |
+| `api_keys` | string[] | はい* | 認証キー。複数キーでリクエストごとのローテーションが可能。ローカル provider(Ollama、LM Studio、VLLM)には不要 |
+| `api_base` | string | いいえ | デフォルトの API エンドポイント URL を上書き |
+| `proxy` | string | いいえ | このモデルエントリの HTTP プロキシ URL |
+| `user_agent` | string | いいえ | カスタム `User-Agent` リクエストヘッダー(OpenAI 互換、Anthropic、Azure provider で対応) |
+| `request_timeout` | int | いいえ | リクエストタイムアウト(秒)。デフォルト値は provider により異なる |
+| `max_tokens_field` | string | いいえ | リクエストボディの max tokens フィールド名を上書き(例:o1 モデルでは `max_completion_tokens`) |
+| `thinking_level` | string | いいえ | 拡張思考レベル:`off`、`low`、`medium`、`high`、`xhigh`、`adaptive` |
+| `extra_body` | object | いいえ | 各リクエストボディに注入する追加フィールド |
+| `rpm` | int | いいえ | 1 分あたりのリクエストレート制限 |
+| `fallbacks` | string[] | いいえ | 自動フェイルオーバーのフォールバックモデル名 |
+| `enabled` | bool | いいえ | このモデルエントリを有効にするかどうか(デフォルト:`true`) |
+
#### ベンダー別設定例
**OpenAI**
@@ -201,6 +219,7 @@ Anthropic API への直接アクセスや、Anthropic のネイティブメッ
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
"api_keys": ["sk-..."],
+ "user_agent": "MyApp/1.0",
"request_timeout": 300
}
```
diff --git a/docs/providers.md b/docs/providers.md
index f45aa5f3b..d03fbab3e 100644
--- a/docs/providers.md
+++ b/docs/providers.md
@@ -16,6 +16,7 @@
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
+| `venice` | LLM (Venice AI direct) | [venice.ai](https://venice.ai) |
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `qwen` | LLM (Qwen direct) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
@@ -46,6 +47,7 @@ This design also enables **multi-agent support** with flexible provider selectio
| Vendor | `model` Prefix | Default API Base | Protocol | API Key |
| ------------------- | ----------------- |-----------------------------------------------------| --------- | ---------------------------------------------------------------- |
| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) |
+| **Venice AI** | `venice/` | `https://api.venice.ai/api/v1` | OpenAI | [Get Key](https://venice.ai) |
| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) |
| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
| **Z.AI Coding Plan** | `openai/` | `https://api.z.ai/api/coding/paas/v4` | OpenAI | [Get Key](https://z.ai/manage-apikey/apikey-list) |
@@ -106,6 +108,25 @@ This design also enables **multi-agent support** with flexible provider selectio
}
```
+#### `model_list` Entry Fields
+
+| Field | Type | Required | Description |
+|-------|------|----------|-------------|
+| `model_name` | string | Yes | Unique name used to reference this model in agent config |
+| `model` | string | Yes | Vendor/model identifier (e.g., `openai/gpt-5.4`, `azure/gpt-5.4`, `anthropic/claude-sonnet-4.6`) |
+| `api_keys` | string[] | Yes* | API key(s) for authentication. Multiple keys enable per-request rotation. Not required for local providers (Ollama, LM Studio, VLLM) |
+| `api_base` | string | No | Override the default API endpoint URL |
+| `proxy` | string | No | HTTP proxy URL for this model entry |
+| `user_agent` | string | No | Custom `User-Agent` header sent with API requests (supported by OpenAI-compatible, Anthropic, and Azure providers) |
+| `request_timeout` | int | No | Request timeout in seconds (default varies by provider) |
+| `max_tokens_field` | string | No | Override the max tokens field name in request body (e.g., `max_completion_tokens` for o1 models) |
+| `thinking_level` | string | No | Extended thinking level: `off`, `low`, `medium`, `high`, `xhigh`, or `adaptive` |
+| `extra_body` | object | No | Additional fields to inject into every request body |
+| `custom_headers` | object | No | Additional HTTP headers to inject into every request (e.g., `{"X-Source":"coding-plan"}`). If a key matches a built-in header, the custom value overrides the built-in one (e.g., `Authorization`, `User-Agent`, `Content-Type`, `Accept`). |
+| `rpm` | int | No | Per-minute request rate limit |
+| `fallbacks` | string[] | No | Fallback model names for automatic failover |
+| `enabled` | bool | No | Whether this model entry is active (default: `true`) |
+
#### Voice Transcription
You can configure a dedicated model for audio transcription with `voice.model_name`. This lets you reuse existing multimodal providers that support audio input instead of relying only on Groq.
@@ -247,6 +268,7 @@ PicoClaw sends OpenAI-compatible requests to LM Studio, and strips the `lmstudio
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
"api_keys": ["sk-..."],
+ "user_agent": "MyApp/1.0",
"request_timeout": 300
}
```
diff --git a/docs/pt-br/providers.md b/docs/pt-br/providers.md
index c7c6305e2..103490dc7 100644
--- a/docs/pt-br/providers.md
+++ b/docs/pt-br/providers.md
@@ -99,6 +99,24 @@ Este design também permite **suporte multi-agente** com seleção flexível de
}
```
+#### Campos de entrada `model_list`
+
+| Campo | Tipo | Obrigatório | Descrição |
+|-------|------|-------------|-----------|
+| `model_name` | string | Sim | Nome único para referenciar este modelo na config do agent |
+| `model` | string | Sim | Identificador fornecedor/modelo (ex: `openai/gpt-5.4`, `azure/gpt-5.4`, `anthropic/claude-sonnet-4.6`) |
+| `api_keys` | string[] | Sim* | Chave(s) API para autenticação. Múltiplas chaves permitem rotação por requisição. Não necessário para providers locais (Ollama, LM Studio, VLLM) |
+| `api_base` | string | Não | Substitui a URL base da API padrão |
+| `proxy` | string | Não | URL do proxy HTTP para esta entrada de modelo |
+| `user_agent` | string | Não | Cabeçalho `User-Agent` personalizado enviado com requisições API (suportado por providers OpenAI-compatible, Anthropic e Azure) |
+| `request_timeout` | int | Não | Timeout de requisição em segundos (o padrão varia por provider) |
+| `max_tokens_field` | string | Não | Substitui o nome do campo max tokens no corpo da requisição (ex: `max_completion_tokens` para modelos o1) |
+| `thinking_level` | string | Não | Nível de pensamento estendido: `off`, `low`, `medium`, `high`, `xhigh` ou `adaptive` |
+| `extra_body` | object | Não | Campos adicionais para injetar em cada corpo de requisição |
+| `rpm` | int | Não | Limite de requisições por minuto |
+| `fallbacks` | string[] | Não | Nomes dos modelos de fallback para failover automático |
+| `enabled` | bool | Não | Ativar ou desativar esta entrada de modelo (padrão: `true`) |
+
#### Exemplos por Vendor
**OpenAI**
@@ -190,6 +208,7 @@ Para acesso direto à API Anthropic ou endpoints personalizados que suportam ape
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
"api_keys": ["sk-..."],
+ "user_agent": "MyApp/1.0",
"request_timeout": 300
}
```
diff --git a/docs/rate-limiting.md b/docs/rate-limiting.md
new file mode 100644
index 000000000..b54c757f8
--- /dev/null
+++ b/docs/rate-limiting.md
@@ -0,0 +1,95 @@
+# Dynamic Rate Limiting
+
+PicoClaw prevents 429 errors from LLM provider APIs by enforcing configurable per-model request-rate limits **before** sending each request. Unlike the reactive cooldown/fallback system (which activates *after* a 429 is received), rate limiting is **proactive**: it keeps outbound QPS within the provider's free-tier or plan limits.
+
+## How it works
+
+### Token-bucket algorithm
+
+Each rate-limited model gets a token bucket:
+
+- **Capacity** = `rpm` (burst size equals the per-minute limit)
+- **Refill rate** = `rpm / 60` tokens per second
+- Tokens are consumed one per LLM call; if the bucket is empty, the call blocks until a token refills or the request context is cancelled
+
+### Call chain integration
+
+```
+AgentLoop.callLLM()
+ └─ FallbackChain.Execute() ← iterate candidates
+ ├─ CooldownTracker.IsAvailable() ← skip if post-429 cooldown active
+ ├─ RateLimiterRegistry.Wait() ← NEW: block until token available
+ └─ provider.Chat() ← actual LLM HTTP call
+```
+
+The rate limiter runs **after** the cooldown check and **before** the provider call, so:
+- Candidates already in cooldown are skipped entirely (no token consumed)
+- Candidates that are available get throttled to the configured RPM
+
+The same check applies in `ExecuteImage`.
+
+### Thread safety
+
+`RateLimiterRegistry` is safe for concurrent use. The per-limiter token bucket uses a fine-grained mutex so concurrent goroutines each acquire their own token independently.
+
+## Configuration
+
+Set `rpm` on any model in `model_list`:
+
+```yaml
+model_list:
+ - model_name: gpt-4o-free
+ model: openai/gpt-4o
+ api_base: https://api.openai.com/v1
+ rpm: 3 # max 3 requests per minute
+ api_keys:
+ - sk-...
+
+ - model_name: claude-haiku
+ model: anthropic/claude-haiku-4-5
+ rpm: 60 # 60 rpm (Anthropic free tier)
+ api_keys:
+ - sk-ant-...
+
+ - model_name: local-llm
+ model: openai/llama3
+ api_base: http://localhost:11434/v1
+ # no rpm → unrestricted
+```
+
+| Field | Type | Default | Description |
+|---|---|---|---|
+| `rpm` | `int` | `0` | Requests per minute. `0` means no limit. |
+
+### Interaction with fallbacks
+
+When a model has fallbacks configured, each candidate is rate-limited **independently**:
+
+```yaml
+model_list:
+ - model_name: gpt4-with-fallback
+ model: openai/gpt-4o
+ rpm: 5
+ fallbacks:
+ - gpt-4o-mini # must also be in model_list; its own rpm applies
+```
+
+If the current candidate's bucket is empty and there are more candidates available, PicoClaw skips the locally saturated candidate and tries the next fallback immediately. Only the last remaining candidate waits for a token to refill. If the context deadline is hit while waiting on that last candidate, the wait error propagates.
+
+For `model_list` aliases that resolve to the same underlying provider/model, rate limiting is keyed by the stable config identity (for example `model_name`) rather than the resolved runtime model string. This preserves distinct RPM settings for multi-key and alias-based configurations.
+
+### Burst behaviour
+
+The bucket starts **full** (burst = RPM). For `rpm: 3`, the first 3 requests fire instantly; subsequent requests are spaced ~20 s apart.
+
+To reduce burstiness for strict APIs, set a lower `rpm` and rely on the steady-state refill.
+
+## Files changed
+
+| File | What |
+|---|---|
+| `pkg/providers/ratelimiter.go` | `RateLimiter` (token bucket) + `RateLimiterRegistry` |
+| `pkg/providers/ratelimiter_test.go` | Unit tests for limiter and registry |
+| `pkg/providers/fallback.go` | `FallbackCandidate.RPM` field; `FallbackChain.rl`; `Wait()` call in `Execute`/`ExecuteImage` |
+| `pkg/agent/model_resolution.go` | Resolves candidates from `model_list`, preserving stable config identity and propagating `RPM` into `FallbackCandidate` |
+| `pkg/agent/loop.go` | Build `RateLimiterRegistry`, register all agents' candidates, pass to `NewFallbackChain` |
diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md
index 5a4b5bb28..adee9244a 100644
--- a/docs/tools_configuration.md
+++ b/docs/tools_configuration.md
@@ -528,6 +528,9 @@ For example:
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
- `PICOCLAW_TOOLS_MCP_ENABLED=true`
+- `PICOCLAW_TOOLS_MCP_MAX_INLINE_TEXT_CHARS=16384`
Note: Nested map-style config (for example `tools.mcp.servers..*`) is configured in `config.json` rather than
environment variables.
+
+For MCP tools, `tools.mcp.max_inline_text_chars` controls how much text result is kept inline in model context. The threshold is counted in Unicode characters (Go runes), not bytes. For example, `16384` means up to 16,384 characters inline, which may occupy more than 16 KB for multibyte text such as CJK. Above this threshold, PicoClaw saves the MCP text result as a local artifact in the agent workspace and gives the model a short note plus a structured `[file:...]` artifact path instead of injecting the full payload into context.
diff --git a/docs/vi/providers.md b/docs/vi/providers.md
index ffd992645..46c9de663 100644
--- a/docs/vi/providers.md
+++ b/docs/vi/providers.md
@@ -99,6 +99,24 @@ Thiết kế này cũng cho phép **hỗ trợ đa agent** với lựa chọn pr
}
```
+#### Các trường entry `model_list`
+
+| Trường | Kiểu | Bắt buộc | Mô tả |
+|--------|------|----------|------|
+| `model_name` | string | Có | Tên duy nhất để tham chiếu model này trong cấu hình agent |
+| `model` | string | Có | Định danh nhà cung cấp/model (ví dụ: `openai/gpt-5.4`, `azure/gpt-5.4`, `anthropic/claude-sonnet-4.6`) |
+| `api_keys` | string[] | Có* | Khóa API xác thực. Nhiều khóa cho phép xoay vòng theo yêu cầu. Không cần thiết cho provider nội bộ (Ollama, LM Studio, VLLM) |
+| `api_base` | string | Không | Ghi đè URL endpoint API mặc định |
+| `proxy` | string | Không | URL proxy HTTP cho entry model này |
+| `user_agent` | string | Không | Header `User-Agent` tùy chỉnh gửi với yêu cầu API (được hỗ trợ bởi provider OpenAI-compatible, Anthropic và Azure) |
+| `request_timeout` | int | Không | Timeout yêu cầu tính bằng giây (mặc định khác nhau tùy provider) |
+| `max_tokens_field` | string | Không | Ghi đè tên trường max tokens trong request body (ví dụ: `max_completion_tokens` cho model o1) |
+| `thinking_level` | string | Không | Mức độ tư duy mở rộng: `off`, `low`, `medium`, `high`, `xhigh` hoặc `adaptive` |
+| `extra_body` | object | Không | Các trường bổ sung để chèn vào mỗi request body |
+| `rpm` | int | Không | Giới hạn tốc độ yêu cầu mỗi phút |
+| `fallbacks` | string[] | Không | Tên model dự phòng cho failover tự động |
+| `enabled` | bool | Không | Kích hoạt hay vô hiệu hóa entry model này (mặc định: `true`) |
+
#### Ví Dụ Theo Vendor
**OpenAI**
@@ -190,6 +208,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa agent** với lựa chọn pr
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
"api_keys": ["sk-..."],
+ "user_agent": "MyApp/1.0",
"request_timeout": 300
}
```
diff --git a/docs/zh/providers.md b/docs/zh/providers.md
index 04b2f7a88..7b3930f6f 100644
--- a/docs/zh/providers.md
+++ b/docs/zh/providers.md
@@ -15,6 +15,7 @@
| `openrouter` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) |
| `openai` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) |
+| `venice` | LLM (Venice AI 直连) | [venice.ai](https://venice.ai) |
| `deepseek` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) |
| `qwen` | LLM (通义千问) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) |
@@ -44,6 +45,7 @@
| 厂商 | `model` 前缀 | 默认 API Base | 协议 | 获取 API Key |
| ------------------- | ----------------- | --------------------------------------------------- | --------- | ----------------------------------------------------------------- |
| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [获取密钥](https://platform.openai.com) |
+| **Venice AI** | `venice/` | `https://api.venice.ai/api/v1` | OpenAI | [获取密钥](https://venice.ai) |
| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [获取密钥](https://console.anthropic.com) |
| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [获取密钥](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [获取密钥](https://platform.deepseek.com) |
@@ -102,6 +104,25 @@
}
```
+#### `model_list` 条目字段
+
+| 字段 | 类型 | 必填 | 说明 |
+|------|------|------|------|
+| `model_name` | string | 是 | 在 agent 配置中引用此模型的唯一名称 |
+| `model` | string | 是 | 厂商/模型标识符(如 `openai/gpt-5.4`、`azure/gpt-5.4`、`anthropic/claude-sonnet-4.6`) |
+| `api_keys` | string[] | 是* | 认证密钥。多个密钥可按请求轮换。本地 provider(Ollama、LM Studio、VLLM)不需要 |
+| `api_base` | string | 否 | 覆盖默认的 API 端点 URL |
+| `proxy` | string | 否 | 此模型条目的 HTTP 代理 URL |
+| `user_agent` | string | 否 | 自定义 `User-Agent` 请求头(支持 OpenAI 兼容、Anthropic 和 Azure provider) |
+| `request_timeout` | int | 否 | 请求超时时间(秒),默认值因 provider 而异 |
+| `max_tokens_field` | string | 否 | 覆盖请求体中 max tokens 的字段名(如 o1 模型使用 `max_completion_tokens`) |
+| `thinking_level` | string | 否 | 扩展思考级别:`off`、`low`、`medium`、`high`、`xhigh` 或 `adaptive` |
+| `extra_body` | object | 否 | 注入到每个请求体中的额外字段 |
+| `custom_headers` | object | 否 | 注入到每个请求中的额外 HTTP 请求头(例如 `{"X-Source":"coding-plan"}`)。若键名与内置请求头同名,会覆盖内置值(如 `Authorization`、`User-Agent`、`Content-Type`、`Accept`)。 |
+| `rpm` | int | 否 | 每分钟请求速率限制 |
+| `fallbacks` | string[] | 否 | 自动故障转移的备用模型名称 |
+| `enabled` | bool | 否 | 是否启用此模型条目(默认:`true`) |
+
#### 语音转录
你可以通过 `voice.model_name` 为语音转录指定一个专用模型。这样可以直接复用已经配置好的、支持音频输入的多模态 provider,而不必只依赖 Groq。
@@ -232,6 +253,7 @@ PicoClaw 向 LM Studio 的 OpenAI 兼容终结点发送请求,且将移除首
"model": "openai/custom-model",
"api_base": "https://my-proxy.com/v1",
"api_keys": ["sk-..."],
+ "user_agent": "MyApp/1.0",
"request_timeout": 300
}
```
diff --git a/go.mod b/go.mod
index 7d242d498..a9f4bb7cb 100644
--- a/go.mod
+++ b/go.mod
@@ -5,8 +5,10 @@ go 1.25.8
require (
fyne.io/systray v1.12.0
github.com/BurntSushi/toml v1.6.0
+ github.com/SevereCloud/vksdk/v3 v3.3.1
github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.26.0
+ github.com/atc0005/go-teams-notify/v2 v2.14.0
github.com/atotto/clipboard v0.1.4
github.com/aws/aws-sdk-go-v2 v1.41.5
github.com/aws/aws-sdk-go-v2/config v1.32.12
@@ -23,12 +25,15 @@ require (
github.com/h2non/filetype v1.1.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mdp/qrterminal/v3 v3.2.1
+ github.com/minio/selfupdate v0.6.0
github.com/modelcontextprotocol/go-sdk v1.4.1
github.com/mymmrac/telego v1.7.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
+ github.com/pion/rtp v1.10.1
+ github.com/pion/webrtc/v3 v3.3.6
github.com/rivo/tview v0.42.0
- github.com/rs/zerolog v1.34.0
+ github.com/rs/zerolog v1.35.0
github.com/slack-go/slack v0.17.3
github.com/spf13/cobra v1.10.2
github.com/stretchr/testify v1.11.1
@@ -41,11 +46,12 @@ require (
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mautrix v0.26.4
- modernc.org/sqlite v1.47.0
+ modernc.org/sqlite v1.48.0
rsc.io/qr v0.2.0
)
require (
+ aead.dev/minisign v0.2.0 // indirect
filippo.io/edwards25519 v1.2.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
@@ -61,6 +67,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/beeper/argo-go v1.1.2 // indirect
+ github.com/cloudflare/circl v1.6.3 // indirect
github.com/coder/websocket v1.8.14 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
@@ -76,6 +83,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.34 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6 // indirect
+ github.com/pion/randutil v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
@@ -83,6 +91,8 @@ require (
github.com/segmentio/encoding v0.5.4 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/vektah/gqlparser/v2 v2.5.27 // indirect
+ github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
+ github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.mau.fi/libsignal v0.2.1 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel v1.35.0 // indirect
@@ -120,6 +130,8 @@ require (
golang.org/x/arch v0.24.0 // indirect
golang.org/x/crypto v0.49.0
golang.org/x/net v0.52.0
- golang.org/x/sync v0.20.0 // indirect
+ golang.org/x/sync v0.20.0
golang.org/x/sys v0.42.0
)
+
+replace github.com/bwmarrin/discordgo => github.com/yeongaori/discordgo-fork v0.0.0-20260319072544-e8e546f5d532
diff --git a/go.sum b/go.sum
index 76d1b46c7..765a3211a 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,5 @@
+aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk=
+aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
@@ -7,6 +9,8 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
+github.com/SevereCloud/vksdk/v3 v3.3.1 h1:O86zsp5LQnHE+O5acvuXM/s6S1LyxzVTkF6+Lup0Jyg=
+github.com/SevereCloud/vksdk/v3 v3.3.1/go.mod h1:c6WaA5aocUYsXfkcUbg2qy45V9M1VDcqHHmHIN14NAw=
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM=
@@ -17,6 +21,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
+github.com/atc0005/go-teams-notify/v2 v2.14.0 h1:7N+xw+COnYANLREaAveQ65rsNQ12nIZJED9nMLyscCo=
+github.com/atc0005/go-teams-notify/v2 v2.14.0/go.mod h1:EECsWM2b0Hvoz7O+QdlsvyN2KCUOFQCGj8bUBXv3A3Q=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
@@ -53,8 +59,6 @@ github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
-github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
-github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
@@ -65,11 +69,12 @@ github.com/caarlos0/env/v11 v11.4.0 h1:Kcb6t5kIIr4XkoQC9AF2j+8E1Jsrl3Wz/hhm1LtoG
github.com/caarlos0/env/v11 v11.4.0/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
+github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
-github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
@@ -108,7 +113,6 @@ github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2m
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
-github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
@@ -173,17 +177,16 @@ github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJi
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
-github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
-github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
-github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
+github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
+github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM=
github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc=
github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s=
github.com/mymmrac/telego v1.7.0 h1:yRO/l00tFGG4nY66ufUKb4ARqv7qx9+LsjQv/b0NEyo=
@@ -204,8 +207,13 @@ github.com/openai/openai-go/v3 v3.22.0 h1:6MEoNoV8sbjOVmXdvhmuX3BjVbVdcExbVyGixi
github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6 h1:rh2lKw/P/EqHa724vYH2+VVQ1YnW4u6EOXl0PMAovZE=
github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
+github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
+github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
+github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA=
+github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM=
+github.com/pion/webrtc/v3 v3.3.6 h1:7XAh4RPtlY1Vul6/GmZrv7z+NnxKA6If0KStXBI2ZLE=
+github.com/pion/webrtc/v3 v3.3.6/go.mod h1:zyN7th4mZpV27eXybfR/cnUf3J2DRy8zw/mdjD9JTNM=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
-github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
@@ -218,9 +226,8 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
-github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
-github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
-github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
+github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
+github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
@@ -271,8 +278,14 @@ github.com/valyala/fastjson v1.6.10 h1:/yjJg8jaVQdYR3arGxPE2X5z89xrlhS0eGXdv+ADT
github.com/valyala/fastjson v1.6.10/go.mod h1:e6FubmQouUNP73jtMLmcbxS6ydWIpOfhz34TSfO3JaE=
github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTdwFp0s=
github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo=
+github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
+github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
+github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
+github.com/yeongaori/discordgo-fork v0.0.0-20260319072544-e8e546f5d532 h1:gxFHYeUDGziRb0zXYEqBFohC+NJbIW9L0tddaXMWr2o=
+github.com/yeongaori/discordgo-fork v0.0.0-20260319072544-e8e546f5d532/go.mod h1:A0FcMFJKJ9fRjgSuZ2o+pIQ6mPS81SVuiLN2vYTa7Ao=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -300,8 +313,9 @@ golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
+golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
@@ -322,6 +336,7 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
+golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
@@ -344,24 +359,25 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
+golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -442,8 +458,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
-modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
-modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
+modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4=
+modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
diff --git a/pkg/agent/context.go b/pkg/agent/context.go
index b5c68650a..c2921294b 100644
--- a/pkg/agent/context.go
+++ b/pkg/agent/context.go
@@ -602,14 +602,16 @@ func (cb *ContextBuilder) BuildMessages(
// Add conversation history
messages = append(messages, history...)
- // Add current user message
- if strings.TrimSpace(currentMessage) != "" {
+ // Add current user message. Media-only turns must still be preserved so
+ // multimodal providers receive the uploaded image even when the user sends
+ // no accompanying text.
+ if strings.TrimSpace(currentMessage) != "" || len(media) > 0 {
msg := providers.Message{
Role: "user",
Content: currentMessage,
}
if len(media) > 0 {
- msg.Media = media
+ msg.Media = append([]string(nil), media...)
}
messages = append(messages, msg)
}
diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go
index 3398d7863..72f80382a 100644
--- a/pkg/agent/context_budget.go
+++ b/pkg/agent/context_budget.go
@@ -6,10 +6,8 @@
package agent
import (
- "encoding/json"
- "unicode/utf8"
-
"github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/tokenizer"
)
// parseTurnBoundaries returns the starting index of each Turn in the history.
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
return 0
}
-// estimateMessageTokens estimates the token count for a single message,
-// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
-// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
-func estimateMessageTokens(msg providers.Message) int {
- contentChars := utf8.RuneCountInString(msg.Content)
-
- // SystemParts are structured system blocks used for cache-aware adapters.
- // They carry the same content as Content, but in multiple blocks.
- // We estimate them as an alternative representation, not additive.
- systemPartsChars := 0
- if len(msg.SystemParts) > 0 {
- for _, part := range msg.SystemParts {
- systemPartsChars += utf8.RuneCountInString(part.Text)
- }
- // Per-part overhead for JSON structure (type, text, cache_control).
- const perPartOverhead = 20
- systemPartsChars += len(msg.SystemParts) * perPartOverhead
- }
-
- // Use the larger of the two representations to stay conservative.
- chars := contentChars
- if systemPartsChars > chars {
- chars = systemPartsChars
- }
-
- chars += utf8.RuneCountInString(msg.ReasoningContent)
-
- for _, tc := range msg.ToolCalls {
- chars += len(tc.ID) + len(tc.Type)
- if tc.Function != nil {
- // Count function name + arguments (the wire format for most providers).
- // tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
- chars += len(tc.Function.Name) + len(tc.Function.Arguments)
- } else {
- // Fallback: some provider formats use top-level Name without Function.
- chars += len(tc.Name)
- }
- }
-
- if msg.ToolCallID != "" {
- chars += len(msg.ToolCallID)
- }
-
- // Per-message overhead for role label, JSON structure, separators.
- const messageOverhead = 12
- chars += messageOverhead
-
- tokens := chars * 2 / 5
-
- // Media items (images, files) are serialized by provider adapters into
- // multipart or image_url payloads. Add a fixed per-item token estimate
- // directly (not through the chars heuristic) since actual cost depends
- // on resolution and provider-specific image tokenization.
- const mediaTokensPerItem = 256
- tokens += len(msg.Media) * mediaTokensPerItem
-
- return tokens
+// EstimateMessageTokens estimates the token count for a single message.
+// Delegates to the shared tokenizer package for consistency across agent and seahorse.
+func EstimateMessageTokens(msg providers.Message) int {
+ return tokenizer.EstimateMessageTokens(msg)
}
-// estimateToolDefsTokens estimates the total token cost of tool definitions
-// as they appear in the LLM request. Each tool's name, description, and
-// JSON schema parameters contribute to the context window budget.
-func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
- if len(defs) == 0 {
- return 0
- }
-
- totalChars := 0
- for _, d := range defs {
- totalChars += len(d.Function.Name) + len(d.Function.Description)
-
- if d.Function.Parameters != nil {
- if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
- totalChars += len(paramJSON)
- }
- }
-
- // Per-tool overhead: type field, JSON structure, separators.
- totalChars += 20
- }
-
- return totalChars * 2 / 5
+// EstimateToolDefsTokens estimates the total token cost of tool definitions
+// as they appear in the LLM request. Delegates to the shared tokenizer package.
+func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
+ return tokenizer.EstimateToolDefsTokens(defs)
}
// isOverContextBudget checks whether the assembled messages plus tool definitions
@@ -181,10 +107,10 @@ func isOverContextBudget(
) bool {
msgTokens := 0
for _, m := range messages {
- msgTokens += estimateMessageTokens(m)
+ msgTokens += EstimateMessageTokens(m)
}
- toolTokens := estimateToolDefsTokens(toolDefs)
+ toolTokens := EstimateToolDefsTokens(toolDefs)
total := msgTokens + toolTokens + maxTokens
return total > contextWindow
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index 22cbdc0db..9de1707ec 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got := estimateMessageTokens(tt.msg)
+ got := EstimateMessageTokens(tt.msg)
if got < tt.want {
- t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
+ t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
}
})
}
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
},
}
- plainTokens := estimateMessageTokens(plain)
- withTCTokens := estimateMessageTokens(withTC)
+ plainTokens := EstimateMessageTokens(plain)
+ withTCTokens := EstimateMessageTokens(withTC)
if withTCTokens <= plainTokens {
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
// but may map to different token counts. The heuristic should still produce
// reasonable estimates via RuneCountInString.
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
- tokens := estimateMessageTokens(msg)
+ tokens := EstimateMessageTokens(msg)
if tokens <= 0 {
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
}
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
},
}
- tokens := estimateMessageTokens(msg)
+ tokens := EstimateMessageTokens(msg)
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
if tokens < 2000 {
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
ReasoningContent: strings.Repeat("thinking step ", 200),
}
- plainTokens := estimateMessageTokens(plain)
- reasoningTokens := estimateMessageTokens(withReasoning)
+ plainTokens := EstimateMessageTokens(plain)
+ reasoningTokens := EstimateMessageTokens(withReasoning)
if reasoningTokens <= plainTokens {
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
Media: []string{"media://img1.png", "media://img2.png"},
}
- plainTokens := estimateMessageTokens(plain)
- mediaTokens := estimateMessageTokens(withMedia)
+ plainTokens := EstimateMessageTokens(plain)
+ mediaTokens := EstimateMessageTokens(withMedia)
if mediaTokens <= plainTokens {
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
},
}
- plainTokens := estimateMessageTokens(plain)
- partsTokens := estimateMessageTokens(withParts)
+ plainTokens := EstimateMessageTokens(plain)
+ partsTokens := EstimateMessageTokens(withParts)
if partsTokens <= plainTokens {
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
}
}
-// --- estimateToolDefsTokens tests ---
+// --- EstimateToolDefsTokens tests ---
func TestEstimateToolDefsTokens(t *testing.T) {
tests := []struct {
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got := estimateToolDefsTokens(tt.defs)
+ got := EstimateToolDefsTokens(tt.defs)
if got < tt.want {
- t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
+ t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
}
})
}
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
}
}
- one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
- three := estimateToolDefsTokens([]providers.ToolDefinition{
+ one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
+ three := EstimateToolDefsTokens([]providers.ToolDefinition{
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
})
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
},
}
- tokens := estimateMessageTokens(msg)
+ tokens := EstimateMessageTokens(msg)
// ReasoningContent alone is ~1700 chars → ~680 tokens.
// Content + TC + overhead adds more. Should be well above 500.
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
// Compare without reasoning to ensure it's counted.
msgNoReasoning := msg
msgNoReasoning.ReasoningContent = ""
- tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
+ tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
if tokens <= tokensNoReasoning {
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go
index 81a1534b9..ef5e6c5de 100644
--- a/pkg/agent/context_cache_test.go
+++ b/pkg/agent/context_cache_test.go
@@ -707,6 +707,38 @@ func TestEmptyWorkspaceBaselineDetectsNewFiles(t *testing.T) {
}
}
+func TestBuildMessages_IncludesMediaOnlyCurrentMessage(t *testing.T) {
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ msgs := cb.BuildMessages(
+ nil,
+ "",
+ "",
+ []string{"data:image/png;base64,abc123"},
+ "pico",
+ "chat-1",
+ "",
+ "",
+ )
+
+ if len(msgs) != 2 {
+ t.Fatalf("len(msgs) = %d, want 2", len(msgs))
+ }
+
+ userMsg := msgs[1]
+ if userMsg.Role != "user" {
+ t.Fatalf("userMsg.Role = %q, want %q", userMsg.Role, "user")
+ }
+ if userMsg.Content != "" {
+ t.Fatalf("userMsg.Content = %q, want empty string", userMsg.Content)
+ }
+ if len(userMsg.Media) != 1 || userMsg.Media[0] != "data:image/png;base64,abc123" {
+ t.Fatalf("userMsg.Media = %#v, want image payload", userMsg.Media)
+ }
+}
+
// BenchmarkBuildMessagesWithCache measures caching performance.
func BenchmarkBuildMessagesWithCache(b *testing.B) {
tmpDir, _ := os.MkdirTemp("", "picoclaw-bench-*")
diff --git a/pkg/agent/context_legacy.go b/pkg/agent/context_legacy.go
new file mode 100644
index 000000000..51aff44f8
--- /dev/null
+++ b/pkg/agent/context_legacy.go
@@ -0,0 +1,379 @@
+package agent
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// legacyContextManager wraps the existing summarization/compression logic
+// as a ContextManager implementation. It is the default when no other
+// ContextManager is configured.
+type legacyContextManager struct {
+ al *AgentLoop
+ summarizing sync.Map // dedup for async Compact (post-turn)
+}
+
+func (m *legacyContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
+ // Legacy: read history from session, return as-is.
+ // Budget enforcement happens in BuildMessages caller via
+ // isOverContextBudget + forceCompression.
+ agent := m.al.registry.GetDefaultAgent()
+ if agent == nil {
+ return &AssembleResponse{}, nil
+ }
+ history := agent.Sessions.GetHistory(req.SessionKey)
+ summary := agent.Sessions.GetSummary(req.SessionKey)
+ return &AssembleResponse{
+ History: history,
+ Summary: summary,
+ }, nil
+}
+
+func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) error {
+ switch req.Reason {
+ case ContextCompressReasonProactive, ContextCompressReasonRetry:
+ // Sync emergency compression — budget exceeded.
+ if result, ok := m.forceCompression(req.SessionKey); ok {
+ m.al.emitEvent(
+ EventKindContextCompress,
+ m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"),
+ ContextCompressPayload{
+ Reason: req.Reason,
+ DroppedMessages: result.DroppedMessages,
+ RemainingMessages: result.RemainingMessages,
+ },
+ )
+ }
+ case ContextCompressReasonSummarize:
+ m.maybeSummarize(req.SessionKey)
+ }
+ return nil
+}
+
+func (m *legacyContextManager) Ingest(_ context.Context, _ *IngestRequest) error {
+ // Legacy: no-op. Messages are persisted by Sessions JSONL.
+ return nil
+}
+
+// maybeSummarize triggers summarization if the session history exceeds thresholds.
+// It runs asynchronously in a goroutine.
+func (m *legacyContextManager) maybeSummarize(sessionKey string) {
+ agent := m.al.registry.GetDefaultAgent()
+ if agent == nil {
+ return
+ }
+
+ newHistory := agent.Sessions.GetHistory(sessionKey)
+ tokenEstimate := m.estimateTokens(newHistory)
+ threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
+
+ if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
+ summarizeKey := agent.ID + ":" + sessionKey
+ if _, loading := m.summarizing.LoadOrStore(summarizeKey, true); !loading {
+ go func() {
+ defer m.summarizing.Delete(summarizeKey)
+ defer func() {
+ if r := recover(); r != nil {
+ logger.WarnCF("agent", "Summarization panic recovered", map[string]any{
+ "session_key": sessionKey,
+ "panic": r,
+ })
+ }
+ }()
+ logger.Debug("Memory threshold reached. Optimizing conversation history...")
+ m.summarizeSession(agent, sessionKey)
+ }()
+ }
+ }
+}
+
+type compressionResult struct {
+ DroppedMessages int
+ RemainingMessages int
+}
+
+// forceCompression aggressively reduces context when the limit is hit.
+// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
+// cycle, as defined in #1316), so tool-call sequences are never split.
+func (m *legacyContextManager) forceCompression(sessionKey string) (compressionResult, bool) {
+ agent := m.al.registry.GetDefaultAgent()
+ if agent == nil {
+ return compressionResult{}, false
+ }
+
+ history := agent.Sessions.GetHistory(sessionKey)
+ if len(history) <= 2 {
+ return compressionResult{}, false
+ }
+
+ turns := parseTurnBoundaries(history)
+ var mid int
+ if len(turns) >= 2 {
+ mid = turns[len(turns)/2]
+ } else {
+ mid = findSafeBoundary(history, len(history)/2)
+ }
+ var keptHistory []providers.Message
+ if mid <= 0 {
+ for i := len(history) - 1; i >= 0; i-- {
+ if history[i].Role == "user" {
+ keptHistory = []providers.Message{history[i]}
+ break
+ }
+ }
+ } else {
+ keptHistory = history[mid:]
+ }
+
+ droppedCount := len(history) - len(keptHistory)
+
+ existingSummary := agent.Sessions.GetSummary(sessionKey)
+ compressionNote := fmt.Sprintf(
+ "[Emergency compression dropped %d oldest messages due to context limit]",
+ droppedCount,
+ )
+ if existingSummary != "" {
+ compressionNote = existingSummary + "\n\n" + compressionNote
+ }
+ agent.Sessions.SetSummary(sessionKey, compressionNote)
+
+ agent.Sessions.SetHistory(sessionKey, keptHistory)
+ agent.Sessions.Save(sessionKey)
+
+ logger.WarnCF("agent", "Forced compression executed", map[string]any{
+ "session_key": sessionKey,
+ "dropped_msgs": droppedCount,
+ "new_count": len(keptHistory),
+ })
+
+ return compressionResult{
+ DroppedMessages: droppedCount,
+ RemainingMessages: len(keptHistory),
+ }, true
+}
+
+func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey string) {
+ ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
+ defer cancel()
+
+ history := agent.Sessions.GetHistory(sessionKey)
+ summary := agent.Sessions.GetSummary(sessionKey)
+
+ if len(history) <= 4 {
+ return
+ }
+
+ safeCut := findSafeBoundary(history, len(history)-4)
+ if safeCut <= 0 {
+ return
+ }
+ keepCount := len(history) - safeCut
+ toSummarize := history[:safeCut]
+
+ maxMessageTokens := agent.ContextWindow / 2
+ validMessages := make([]providers.Message, 0)
+ omitted := false
+
+ for _, msg := range toSummarize {
+ if msg.Role != "user" && msg.Role != "assistant" {
+ continue
+ }
+ msgTokens := len(msg.Content) / 2
+ if msgTokens > maxMessageTokens {
+ omitted = true
+ continue
+ }
+ validMessages = append(validMessages, msg)
+ }
+
+ if len(validMessages) == 0 {
+ return
+ }
+
+ const (
+ maxSummarizationMessages = 10
+ llmMaxRetries = 3
+ )
+
+ var finalSummary string
+ if len(validMessages) > maxSummarizationMessages {
+ mid := len(validMessages) / 2
+ mid = m.findNearestUserMessage(validMessages, mid)
+
+ part1 := validMessages[:mid]
+ part2 := validMessages[mid:]
+
+ s1, _ := m.summarizeBatch(ctx, agent, part1, "")
+ s2, _ := m.summarizeBatch(ctx, agent, part2, "")
+
+ mergePrompt := fmt.Sprintf(
+ "Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
+ s1, s2,
+ )
+
+ resp, err := m.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
+ if err == nil && resp.Content != "" {
+ finalSummary = resp.Content
+ } else {
+ finalSummary = s1 + " " + s2
+ }
+ } else {
+ finalSummary, _ = m.summarizeBatch(ctx, agent, validMessages, summary)
+ }
+
+ if omitted && finalSummary != "" {
+ finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
+ }
+
+ if finalSummary != "" {
+ agent.Sessions.SetSummary(sessionKey, finalSummary)
+ agent.Sessions.TruncateHistory(sessionKey, keepCount)
+ agent.Sessions.Save(sessionKey)
+ m.al.emitEvent(
+ EventKindSessionSummarize,
+ m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"),
+ SessionSummarizePayload{
+ SummarizedMessages: len(validMessages),
+ KeptMessages: keepCount,
+ SummaryLen: len(finalSummary),
+ OmittedOversized: omitted,
+ },
+ )
+ }
+}
+
+func (m *legacyContextManager) findNearestUserMessage(messages []providers.Message, mid int) int {
+ originalMid := mid
+
+ for mid > 0 && messages[mid].Role != "user" {
+ mid--
+ }
+
+ if messages[mid].Role == "user" {
+ return mid
+ }
+
+ mid = originalMid
+ for mid < len(messages) && messages[mid].Role != "user" {
+ mid++
+ }
+
+ if mid < len(messages) {
+ return mid
+ }
+
+ return originalMid
+}
+
+func (m *legacyContextManager) retryLLMCall(
+ ctx context.Context,
+ agent *AgentInstance,
+ prompt string,
+ maxRetries int,
+) (*providers.LLMResponse, error) {
+ const llmTemperature = 0.3
+
+ var resp *providers.LLMResponse
+ var err error
+
+ for attempt := 0; attempt < maxRetries; attempt++ {
+ m.al.activeRequests.Add(1)
+ resp, err = func() (*providers.LLMResponse, error) {
+ defer m.al.activeRequests.Done()
+ return agent.Provider.Chat(
+ ctx,
+ []providers.Message{{Role: "user", Content: prompt}},
+ nil,
+ agent.Model,
+ map[string]any{
+ "max_tokens": agent.MaxTokens,
+ "temperature": llmTemperature,
+ "prompt_cache_key": agent.ID,
+ },
+ )
+ }()
+
+ if err == nil && resp != nil && resp.Content != "" {
+ return resp, nil
+ }
+ if attempt < maxRetries-1 {
+ time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
+ }
+ }
+
+ return resp, err
+}
+
+func (m *legacyContextManager) summarizeBatch(
+ ctx context.Context,
+ agent *AgentInstance,
+ batch []providers.Message,
+ existingSummary string,
+) (string, error) {
+ const (
+ llmMaxRetries = 3
+ fallbackMinContentLength = 200
+ fallbackMaxContentPercent = 10
+ )
+
+ var sb strings.Builder
+ sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
+ if existingSummary != "" {
+ sb.WriteString("Existing context: ")
+ sb.WriteString(existingSummary)
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\nCONVERSATION:\n")
+ for _, msg := range batch {
+ fmt.Fprintf(&sb, "%s: %s\n", msg.Role, msg.Content)
+ }
+ prompt := sb.String()
+
+ response, err := m.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
+ if err == nil && response.Content != "" {
+ return strings.TrimSpace(response.Content), nil
+ }
+
+ var fallback strings.Builder
+ fallback.WriteString("Conversation summary: ")
+ for i, msg := range batch {
+ if i > 0 {
+ fallback.WriteString(" | ")
+ }
+ content := strings.TrimSpace(msg.Content)
+ runes := []rune(content)
+ if len(runes) == 0 {
+ fallback.WriteString(fmt.Sprintf("%s: ", msg.Role))
+ continue
+ }
+
+ keepLength := len(runes) * fallbackMaxContentPercent / 100
+ if keepLength < fallbackMinContentLength {
+ keepLength = fallbackMinContentLength
+ }
+ if keepLength > len(runes) {
+ keepLength = len(runes)
+ }
+
+ content = string(runes[:keepLength])
+ if keepLength < len(runes) {
+ content += "..."
+ }
+ fallback.WriteString(fmt.Sprintf("%s: %s", msg.Role, content))
+ }
+ return fallback.String(), nil
+}
+
+func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
+ total := 0
+ for _, msg := range messages {
+ total += EstimateMessageTokens(msg)
+ }
+ return total
+}
diff --git a/pkg/agent/context_manager.go b/pkg/agent/context_manager.go
new file mode 100644
index 000000000..5f8701812
--- /dev/null
+++ b/pkg/agent/context_manager.go
@@ -0,0 +1,90 @@
+package agent
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sync"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// ContextManager manages conversation context via a pluggable strategy.
+// Exactly ONE ContextManager is active per AgentLoop, selected by config.
+// The default ("legacy") preserves current summarization behavior.
+type ContextManager interface {
+ // Assemble builds budget-aware context from the ContextManager's own storage.
+ // Called before BuildMessages. Returns assembled messages ready for LLM.
+ Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error)
+
+ // Compact compresses conversation history.
+ // Called after turn completes (may be async internally) and on context overflow (sync).
+ Compact(ctx context.Context, req *CompactRequest) error
+
+ // Ingest records a message into the ContextManager's own storage.
+ // Called after each message is persisted to session JSONL.
+ Ingest(ctx context.Context, req *IngestRequest) error
+}
+
+// AssembleRequest is the input to Assemble.
+type AssembleRequest struct {
+ SessionKey string // session identifier
+ Budget int // context window in tokens
+ MaxTokens int // max response tokens
+}
+
+// AssembleResponse is the output of Assemble.
+type AssembleResponse struct {
+ History []providers.Message // assembled conversation history for BuildMessages
+ Summary string // conversation summary embedded into system prompt by BuildMessages
+}
+
+// CompactRequest is the input to Compact.
+type CompactRequest struct {
+ SessionKey string // session identifier
+ Reason ContextCompressReason // proactive_budget | llm_retry | summarize
+ Budget int // context window budget (used for retry aggressive compaction)
+}
+
+// IngestRequest is the input to Ingest.
+type IngestRequest struct {
+ SessionKey string // session identifier
+ Message providers.Message // the message just persisted
+}
+
+// ContextManagerFactory constructs a ContextManager from config.
+// al provides access to the AgentLoop's runtime resources (provider, model, workspace, etc.)
+// cfg is the raw JSON configuration from config.json (may be nil).
+type ContextManagerFactory func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error)
+
+var (
+ cmRegistryMu sync.RWMutex
+ cmRegistry = map[string]ContextManagerFactory{}
+)
+
+// RegisterContextManager registers a named ContextManager factory.
+func RegisterContextManager(name string, factory ContextManagerFactory) error {
+ if name == "" {
+ return fmt.Errorf("context manager name is required")
+ }
+ if factory == nil {
+ return fmt.Errorf("context manager %q factory is nil", name)
+ }
+
+ cmRegistryMu.Lock()
+ defer cmRegistryMu.Unlock()
+
+ if _, exists := cmRegistry[name]; exists {
+ return fmt.Errorf("context manager %q is already registered", name)
+ }
+ cmRegistry[name] = factory
+ return nil
+}
+
+func lookupContextManager(name string) (ContextManagerFactory, bool) {
+ cmRegistryMu.RLock()
+ defer cmRegistryMu.RUnlock()
+
+ f, ok := cmRegistry[name]
+ return f, ok
+}
diff --git a/pkg/agent/context_manager_test.go b/pkg/agent/context_manager_test.go
new file mode 100644
index 000000000..6bde5e1a9
--- /dev/null
+++ b/pkg/agent/context_manager_test.go
@@ -0,0 +1,764 @@
+package agent
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// ---------------------------------------------------------------------------
+// Factory registry tests
+// ---------------------------------------------------------------------------
+
+func TestRegisterContextManager_Success(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ return &noopContextManager{}, nil
+ }
+ if err := RegisterContextManager("test_cm", factory); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ f, ok := lookupContextManager("test_cm")
+ if !ok {
+ t.Fatal("expected factory to be registered")
+ }
+ if f == nil {
+ t.Fatal("expected non-nil factory")
+ }
+}
+
+func TestRegisterContextManager_EmptyName(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ err := RegisterContextManager("", func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ return &noopContextManager{}, nil
+ })
+ if err == nil {
+ t.Fatal("expected error for empty name")
+ }
+ if !strings.Contains(err.Error(), "name is required") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestRegisterContextManager_NilFactory(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ err := RegisterContextManager("nil_factory", nil)
+ if err == nil {
+ t.Fatal("expected error for nil factory")
+ }
+ if !strings.Contains(err.Error(), "factory is nil") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestRegisterContextManager_Duplicate(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ return &noopContextManager{}, nil
+ }
+ if err := RegisterContextManager("dup_cm", factory); err != nil {
+ t.Fatalf("first registration failed: %v", err)
+ }
+ err := RegisterContextManager("dup_cm", factory)
+ if err == nil {
+ t.Fatal("expected error for duplicate registration")
+ }
+ if !strings.Contains(err.Error(), "already registered") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestLookupContextManager_Unknown(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ _, ok := lookupContextManager("nonexistent")
+ if ok {
+ t.Fatal("expected lookup to fail for unknown name")
+ }
+}
+
+// ---------------------------------------------------------------------------
+// resolveContextManager tests
+// ---------------------------------------------------------------------------
+
+func TestResolveContextManager_Default(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "", // default → legacy
+ },
+ },
+ }
+ al := newCMTestAgentLoop(cfg)
+
+ cm := al.contextManager
+ if cm == nil {
+ t.Fatal("expected non-nil context manager")
+ }
+ if _, ok := cm.(*legacyContextManager); !ok {
+ t.Fatalf("expected *legacyContextManager, got %T", cm)
+ }
+}
+
+func TestResolveContextManager_ExplicitLegacy(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "legacy",
+ },
+ },
+ }
+ al := newCMTestAgentLoop(cfg)
+
+ if _, ok := al.contextManager.(*legacyContextManager); !ok {
+ t.Fatalf("expected *legacyContextManager, got %T", al.contextManager)
+ }
+}
+
+func TestResolveContextManager_UnknownFallsBackToLegacy(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "unknown_cm",
+ },
+ },
+ }
+ al := newCMTestAgentLoop(cfg)
+
+ if _, ok := al.contextManager.(*legacyContextManager); !ok {
+ t.Fatalf("expected fallback to *legacyContextManager, got %T", al.contextManager)
+ }
+}
+
+func TestResolveContextManager_RegisteredFactory(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ return &noopContextManager{}, nil
+ }
+ if err := RegisterContextManager("custom_cm", factory); err != nil {
+ t.Fatalf("register failed: %v", err)
+ }
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "custom_cm",
+ },
+ },
+ }
+ al := newCMTestAgentLoop(cfg)
+
+ if _, ok := al.contextManager.(*noopContextManager); !ok {
+ t.Fatalf("expected *noopContextManager, got %T", al.contextManager)
+ }
+}
+
+func TestResolveContextManager_FactoryError(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ return nil, os.ErrPermission
+ }
+ if err := RegisterContextManager("broken_cm", factory); err != nil {
+ t.Fatalf("register failed: %v", err)
+ }
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "broken_cm",
+ },
+ },
+ }
+ al := newCMTestAgentLoop(cfg)
+
+ // Should fall back to legacy when factory returns error
+ if _, ok := al.contextManager.(*legacyContextManager); !ok {
+ t.Fatalf("expected fallback to *legacyContextManager on factory error, got %T", al.contextManager)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Legacy Assemble tests
+// ---------------------------------------------------------------------------
+
+func TestLegacyAssemble_Passthrough(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ agent := al.registry.GetDefaultAgent()
+ if agent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ history := []providers.Message{
+ {Role: "user", Content: "hello"},
+ {Role: "assistant", Content: "hi there"},
+ }
+ agent.Sessions.SetHistory("test-session", history)
+
+ resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
+ SessionKey: "test-session",
+ Budget: 8000,
+ MaxTokens: 4096,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(resp.History) != len(history) {
+ t.Fatalf("expected %d messages, got %d", len(history), len(resp.History))
+ }
+ for i, msg := range resp.History {
+ if msg.Content != history[i].Content || msg.Role != history[i].Role {
+ t.Fatalf("message %d mismatch: want %+v, got %+v", i, history[i], msg)
+ }
+ }
+}
+
+func TestLegacyAssemble_EmptyHistory(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ resp, err := al.contextManager.Assemble(context.Background(), &AssembleRequest{
+ SessionKey: "test-session",
+ Budget: 8000,
+ MaxTokens: 4096,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(resp.History) != 0 {
+ t.Fatalf("expected empty messages, got %d", len(resp.History))
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Legacy Compact overflow tests
+// ---------------------------------------------------------------------------
+
+func TestLegacyCompact_Overflow(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ history := []providers.Message{
+ {Role: "user", Content: "msg 1"},
+ {Role: "assistant", Content: "resp 1"},
+ {Role: "user", Content: "msg 2"},
+ {Role: "assistant", Content: "resp 2"},
+ {Role: "user", Content: "msg 3"},
+ }
+ defaultAgent.Sessions.SetHistory("session-overflow", history)
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ err := al.contextManager.Compact(context.Background(), &CompactRequest{
+ SessionKey: "session-overflow",
+ Reason: ContextCompressReasonRetry,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // After overflow compression, history should be shorter
+ newHistory := defaultAgent.Sessions.GetHistory("session-overflow")
+ if len(newHistory) >= len(history) {
+ t.Fatalf("expected compressed history, got %d messages (was %d)", len(newHistory), len(history))
+ }
+
+ // Summary should contain compression note
+ summary := defaultAgent.Sessions.GetSummary("session-overflow")
+ if !strings.Contains(summary, "Emergency compression") {
+ t.Fatalf("expected compression note in summary, got %q", summary)
+ }
+
+ // Event should carry the proactive reason
+ events := collectEventStream(sub.C)
+ compressEvt, ok := findEvent(events, EventKindContextCompress)
+ if !ok {
+ t.Fatal("expected context compress event")
+ }
+ payload, ok := compressEvt.Payload.(ContextCompressPayload)
+ if !ok {
+ t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
+ }
+ if payload.Reason != ContextCompressReasonRetry {
+ t.Fatalf("expected retry reason, got %q", payload.Reason)
+ }
+}
+
+func TestLegacyCompact_Overflow_ProactiveReason(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ history := []providers.Message{
+ {Role: "user", Content: "msg 1"},
+ {Role: "assistant", Content: "resp 1"},
+ {Role: "user", Content: "msg 2"},
+ {Role: "assistant", Content: "resp 2"},
+ {Role: "user", Content: "msg 3"},
+ }
+ defaultAgent.Sessions.SetHistory("session-proactive", history)
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ err := al.contextManager.Compact(context.Background(), &CompactRequest{
+ SessionKey: "session-proactive",
+ Reason: ContextCompressReasonProactive,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ events := collectEventStream(sub.C)
+ compressEvt, ok := findEvent(events, EventKindContextCompress)
+ if !ok {
+ t.Fatal("expected context compress event")
+ }
+ payload, ok := compressEvt.Payload.(ContextCompressPayload)
+ if !ok {
+ t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
+ }
+ if payload.Reason != ContextCompressReasonProactive {
+ t.Fatalf("expected proactive reason, got %q", payload.Reason)
+ }
+}
+
+func TestLegacyCompact_Overflow_TooShortToCompress(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ history := []providers.Message{
+ {Role: "user", Content: "only one"},
+ }
+ defaultAgent.Sessions.SetHistory("session-tiny", history)
+
+ err := al.contextManager.Compact(context.Background(), &CompactRequest{
+ SessionKey: "session-tiny",
+ Reason: ContextCompressReasonRetry,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // History should be unchanged (too short to compress)
+ newHistory := defaultAgent.Sessions.GetHistory("session-tiny")
+ if len(newHistory) != len(history) {
+ t.Fatalf("expected history unchanged, got %d messages (was %d)", len(newHistory), len(history))
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Legacy Compact post-turn tests
+// ---------------------------------------------------------------------------
+
+func TestLegacyCompact_PostTurn_BelowThreshold(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ // Small history, below summarization thresholds
+ history := []providers.Message{
+ {Role: "user", Content: "hi"},
+ {Role: "assistant", Content: "hello"},
+ }
+ defaultAgent.Sessions.SetHistory("session-small", history)
+
+ err := al.contextManager.Compact(context.Background(), &CompactRequest{
+ SessionKey: "session-small",
+ Reason: ContextCompressReasonSummarize,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // History should remain unchanged
+ newHistory := defaultAgent.Sessions.GetHistory("session-small")
+ if len(newHistory) != len(history) {
+ t.Fatalf("expected unchanged history, got %d messages (was %d)", len(newHistory), len(history))
+ }
+}
+
+func TestLegacyCompact_PostTurn_ExceedsMessageThreshold(t *testing.T) {
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextWindow: 8000,
+ SummarizeMessageThreshold: 2,
+ SummarizeTokenPercent: 75,
+ },
+ },
+ }
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary"})
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ // 6 messages > threshold of 2
+ history := []providers.Message{
+ {Role: "user", Content: "q1"},
+ {Role: "assistant", Content: "a1"},
+ {Role: "user", Content: "q2"},
+ {Role: "assistant", Content: "a2"},
+ {Role: "user", Content: "q3"},
+ {Role: "assistant", Content: "a3"},
+ }
+ defaultAgent.Sessions.SetHistory("session-threshold", history)
+
+ err := al.contextManager.Compact(context.Background(), &CompactRequest{
+ SessionKey: "session-threshold",
+ Reason: ContextCompressReasonSummarize,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Wait for async summarization to complete via event
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ waitForEvent(t, sub.C, 5*time.Second, func(evt Event) bool {
+ return evt.Kind == EventKindSessionSummarize
+ })
+
+ newHistory := defaultAgent.Sessions.GetHistory("session-threshold")
+ if len(newHistory) >= len(history) {
+ t.Fatalf("expected summarization to reduce history from %d messages, got %d", len(history), len(newHistory))
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Legacy Ingest tests
+// ---------------------------------------------------------------------------
+
+func TestLegacyIngest_NoOp(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ err := al.contextManager.Ingest(context.Background(), &IngestRequest{
+ SessionKey: "session-ingest",
+ Message: providers.Message{Role: "user", Content: "test"},
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Mock ContextManager — verifies dispatch through AgentLoop
+// ---------------------------------------------------------------------------
+
+func TestAgentLoop_UsesCustomContextManager(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ mock := &trackingContextManager{}
+ factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ return mock, nil
+ }
+ if err := RegisterContextManager("tracking_cm", factory); err != nil {
+ t.Fatalf("register failed: %v", err)
+ }
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "tracking_cm",
+ },
+ },
+ }
+ al := newCMTestAgentLoop(cfg)
+
+ // Verify the mock was installed
+ if al.contextManager != mock {
+ t.Fatalf("expected mock context manager, got %T", al.contextManager)
+ }
+
+ // Direct method calls
+ _, err := mock.Assemble(context.Background(), &AssembleRequest{
+ SessionKey: "s1",
+ Budget: 8000,
+ MaxTokens: 4096,
+ })
+ if err != nil {
+ t.Fatalf("Assemble error: %v", err)
+ }
+ if mock.assembleCalls.Load() != 1 {
+ t.Fatalf("expected 1 assemble call, got %d", mock.assembleCalls.Load())
+ }
+
+ err = mock.Compact(context.Background(), &CompactRequest{
+ SessionKey: "s1",
+ Reason: ContextCompressReasonRetry,
+ })
+ if err != nil {
+ t.Fatalf("Compact error: %v", err)
+ }
+ if mock.compactCalls.Load() != 1 {
+ t.Fatalf("expected 1 compact call, got %d", mock.compactCalls.Load())
+ }
+
+ err = mock.Ingest(context.Background(), &IngestRequest{
+ SessionKey: "s1",
+ Message: providers.Message{Role: "user", Content: "test"},
+ })
+ if err != nil {
+ t.Fatalf("Ingest error: %v", err)
+ }
+ if mock.ingestCalls.Load() != 1 {
+ t.Fatalf("expected 1 ingest call, got %d", mock.ingestCalls.Load())
+ }
+}
+
+func TestIngestCalledDuringTurn(t *testing.T) {
+ cleanup := resetCMRegistry()
+ defer cleanup()
+
+ mock := &trackingContextManager{}
+ factory := func(cfg json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ return mock, nil
+ }
+ if err := RegisterContextManager("ingest_track_cm", factory); err != nil {
+ t.Fatalf("register failed: %v", err)
+ }
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "ingest_track_cm",
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "done"})
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ // Run a turn — ingestMessage is called for user message and final assistant message
+ _, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
+ SessionKey: "session-ingest-turn",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "test ingest",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+
+ // Should have at least 2 ingest calls: user message + final assistant message
+ if mock.ingestCalls.Load() < 2 {
+ t.Fatalf("expected >= 2 ingest calls during turn, got %d", mock.ingestCalls.Load())
+ }
+}
+
+// ---------------------------------------------------------------------------
+// forceCompression edge cases (via legacy Compact)
+// ---------------------------------------------------------------------------
+
+func TestLegacyCompact_Overflow_SingleTurnKeepsLastUserMessage(t *testing.T) {
+ cfg := testConfig(t)
+ al := newCMTestAgentLoop(cfg)
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ // History with only 2 messages — forceCompression should still handle it
+ history := []providers.Message{
+ {Role: "user", Content: "first question"},
+ {Role: "assistant", Content: "first answer"},
+ }
+ defaultAgent.Sessions.SetHistory("session-2msg", history)
+
+ err := al.contextManager.Compact(context.Background(), &CompactRequest{
+ SessionKey: "session-2msg",
+ Reason: ContextCompressReasonRetry,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ newHistory := defaultAgent.Sessions.GetHistory("session-2msg")
+ // With 2 messages, forceCompression returns false (len <= 2), so no compression
+ if len(newHistory) != len(history) {
+ t.Fatalf("expected no compression for 2-message history, got %d", len(newHistory))
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Test helpers
+// ---------------------------------------------------------------------------
+
+// noopContextManager is a minimal ContextManager that does nothing.
+type noopContextManager struct{}
+
+func (m *noopContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
+ return &AssembleResponse{}, nil
+}
+func (m *noopContextManager) Compact(_ context.Context, _ *CompactRequest) error { return nil }
+func (m *noopContextManager) Ingest(_ context.Context, _ *IngestRequest) error { return nil }
+
+// trackingContextManager tracks call counts for each method.
+type trackingContextManager struct {
+ assembleCalls atomic.Int64
+ compactCalls atomic.Int64
+ ingestCalls atomic.Int64
+ mu sync.Mutex
+ lastAssemble *AssembleRequest
+ lastCompact *CompactRequest
+ lastIngest *IngestRequest
+}
+
+func (m *trackingContextManager) Assemble(_ context.Context, req *AssembleRequest) (*AssembleResponse, error) {
+ m.assembleCalls.Add(1)
+ m.mu.Lock()
+ m.lastAssemble = req
+ m.mu.Unlock()
+ return &AssembleResponse{}, nil
+}
+
+func (m *trackingContextManager) Compact(_ context.Context, req *CompactRequest) error {
+ m.compactCalls.Add(1)
+ m.mu.Lock()
+ m.lastCompact = req
+ m.mu.Unlock()
+ return nil
+}
+
+func (m *trackingContextManager) Ingest(_ context.Context, req *IngestRequest) error {
+ m.ingestCalls.Add(1)
+ m.mu.Lock()
+ m.lastIngest = req
+ m.mu.Unlock()
+ return nil
+}
+
+// resetCMRegistry clears the global factory registry and returns a cleanup
+// function that restores the original state after the test.
+func resetCMRegistry() func() {
+ cmRegistryMu.Lock()
+ original := make(map[string]ContextManagerFactory, len(cmRegistry))
+ for k, v := range cmRegistry {
+ original[k] = v
+ }
+ cmRegistry = make(map[string]ContextManagerFactory)
+ cmRegistryMu.Unlock()
+
+ return func() {
+ cmRegistryMu.Lock()
+ cmRegistry = original
+ cmRegistryMu.Unlock()
+ }
+}
+
+func testConfig(t *testing.T) *config.Config {
+ t.Helper()
+ return &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+}
+
+func newCMTestAgentLoop(cfg *config.Config) *AgentLoop {
+ msgBus := bus.NewMessageBus()
+ return NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "test"})
+}
diff --git a/pkg/agent/context_seahorse.go b/pkg/agent/context_seahorse.go
new file mode 100644
index 000000000..a2e09095a
--- /dev/null
+++ b/pkg/agent/context_seahorse.go
@@ -0,0 +1,269 @@
+//go:build !mipsle && !netbsd
+
+package agent
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
+ "github.com/sipeed/picoclaw/pkg/seahorse"
+ "github.com/sipeed/picoclaw/pkg/session"
+ "github.com/sipeed/picoclaw/pkg/tokenizer"
+)
+
+// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
+type seahorseContextManager struct {
+ engine *seahorse.Engine
+ sessions session.SessionStore // for startup bootstrap
+}
+
+// newSeahorseContextManager creates a seahorse-backed ContextManager.
+func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
+ if al == nil {
+ return nil, fmt.Errorf("seahorse: AgentLoop is required")
+ }
+
+ // Resolve workspace for DB path
+ // DB stores session data, so it goes in sessions/ directory
+ agent := al.registry.GetDefaultAgent()
+ dbPath := agent.Workspace + "/sessions/seahorse.db"
+
+ // Create CompleteFn from provider
+ completeFn := providerToCompleteFn(agent.Provider, agent.Model)
+
+ // Create engine
+ engine, err := seahorse.NewEngine(seahorse.Config{
+ DBPath: dbPath,
+ }, completeFn)
+ if err != nil {
+ return nil, fmt.Errorf("seahorse: create engine: %w", err)
+ }
+
+ mgr := &seahorseContextManager{
+ engine: engine,
+ sessions: agent.Sessions,
+ }
+
+ // Register seahorse tools with the agent's tool registry
+ retrieval := mgr.engine.GetRetrieval()
+ al.RegisterTool(seahorse.NewGrepTool(retrieval))
+ al.RegisterTool(seahorse.NewExpandTool(retrieval))
+
+ // Bootstrap all existing sessions at startup
+ if agent.Sessions != nil {
+ ctx := context.Background()
+ for _, sessionKey := range agent.Sessions.ListSessions() {
+ mgr.bootstrapSession(ctx, sessionKey)
+ }
+ }
+
+ return mgr, nil
+}
+
+// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
+func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
+ return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
+ resp, err := provider.Chat(
+ ctx,
+ []providers.Message{{Role: "user", Content: prompt}},
+ nil, // no tools for summarization
+ model,
+ map[string]any{
+ "max_tokens": opts.MaxTokens,
+ "temperature": opts.Temperature,
+ "prompt_cache_key": "seahorse",
+ },
+ )
+ if err != nil {
+ return "", err
+ }
+ return resp.Content, nil
+ }
+}
+
+// Assemble builds budget-aware context from seahorse SQLite.
+func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
+ if req == nil {
+ return nil, fmt.Errorf("seahorse assemble: nil request")
+ }
+
+ budget := req.Budget
+ if budget <= 0 {
+ budget = 100000
+ }
+
+ // Reserve space for model response (spec lines 1400-1410)
+ effectiveBudget := budget - req.MaxTokens
+ if effectiveBudget <= 0 {
+ // MaxTokens >= budget is a configuration problem
+ // Use 50% as minimum to avoid guaranteed overflow
+ logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
+ map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
+ effectiveBudget = budget / 2
+ }
+
+ result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
+ Budget: effectiveBudget,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("seahorse assemble: %w", err)
+ }
+
+ history := seahorseToProviderMessages(result)
+
+ // Summary is already formatted as XML with system prompt addition by assembler
+ return &AssembleResponse{
+ History: history,
+ Summary: result.Summary,
+ }, nil
+}
+
+// Compact compresses conversation history via seahorse summarization.
+func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
+ if req == nil {
+ return nil
+ }
+
+ // For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
+ // context shrinks below budget (spec lines ~1410).
+ if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
+ _, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
+ return err
+ }
+
+ _, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
+ Force: req.Reason == ContextCompressReasonRetry,
+ Budget: &req.Budget,
+ })
+ return err
+}
+
+// Ingest records a message into seahorse SQLite.
+// All existing sessions are bootstrapped at startup, so this only ingests new messages.
+func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
+ if req == nil {
+ return nil
+ }
+
+ msg := providerToSeahorseMessage(req.Message)
+ _, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
+ return err
+}
+
+// bootstrapSession reconciles JSONL session history into seahorse SQLite.
+func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
+ if m.sessions == nil {
+ return
+ }
+
+ history := m.sessions.GetHistory(sessionKey)
+ if len(history) == 0 {
+ return
+ }
+
+ // Convert provider messages to seahorse messages
+ msgs := make([]seahorse.Message, len(history))
+ for i, h := range history {
+ msgs[i] = providerToSeahorseMessage(h)
+ }
+
+ if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
+ logger.WarnCF("seahorse", "bootstrap", map[string]any{
+ "session": sessionKey,
+ "error": err.Error(),
+ })
+ }
+}
+
+// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
+func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
+ result := seahorse.Message{
+ Role: msg.Role,
+ Content: msg.Content,
+ ReasoningContent: msg.ReasoningContent,
+ TokenCount: tokenizer.EstimateMessageTokens(msg),
+ }
+
+ // Convert ToolCalls → MessageParts
+ for _, tc := range msg.ToolCalls {
+ part := seahorse.MessagePart{
+ Type: "tool_use",
+ Name: tc.Function.Name,
+ Arguments: tc.Function.Arguments,
+ ToolCallID: tc.ID,
+ }
+ result.Parts = append(result.Parts, part)
+ }
+
+ // Convert tool result
+ if msg.ToolCallID != "" {
+ part := seahorse.MessagePart{
+ Type: "tool_result",
+ ToolCallID: msg.ToolCallID,
+ Text: msg.Content,
+ }
+ result.Parts = append(result.Parts, part)
+ }
+
+ // Convert media attachments
+ for _, mediaURI := range msg.Media {
+ part := seahorse.MessagePart{
+ Type: "media",
+ MediaURI: mediaURI,
+ }
+ result.Parts = append(result.Parts, part)
+ }
+
+ return result
+}
+
+// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
+func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
+ messages := make([]protocoltypes.Message, 0, len(result.Messages))
+
+ // Convert assembled messages (which already include summary XML messages)
+ for _, msg := range result.Messages {
+ pm := protocoltypes.Message{
+ Role: msg.Role,
+ Content: msg.Content,
+ ReasoningContent: msg.ReasoningContent,
+ }
+
+ // Reconstruct ToolCalls from parts
+ for _, part := range msg.Parts {
+ if part.Type == "tool_use" {
+ pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
+ ID: part.ToolCallID,
+ Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
+ Function: &protocoltypes.FunctionCall{
+ Name: part.Name,
+ Arguments: part.Arguments,
+ },
+ })
+ }
+ if part.Type == "tool_result" {
+ pm.ToolCallID = part.ToolCallID
+ if pm.Content == "" && part.Text != "" {
+ pm.Content = part.Text
+ }
+ }
+ if part.Type == "media" && part.MediaURI != "" {
+ pm.Media = append(pm.Media, part.MediaURI)
+ }
+ }
+
+ messages = append(messages, pm)
+ }
+
+ return messages
+}
+
+func init() {
+ if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
+ panic(fmt.Sprintf("register seahorse context manager: %v", err))
+ }
+}
diff --git a/pkg/agent/context_seahorse_test.go b/pkg/agent/context_seahorse_test.go
new file mode 100644
index 000000000..e405ef944
--- /dev/null
+++ b/pkg/agent/context_seahorse_test.go
@@ -0,0 +1,1086 @@
+package agent
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
+ "github.com/sipeed/picoclaw/pkg/seahorse"
+)
+
+// seahorseTestProvider implements providers.LLMProvider for seahorse tests.
+type seahorseTestProvider struct {
+ chatFn func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]any) (*providers.LLMResponse, error)
+}
+
+func (m *seahorseTestProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ options map[string]any,
+) (*providers.LLMResponse, error) {
+ if m.chatFn != nil {
+ return m.chatFn(ctx, messages, tools, model, options)
+ }
+ return &providers.LLMResponse{Content: "mock response"}, nil
+}
+
+func (m *seahorseTestProvider) GetDefaultModel() string {
+ return "mock-model"
+}
+
+func TestSeahorseCMRegistration(t *testing.T) {
+ factory, ok := lookupContextManager("seahorse")
+ if !ok {
+ t.Error("expected 'seahorse' context manager to be registered")
+ }
+ if factory == nil {
+ t.Error("expected non-nil factory")
+ }
+}
+
+func TestProviderToSeahorseMessage(t *testing.T) {
+ tests := []struct {
+ name string
+ input protocoltypes.Message
+ wantRole string
+ wantContent string
+ }{
+ {
+ name: "simple user message",
+ input: protocoltypes.Message{Role: "user", Content: "hello world"},
+ wantRole: "user",
+ wantContent: "hello world",
+ },
+ {
+ name: "assistant message",
+ input: protocoltypes.Message{Role: "assistant", Content: "response text"},
+ wantRole: "assistant",
+ wantContent: "response text",
+ },
+ {
+ name: "tool result message",
+ input: protocoltypes.Message{Role: "tool", Content: "tool output", ToolCallID: "tc_123"},
+ wantRole: "tool",
+ wantContent: "tool output",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := providerToSeahorseMessage(tt.input)
+ if result.Role != tt.wantRole {
+ t.Errorf("Role = %q, want %q", result.Role, tt.wantRole)
+ }
+ if result.Content != tt.wantContent {
+ t.Errorf("Content = %q, want %q", result.Content, tt.wantContent)
+ }
+ })
+ }
+}
+
+func TestProviderToSeahorseMessageWithToolCalls(t *testing.T) {
+ msg := protocoltypes.Message{
+ Role: "assistant",
+ Content: "",
+ ToolCalls: []protocoltypes.ToolCall{
+ {
+ ID: "tc_1",
+ Function: &protocoltypes.FunctionCall{
+ Name: "read_file",
+ Arguments: `{"path":"/tmp/test"}`,
+ },
+ },
+ },
+ }
+
+ result := providerToSeahorseMessage(msg)
+ if result.Role != "assistant" {
+ t.Errorf("Role = %q, want assistant", result.Role)
+ }
+ if len(result.Parts) == 0 {
+ t.Fatal("expected at least 1 part from tool calls")
+ }
+ if result.Parts[0].Type != "tool_use" {
+ t.Errorf("Part type = %q, want tool_use", result.Parts[0].Type)
+ }
+ if result.Parts[0].Name != "read_file" {
+ t.Errorf("Part name = %q, want read_file", result.Parts[0].Name)
+ }
+ if result.Parts[0].ToolCallID != "tc_1" {
+ t.Errorf("Part ToolCallID = %q, want tc_1", result.Parts[0].ToolCallID)
+ }
+}
+
+func TestProviderToSeahorseMessageWithToolResult(t *testing.T) {
+ msg := protocoltypes.Message{
+ Role: "tool",
+ Content: "file contents here",
+ ToolCallID: "tc_456",
+ }
+
+ result := providerToSeahorseMessage(msg)
+ if result.Role != "tool" {
+ t.Errorf("Role = %q, want tool", result.Role)
+ }
+ found := false
+ for _, p := range result.Parts {
+ if p.Type == "tool_result" && p.ToolCallID == "tc_456" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("expected tool_result part with ToolCallID tc_456")
+ }
+}
+
+func TestProviderToSeahorseMessageWithMedia(t *testing.T) {
+ msg := protocoltypes.Message{
+ Role: "user",
+ Content: "Here is an image",
+ Media: []string{"data:image/png;base64,abc123"},
+ }
+
+ result := providerToSeahorseMessage(msg)
+ if result.Role != "user" {
+ t.Errorf("Role = %q, want user", result.Role)
+ }
+
+ // Should have a media part
+ found := false
+ for _, p := range result.Parts {
+ if p.Type == "media" {
+ found = true
+ if p.MediaURI != "data:image/png;base64,abc123" {
+ t.Errorf("MediaURI = %q, want data:image/png;base64,abc123", p.MediaURI)
+ }
+ break
+ }
+ }
+ if !found {
+ t.Error("expected media part in converted message")
+ }
+}
+
+func TestProviderToSeahorseMessageWithReasoning(t *testing.T) {
+ msg := protocoltypes.Message{
+ Role: "assistant",
+ Content: "response text",
+ ReasoningContent: "I thought about this carefully",
+ }
+
+ result := providerToSeahorseMessage(msg)
+ if result.ReasoningContent != "I thought about this carefully" {
+ t.Errorf("ReasoningContent = %q, want 'I thought about this carefully'", result.ReasoningContent)
+ }
+}
+
+func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) {
+ result := &seahorse.AssembleResult{
+ Messages: []seahorse.Message{
+ {
+ Role: "assistant",
+ Content: "response",
+ ReasoningContent: "thinking process",
+ },
+ },
+ }
+
+ messages := seahorseToProviderMessages(result)
+ if len(messages) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(messages))
+ }
+ if messages[0].ReasoningContent != "thinking process" {
+ t.Errorf("ReasoningContent = %q, want 'thinking process'", messages[0].ReasoningContent)
+ }
+}
+
+func TestSeahorseToProviderMessages(t *testing.T) {
+ // Summaries should NOT be double-injected.
+ // The assembler already includes summaries as XML-formatted messages in Messages slice.
+ // seahorseToProviderMessages should only convert Messages, not Summaries.
+ summaryXML := `
+
+ test summary content
+
+`
+ summaryMsg := seahorse.Message{
+ Role: "user",
+ Content: summaryXML,
+ TokenCount: 50,
+ }
+ rawMsg := seahorse.Message{
+ Role: "user",
+ Content: "hello",
+ TokenCount: 5,
+ }
+
+ result := seahorseToProviderMessages(&seahorse.AssembleResult{
+ Messages: []seahorse.Message{summaryMsg, rawMsg},
+ })
+
+ // Should have exactly 2 messages (from Messages slice only)
+ // NOT 3 (which would happen if Summaries were also converted)
+ if len(result) != 2 {
+ t.Fatalf("expected exactly 2 messages (no double injection), got %d", len(result))
+ }
+ // First should be the XML summary message
+ if result[0].Content != summaryXML {
+ t.Errorf("first message content = %q, want summary XML", result[0].Content)
+ }
+ // Second should be the raw message
+ if result[1].Content != "hello" {
+ t.Errorf("second message content = %q, want 'hello'", result[1].Content)
+ }
+}
+
+func TestSeahorseToProviderMessagesWithToolCalls(t *testing.T) {
+ msg := seahorse.Message{
+ Role: "assistant",
+ Content: "",
+ TokenCount: 10,
+ Parts: []seahorse.MessagePart{
+ {
+ Type: "tool_use",
+ Name: "read_file",
+ Arguments: `{"path":"/tmp"}`,
+ ToolCallID: "tc_1",
+ },
+ },
+ }
+
+ result := seahorseToProviderMessages(&seahorse.AssembleResult{
+ Messages: []seahorse.Message{msg},
+ })
+
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(result))
+ }
+ if result[0].Role != "assistant" {
+ t.Errorf("Role = %q, want assistant", result[0].Role)
+ }
+ if len(result[0].ToolCalls) != 1 {
+ t.Fatalf("ToolCalls = %d, want 1", len(result[0].ToolCalls))
+ }
+ if result[0].ToolCalls[0].Function.Name != "read_file" {
+ t.Errorf("ToolCall name = %q, want read_file", result[0].ToolCalls[0].Function.Name)
+ }
+ // GLM API and other OpenAI-compatible APIs require Type: "function"
+ if result[0].ToolCalls[0].Type != "function" {
+ t.Errorf("ToolCall Type = %q, want 'function' (required by GLM/OpenAI APIs)",
+ result[0].ToolCalls[0].Type)
+ }
+}
+
+func TestSeahorseToProviderMessagesToolResult(t *testing.T) {
+ msg := seahorse.Message{
+ Role: "tool",
+ Content: "file output",
+ TokenCount: 5,
+ Parts: []seahorse.MessagePart{
+ {
+ Type: "tool_result",
+ ToolCallID: "tc_99",
+ Text: "file output",
+ },
+ },
+ }
+
+ result := seahorseToProviderMessages(&seahorse.AssembleResult{
+ Messages: []seahorse.Message{msg},
+ })
+
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(result))
+ }
+ if result[0].ToolCallID != "tc_99" {
+ t.Errorf("ToolCallID = %q, want tc_99", result[0].ToolCallID)
+ }
+}
+
+// --- providerToCompleteFn tests ---
+
+func TestProviderToCompleteFn(t *testing.T) {
+ var capturedMessages []providers.Message
+ var capturedModel string
+ var capturedOptions map[string]any
+
+ mp := &seahorseTestProvider{
+ chatFn: func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]any) (*providers.LLMResponse, error) {
+ capturedMessages = messages
+ capturedModel = model
+ capturedOptions = options
+ return &providers.LLMResponse{Content: "summary of conversation"}, nil
+ },
+ }
+
+ completeFn := providerToCompleteFn(mp, "test-model-v1")
+ result, err := completeFn(context.Background(), "Summarize this text", seahorse.CompleteOptions{
+ MaxTokens: 500,
+ Temperature: 0.3,
+ })
+ if err != nil {
+ t.Fatalf("completeFn: %v", err)
+ }
+ if result != "summary of conversation" {
+ t.Errorf("result = %q, want 'summary of conversation'", result)
+ }
+
+ // Verify prompt passed as user message
+ if len(capturedMessages) != 1 {
+ t.Fatalf("captured messages = %d, want 1", len(capturedMessages))
+ }
+ if capturedMessages[0].Role != "user" {
+ t.Errorf("message role = %q, want user", capturedMessages[0].Role)
+ }
+ if capturedMessages[0].Content != "Summarize this text" {
+ t.Errorf("message content = %q, want 'Summarize this text'", capturedMessages[0].Content)
+ }
+
+ // Verify model
+ if capturedModel != "test-model-v1" {
+ t.Errorf("model = %q, want 'test-model-v1'", capturedModel)
+ }
+
+ // Verify options
+ if capturedOptions["max_tokens"] != 500 {
+ t.Errorf("max_tokens = %v, want 500", capturedOptions["max_tokens"])
+ }
+ if capturedOptions["temperature"] != 0.3 {
+ t.Errorf("temperature = %v, want 0.3", capturedOptions["temperature"])
+ }
+ if capturedOptions["prompt_cache_key"] != "seahorse" {
+ t.Errorf("prompt_cache_key = %v, want 'seahorse'", capturedOptions["prompt_cache_key"])
+ }
+}
+
+func TestSeahorseIgnoreHeartbeat(t *testing.T) {
+ // Verify that "heartbeat" sessions are ignored by default
+ // This tests the hardcoded ignore pattern from spec lines 1326-1328
+ engine, err := seahorse.NewEngine(seahorse.Config{
+ DBPath: t.TempDir() + "/test.db",
+ }, nil)
+ if err != nil {
+ t.Fatalf("NewEngine: %v", err)
+ }
+ defer engine.Close()
+
+ ctx := context.Background()
+ result, err := engine.Ingest(ctx, "heartbeat", []seahorse.Message{
+ {Role: "user", Content: "heartbeat msg", TokenCount: 5},
+ })
+ if err != nil {
+ t.Fatalf("Ingest: %v", err)
+ }
+ // Should return nil nil for ignored sessions
+ if result != nil {
+ t.Errorf("expected nil result for heartbeat session, got %+v", result)
+ }
+}
+
+func TestProviderToCompleteFnError(t *testing.T) {
+ mp := &seahorseTestProvider{
+ chatFn: func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]any) (*providers.LLMResponse, error) {
+ return nil, context.Canceled
+ },
+ }
+
+ completeFn := providerToCompleteFn(mp, "test-model")
+ _, err := completeFn(context.Background(), "test prompt", seahorse.CompleteOptions{})
+ if err == nil {
+ t.Error("expected error from canceled context")
+ }
+}
+
+func TestSeahorseAdapterAssembleSubtractsMaxTokens(t *testing.T) {
+ // Create a real seahorse engine with temp DB
+ engine, err := seahorse.NewEngine(seahorse.Config{
+ DBPath: t.TempDir() + "/test.db",
+ }, nil)
+ if err != nil {
+ t.Fatalf("NewEngine: %v", err)
+ }
+ defer engine.Close()
+
+ ctx := context.Background()
+ mgr := &seahorseContextManager{engine: engine}
+
+ // Ingest lots of large messages (~35 tokens each, 120 total = ~4200 tokens)
+ for i := 0; i < 60; i++ {
+ content := fmt.Sprintf(
+ "This is message number %d. It contains enough text to represent a meaningful conversation turn with the user asking about various topics in software engineering and system design principles that require careful consideration.",
+ i,
+ )
+ _ = mgr.Ingest(ctx, &IngestRequest{
+ SessionKey: "budget-sub",
+ Message: protocoltypes.Message{Role: "user", Content: content},
+ })
+ _ = mgr.Ingest(ctx, &IngestRequest{
+ SessionKey: "budget-sub",
+ Message: protocoltypes.Message{Role: "assistant", Content: "Response"},
+ })
+ }
+
+ // Call adapter Assemble with Budget=5000, MaxTokens=2000
+ // Should use effective budget = 5000 - 2000 = 3000
+ resp, err := mgr.Assemble(ctx, &AssembleRequest{
+ SessionKey: "budget-sub",
+ Budget: 5000,
+ MaxTokens: 2000,
+ })
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+ if resp == nil {
+ t.Fatal("expected non-nil response")
+ }
+
+ // Directly call engine with budget=3000 to get baseline
+ baseline, err := engine.Assemble(ctx, "budget-sub", seahorse.AssembleInput{Budget: 3000})
+ if err != nil {
+ t.Fatalf("engine.Assemble baseline: %v", err)
+ }
+
+ // The adapter result should have same message count as engine with budget 3000
+ if len(resp.History) != len(baseline.Messages) {
+ t.Errorf("adapter Budget=5000 MaxTokens=2000 gave %d messages, engine Budget=3000 gave %d",
+ len(resp.History), len(baseline.Messages))
+ }
+}
+
+func TestSeahorseCompactRetryUsesCompactUntilUnder(t *testing.T) {
+ // Track which engine method was called
+ var compactCalled, compactUntilCalled bool
+
+ engine, err := seahorse.NewEngine(seahorse.Config{
+ DBPath: t.TempDir() + "/test.db",
+ }, nil)
+ if err != nil {
+ t.Fatalf("NewEngine: %v", err)
+ }
+ defer engine.Close()
+
+ // Wrap engine to track calls
+ _ = compactCalled // track via adapter behavior
+ _ = compactUntilCalled
+
+ mgr := &seahorseContextManager{engine: engine}
+
+ ctx := context.Background()
+
+ // Ingest messages so there's something to compact
+ for i := 0; i < 40; i++ {
+ content := fmt.Sprintf(
+ "message %d with enough text to have meaningful token count that fills up the budget nicely",
+ i,
+ )
+ _ = mgr.Ingest(ctx, &IngestRequest{
+ SessionKey: "compact-test",
+ Message: protocoltypes.Message{Role: "user", Content: content},
+ })
+ _ = mgr.Ingest(ctx, &IngestRequest{
+ SessionKey: "compact-test",
+ Message: protocoltypes.Message{Role: "assistant", Content: "ok"},
+ })
+ }
+
+ // Compact with retry reason and budget should succeed
+ err = mgr.Compact(ctx, &CompactRequest{
+ SessionKey: "compact-test",
+ Reason: ContextCompressReasonRetry,
+ Budget: 5000,
+ })
+ if err != nil {
+ t.Fatalf("Compact retry: %v", err)
+ }
+
+ // Verify context was actually compacted (should have fewer tokens)
+ result, err := engine.Assemble(ctx, "compact-test", seahorse.AssembleInput{Budget: 5000})
+ if err != nil {
+ t.Fatalf("Assemble after compact: %v", err)
+ }
+ if result == nil {
+ t.Fatal("expected non-nil assemble result")
+ }
+ // Compaction attempted — no assertion on exact count since no LLM
+ _ = result.Summary
+}
+
+// TestSeahorseRealLoopNoDuplicateMessages tests the real-world scenario:
+// 1. Start AgentLoop with seahorse context manager
+// 2. Run a turn (user message -> LLM response)
+// 3. Check DB for duplicate messages
+// This test verifies that bootstrapping at startup (not during first Ingest) prevents duplicates.
+func TestSeahorseRealLoopNoDuplicateMessages(t *testing.T) {
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "seahorse",
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ mockProvider := &simpleMockProvider{response: "I received your message."}
+ al := NewAgentLoop(cfg, msgBus, mockProvider)
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ ctx := context.Background()
+ sessionKey := "test-real-loop-dup"
+
+ // Run a turn: user message -> LLM response
+ _, err := al.runAgentLoop(ctx, defaultAgent, processOptions{
+ SessionKey: sessionKey,
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "hello",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+
+ // Get the seahorse engine from context manager
+ seahorseCM, ok := al.contextManager.(*seahorseContextManager)
+ if !ok {
+ t.Fatal("expected seahorseContextManager")
+ }
+
+ // Check DB for messages via RetrievalEngine.Store()
+ store := seahorseCM.engine.GetRetrieval().Store()
+ conv, err := store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation: %v", err)
+ }
+
+ stored, err := store.GetMessages(ctx, conv.ConversationID, 20, 0)
+ if err != nil {
+ t.Fatalf("GetMessages: %v", err)
+ }
+
+ t.Logf("DB has %d messages:", len(stored))
+ for i, msg := range stored {
+ content := msg.Content
+ if len(content) > 40 {
+ content = content[:40] + "..."
+ }
+ t.Logf(" msg[%d]: role=%s content=%q", i, msg.Role, content)
+ }
+
+ // Count duplicates by (role, content)
+ seen := make(map[string]int)
+ for _, msg := range stored {
+ key := msg.Role + ":" + msg.Content
+ seen[key]++
+ }
+ for key, count := range seen {
+ if count > 1 {
+ t.Errorf("DUPLICATE BUG: %q appears %d times in DB", key, count)
+ }
+ }
+
+ // Expected: 2 messages (user "hello" + assistant response)
+ if len(stored) != 2 {
+ t.Errorf("expected 2 messages in DB (user + assistant), got %d", len(stored))
+ }
+}
+
+// TestSeahorseAssembleReturnsAllSummaries verifies that Assemble returns ALL summaries,
+// not just the latest one. This is important because summaries represent compressed
+// conversation history at different points in time.
+func TestSeahorseAssembleReturnsAllSummaries(t *testing.T) {
+ // Create a real seahorse engine with temp DB
+ engine, err := seahorse.NewEngine(seahorse.Config{
+ DBPath: t.TempDir() + "/test.db",
+ }, nil)
+ if err != nil {
+ t.Fatalf("NewEngine: %v", err)
+ }
+ defer engine.Close()
+
+ ctx := context.Background()
+ mgr := &seahorseContextManager{engine: engine}
+ sessionKey := "test-multi-summary"
+
+ // Get the store to directly create summaries
+ store := engine.GetRetrieval().Store()
+
+ // Get conversation ID
+ conv, err := store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation: %v", err)
+ }
+
+ // Create some messages first
+ for i := 0; i < 20; i++ {
+ _ = mgr.Ingest(ctx, &IngestRequest{
+ SessionKey: sessionKey,
+ Message: protocoltypes.Message{Role: "user", Content: fmt.Sprintf("Message %d", i)},
+ })
+ }
+
+ // Directly create multiple summaries in the database to simulate multi-level compaction
+ testSummaries := []struct {
+ content string
+ kind seahorse.SummaryKind
+ depth int
+ token int
+ }{
+ {"First summary about early conversation discussing topics A and B", seahorse.SummaryKindLeaf, 0, 100},
+ {"Second summary covering middle conversation about topics C and D", seahorse.SummaryKindLeaf, 0, 150},
+ {"Third summary is condensed from first two summaries about topics A-D", seahorse.SummaryKindCondensed, 1, 200},
+ }
+
+ summaryIDs := make([]string, 0, len(testSummaries))
+ for _, s := range testSummaries {
+ input := seahorse.CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: s.kind,
+ Depth: s.depth,
+ Content: s.content,
+ TokenCount: s.token,
+ }
+ summary, createErr := store.CreateSummary(ctx, input)
+ if createErr != nil {
+ t.Fatalf("CreateSummary: %v", createErr)
+ }
+ summaryIDs = append(summaryIDs, summary.SummaryID)
+
+ // Add summary to context_items
+ err = store.AppendContextSummary(ctx, conv.ConversationID, summary.SummaryID)
+ if err != nil {
+ t.Fatalf("AppendContextSummary: %v", err)
+ }
+ }
+
+ t.Logf("Created %d summaries directly in store", len(summaryIDs))
+
+ // Assemble and check summaries
+ resp, err := mgr.Assemble(ctx, &AssembleRequest{
+ SessionKey: sessionKey,
+ Budget: 50000,
+ MaxTokens: 4096,
+ })
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ // Check seahorse engine directly for how many summaries exist
+ result, err := engine.Assemble(ctx, sessionKey, seahorse.AssembleInput{Budget: 50000})
+ if err != nil {
+ t.Fatalf("engine.Assemble: %v", err)
+ }
+
+ t.Logf("Seahorse returned Summary with %d chars", len(result.Summary))
+
+ // The Summary field should contain XML summaries with metadata (depth, kind)
+ // The assembler generates this from the Summaries list
+ if len(resp.Summary) > 0 {
+ // Should contain XML tag
+ if !strings.Contains(resp.Summary, " Content-only = %d",
+ resultWithToolCalls.TokenCount, resultContentOnly.TokenCount)
+ }
+
+ // Message with ToolCallID
+ msgWithToolResult := protocoltypes.Message{
+ Role: "tool",
+ Content: "This is a simple response with some text content.",
+ ToolCallID: "tc_456",
+ }
+ resultWithToolResult := providerToSeahorseMessage(msgWithToolResult)
+
+ if resultWithToolResult.TokenCount <= resultContentOnly.TokenCount {
+ t.Errorf("TokenCount with ToolCallID = %d, should be > Content-only = %d",
+ resultWithToolResult.TokenCount, resultContentOnly.TokenCount)
+ }
+
+ // Message with Media
+ msgWithMedia := protocoltypes.Message{
+ Role: "user",
+ Content: "This is a simple response with some text content.",
+ Media: []string{"data:image/png;base64,abc123"},
+ }
+ resultWithMedia := providerToSeahorseMessage(msgWithMedia)
+
+ if resultWithMedia.TokenCount <= resultContentOnly.TokenCount {
+ t.Errorf("TokenCount with Media = %d, should be > Content-only = %d",
+ resultWithMedia.TokenCount, resultContentOnly.TokenCount)
+ }
+}
+
+func TestSeahorseToProviderMessagesRebuildsContentFromParts(t *testing.T) {
+ msg := seahorse.Message{
+ Role: "tool",
+ Content: "",
+ TokenCount: 50,
+ Parts: []seahorse.MessagePart{
+ {
+ Type: "tool_result",
+ ToolCallID: "tc_999",
+ Text: "This is the actual tool output that should be in Content",
+ },
+ },
+ }
+
+ result := seahorseToProviderMessages(&seahorse.AssembleResult{
+ Messages: []seahorse.Message{msg},
+ })
+
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(result))
+ }
+
+ if result[0].Content == "" {
+ t.Error("Content is empty - tool_result text was not rebuilt into Content")
+ }
+ if result[0].Content != "This is the actual tool output that should be in Content" {
+ t.Errorf("Content = %q, want tool output text from Parts", result[0].Content)
+ }
+}
+
+func TestSeahorseAssembleSummaryNotInMessages(t *testing.T) {
+ engine, err := seahorse.NewEngine(seahorse.Config{
+ DBPath: t.TempDir() + "/test.db",
+ }, nil)
+ if err != nil {
+ t.Fatalf("NewEngine: %v", err)
+ }
+ defer engine.Close()
+
+ ctx := context.Background()
+ mgr := &seahorseContextManager{engine: engine}
+ sessionKey := "test-no-dup-summary"
+
+ // Get the store to directly create a summary
+ store := engine.GetRetrieval().Store()
+ conv, err := store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation: %v", err)
+ }
+
+ // Ingest some messages first
+ for i := 0; i < 10; i++ {
+ _ = mgr.Ingest(ctx, &IngestRequest{
+ SessionKey: sessionKey,
+ Message: protocoltypes.Message{Role: "user", Content: fmt.Sprintf("Message %d", i)},
+ })
+ }
+
+ // Create a summary
+ input := seahorse.CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: seahorse.SummaryKindLeaf,
+ Depth: 0,
+ Content: "This is a test summary about the conversation",
+ TokenCount: 50,
+ }
+ summary, err := store.CreateSummary(ctx, input)
+ if err != nil {
+ t.Fatalf("CreateSummary: %v", err)
+ }
+ err = store.AppendContextSummary(ctx, conv.ConversationID, summary.SummaryID)
+ if err != nil {
+ t.Fatalf("AppendContextSummary: %v", err)
+ }
+
+ // Assemble
+ resp, err := mgr.Assemble(ctx, &AssembleRequest{
+ SessionKey: sessionKey,
+ Budget: 50000,
+ MaxTokens: 4096,
+ })
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ // Count how many times the summary content appears
+ summaryContent := "This is a test summary"
+ countInHistory := 0
+ for _, msg := range resp.History {
+ if strings.Contains(msg.Content, summaryContent) {
+ countInHistory++
+ }
+ }
+
+ if countInHistory > 0 {
+ t.Errorf("Summary content appears %d times in History - should be 0", countInHistory)
+ }
+
+ // Summary should appear in Summary field
+ if !strings.Contains(resp.Summary, summaryContent) {
+ t.Error("Summary content should appear in response.Summary field")
+ }
+}
+
+// TestSeahorseSteeringMessageIngested verifies that steering messages are ingested
+// into seahorse SQLite, not just session JSONL.
+func TestSeahorseSteeringMessageIngested(t *testing.T) {
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "seahorse",
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ mockProvider := &simpleMockProvider{response: "I received your message."}
+ al := NewAgentLoop(cfg, msgBus, mockProvider)
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ ctx := context.Background()
+ sessionKey := "test-steering-ingest"
+
+ // First turn: establish conversation
+ _, err := al.runAgentLoop(ctx, defaultAgent, processOptions{
+ SessionKey: sessionKey,
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "hello",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("first runAgentLoop failed: %v", err)
+ }
+
+ // Inject a steering message
+ steerErr := al.InjectSteering(providers.Message{
+ Role: "user",
+ Content: "steering message content",
+ })
+ if steerErr != nil {
+ t.Fatalf("InjectSteering failed: %v", steerErr)
+ }
+
+ // Second turn: should process steering message
+ _, err = al.runAgentLoop(ctx, defaultAgent, processOptions{
+ SessionKey: sessionKey,
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "continue",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("second runAgentLoop failed: %v", err)
+ }
+
+ // Get the seahorse engine from context manager
+ seahorseCM, ok := al.contextManager.(*seahorseContextManager)
+ if !ok {
+ t.Fatal("expected seahorseContextManager")
+ }
+
+ // Check DB for steering message
+ store := seahorseCM.engine.GetRetrieval().Store()
+ conv, err := store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation: %v", err)
+ }
+
+ stored, err := store.GetMessages(ctx, conv.ConversationID, 20, 0)
+ if err != nil {
+ t.Fatalf("GetMessages: %v", err)
+ }
+
+ t.Logf("DB has %d messages:", len(stored))
+ for i, msg := range stored {
+ content := msg.Content
+ if len(content) > 40 {
+ content = content[:40] + "..."
+ }
+ t.Logf(" msg[%d]: role=%s content=%q", i, msg.Role, content)
+ }
+
+ // Find steering message in stored messages
+ foundSteering := false
+ for _, msg := range stored {
+ if msg.Content == "steering message content" {
+ foundSteering = true
+ break
+ }
+ }
+
+ if !foundSteering {
+ t.Error("STEERING MESSAGE NOT IN SEAHORSE DB: steering message should be ingested into SQLite")
+ }
+}
+
+// TestSeahorseSummarizeSkipsCondensedWhenBelowThreshold verifies that when
+// Summarize is triggered but tokens are below ContextWindow threshold,
+// condensed compaction should NOT run.
+func TestSeahorseSummarizeSkipsCondensedWhenBelowThreshold(t *testing.T) {
+ contextWindow := 1000
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ ModelName: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextManager: "seahorse",
+ ContextWindow: contextWindow,
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ provider := &seahorseTestProvider{}
+ al := NewAgentLoop(cfg, msgBus, provider)
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ ctx := context.Background()
+ sessionKey := "test-summarize-skip-condensed"
+
+ seahorseCM, ok := al.contextManager.(*seahorseContextManager)
+ if !ok {
+ t.Fatal("expected seahorseContextManager")
+ }
+ store := seahorseCM.engine.GetRetrieval().Store()
+
+ conv, err := store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation: %v", err)
+ }
+
+ // Insert leaf summaries directly (bypass leaf compaction requirement)
+ for i := 0; i < seahorse.CondensedMinFanout; i++ {
+ now := time.Now().UTC()
+ summary, sumErr := store.CreateSummary(ctx, seahorse.CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: seahorse.SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("leaf summary %d", i),
+ TokenCount: 50,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ if sumErr != nil {
+ t.Fatalf("CreateSummary %d: %v", i, sumErr)
+ }
+ if appendErr := store.AppendContextSummary(ctx, conv.ConversationID, summary.SummaryID); appendErr != nil {
+ t.Fatalf("AppendContextSummary %d: %v", i, appendErr)
+ }
+ }
+
+ // Add fresh messages (required for condensation candidates)
+ for i := 0; i < seahorse.FreshTailCount+1; i++ {
+ m, msgErr := store.AddMessage(ctx, conv.ConversationID, "user", "fresh", 5)
+ if msgErr != nil {
+ t.Fatalf("AddMessage %d: %v", i, msgErr)
+ }
+ if appendErr := store.AppendContextMessage(ctx, conv.ConversationID, m.ID); appendErr != nil {
+ t.Fatalf("AppendContextMessage %d: %v", i, appendErr)
+ }
+ }
+
+ tokensBefore, err := store.GetContextTokenCount(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetContextTokenCount: %v", err)
+ }
+ threshold := int(float64(contextWindow) * seahorse.ContextThreshold)
+ t.Logf("Tokens before: %d, threshold: %d", tokensBefore, threshold)
+
+ // Trigger Summarize
+ _, err = al.runAgentLoop(ctx, defaultAgent, processOptions{
+ SessionKey: sessionKey,
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "trigger",
+ DefaultResponse: defaultResponse,
+ EnableSummary: true,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop: %v", err)
+ }
+
+ time.Sleep(500 * time.Millisecond)
+
+ summaries, err := store.GetSummariesByConversation(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetSummariesByConversation: %v", err)
+ }
+
+ condensedCount := 0
+ for _, sum := range summaries {
+ if sum.Kind == seahorse.SummaryKindCondensed {
+ condensedCount++
+ }
+ }
+
+ t.Logf("Condensed summaries: %d", condensedCount)
+
+ if tokensBefore < threshold && condensedCount > 0 {
+ t.Errorf("BUG: condensed created when tokens (%d) < threshold (%d)", tokensBefore, threshold)
+ }
+}
diff --git a/pkg/agent/context_seahorse_unsupported.go b/pkg/agent/context_seahorse_unsupported.go
new file mode 100644
index 000000000..882a973b9
--- /dev/null
+++ b/pkg/agent/context_seahorse_unsupported.go
@@ -0,0 +1,20 @@
+//go:build mipsle || netbsd
+
+package agent
+
+import (
+ "encoding/json"
+ "fmt"
+)
+
+// newSeahorseContextManager is unavailable on platforms where modernc sqlite/libc
+// currently has no stable build path for this project.
+func newSeahorseContextManager(_ json.RawMessage, _ *AgentLoop) (ContextManager, error) {
+ return nil, fmt.Errorf("seahorse context manager is unavailable on this platform")
+}
+
+func init() {
+ if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
+ panic(fmt.Sprintf("register seahorse context manager: %v", err))
+ }
+}
diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go
index 66046f87b..31b996260 100644
--- a/pkg/agent/eventbus_test.go
+++ b/pkg/agent/eventbus_test.go
@@ -511,8 +511,8 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
- turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1", nil)
- al.summarizeSession(defaultAgent, "session-1", turnScope)
+ lcm := &legacyContextManager{al: al}
+ lcm.summarizeSession(defaultAgent, "session-1")
events := collectEventStream(sub.C)
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
diff --git a/pkg/agent/events.go b/pkg/agent/events.go
index 6741d0053..f68d3eab5 100644
--- a/pkg/agent/events.go
+++ b/pkg/agent/events.go
@@ -167,6 +167,8 @@ const (
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
// ContextCompressReasonRetry indicates compression during context-error retry handling.
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
+ // ContextCompressReasonSummarize indicates post-turn async summarization.
+ ContextCompressReasonSummarize ContextCompressReason = "summarize"
)
// ContextCompressPayload describes a forced history compression.
diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go
index 880725660..48e5aa625 100644
--- a/pkg/agent/instance.go
+++ b/pkg/agent/instance.go
@@ -51,6 +51,10 @@ type AgentInstance struct {
// LightProvider is the concrete provider instance for the configured light model.
// It is only used when routing selects the light tier for a turn.
LightProvider providers.LLMProvider
+ // CandidateProviders maps "provider/model" keys to per-candidate LLMProvider
+ // instances. This allows each fallback model to use its own api_base and api_key
+ // from model_list, instead of inheriting the primary model's provider config.
+ CandidateProviders map[string]providers.LLMProvider
}
// NewAgentInstance creates an agent instance from config.
@@ -77,7 +81,12 @@ func NewAgentInstance(
if cfg.Tools.IsToolEnabled("read_file") {
maxReadFileSize := cfg.Tools.ReadFile.MaxReadFileSize
- toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
+ switch cfg.Tools.ReadFile.EffectiveMode() {
+ case config.ReadFileModeLines:
+ toolsRegistry.Register(tools.NewReadFileLinesTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
+ default:
+ toolsRegistry.Register(tools.NewReadFileBytesTool(workspace, readRestrict, maxReadFileSize, allowReadPaths))
+ }
}
if cfg.Tools.IsToolEnabled("write_file") {
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths))
@@ -170,6 +179,9 @@ func NewAgentInstance(
// Resolve fallback candidates
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
+ candidateProviders := make(map[string]providers.LLMProvider)
+ populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders)
+
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
@@ -194,6 +206,7 @@ func NewAgentInstance(
})
lightCandidates = resolved
lightProvider = lp
+ populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders)
}
}
} else {
@@ -225,6 +238,43 @@ func NewAgentInstance(
Router: router,
LightCandidates: lightCandidates,
LightProvider: lightProvider,
+ CandidateProviders: candidateProviders,
+ }
+}
+
+// populateCandidateProvidersFromNames resolves each model name (alias or
+// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider
+// for it. This reuses the canonical config resolution path (GetModelConfig) so
+// alias handling and load-balancing stay consistent with the rest of the codebase.
+func populateCandidateProvidersFromNames(
+ cfg *config.Config,
+ workspace string,
+ names []string,
+ out map[string]providers.LLMProvider,
+) {
+ if cfg == nil || len(names) == 0 {
+ return
+ }
+ for _, name := range names {
+ mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace)
+ if err != nil {
+ logger.WarnCF("agent",
+ "fallback provider: no model_list entry found; will inherit primary provider credentials",
+ map[string]any{"name": name, "error": err.Error()})
+ continue
+ }
+ protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model))
+ key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID)
+ if _, exists := out[key]; exists {
+ continue
+ }
+ p, _, err := providers.CreateProviderFromConfig(mc)
+ if err != nil {
+ logger.WarnCF("agent", "fallback provider: failed to create provider",
+ map[string]any{"model": mc.Model, "error": err.Error()})
+ continue
+ }
+ out[key] = p
}
}
diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go
index e296a18cb..8c71296ed 100644
--- a/pkg/agent/instance_test.go
+++ b/pkg/agent/instance_test.go
@@ -9,6 +9,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/providers"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -165,6 +166,58 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
}
}
+func TestNewAgentInstance_PreservesDistinctLimiterIdentityForSharedResolvedModel(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ ModelName: "glm-4.7",
+ ModelFallbacks: []string{"glm-4.7__key_1"},
+ },
+ },
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "glm-4.7",
+ Model: "zhipu/glm-4.7",
+ RPM: 1,
+ },
+ {
+ ModelName: "glm-4.7__key_1",
+ Model: "zhipu/glm-4.7",
+ RPM: 3,
+ },
+ },
+ }
+
+ agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
+ if len(agent.Candidates) != 2 {
+ t.Fatalf("len(Candidates) = %d, want 2", len(agent.Candidates))
+ }
+
+ first := agent.Candidates[0]
+ second := agent.Candidates[1]
+ if first.Provider != "zhipu" || first.Model != "glm-4.7" {
+ t.Fatalf("first candidate = %s/%s, want zhipu/glm-4.7", first.Provider, first.Model)
+ }
+ if second.Provider != "zhipu" || second.Model != "glm-4.7" {
+ t.Fatalf("second candidate = %s/%s, want zhipu/glm-4.7", second.Provider, second.Model)
+ }
+ if first.IdentityKey != "model_name:glm-4.7" {
+ t.Fatalf("first identity key = %q, want %q", first.IdentityKey, "model_name:glm-4.7")
+ }
+ if second.IdentityKey != "model_name:glm-4.7__key_1" {
+ t.Fatalf("second identity key = %q, want %q", second.IdentityKey, "model_name:glm-4.7__key_1")
+ }
+ if first.RPM != 1 {
+ t.Fatalf("first RPM = %d, want 1", first.RPM)
+ }
+ if second.RPM != 3 {
+ t.Fatalf("second RPM = %d, want 3", second.RPM)
+ }
+}
+
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
@@ -248,6 +301,240 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
}
}
+// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil
+// config does not panic and leaves the output map empty.
+func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) {
+ out := map[string]providers.LLMProvider{}
+ populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out)
+ if len(out) != 0 {
+ t.Fatalf("expected empty map, got %d entries", len(out))
+ }
+}
+
+// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already
+// present in the output map is not overwritten.
+func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) {
+ existing := &mockProvider{}
+ key := providers.ModelKey("openai", "gpt-4o")
+ out := map[string]providers.LLMProvider{key: existing}
+
+ cfg := &config.Config{
+ ModelList: []*config.ModelConfig{
+ {ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")},
+ },
+ }
+ populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out)
+
+ if out[key] != existing {
+ t.Fatal("existing provider entry was overwritten; expected it to be preserved")
+ }
+}
+
+// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name
+// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider
+// is created using the underlying model's config.
+func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) {
+ workspace := t.TempDir()
+ out := map[string]providers.LLMProvider{}
+
+ cfg := &config.Config{
+ ModelList: []*config.ModelConfig{
+ {ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace},
+ },
+ }
+ populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out)
+
+ key := providers.ModelKey("openai", "gpt-4o")
+ if out[key] == nil {
+ t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key)
+ }
+}
+
+// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a
+// model_list entry using full "provider/model" notation (e.g.
+// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name.
+func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) {
+ workspace := t.TempDir()
+ out := map[string]providers.LLMProvider{}
+
+ cfg := &config.Config{
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "gemma",
+ Model: "gemini/gemma-3-27b-it",
+ APIKeys: config.SimpleSecureStrings("gemini-test-key"),
+ Workspace: workspace,
+ },
+ },
+ }
+ populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out)
+
+ key := providers.ModelKey("gemini", "gemma-3-27b-it")
+ if out[key] == nil {
+ t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key)
+ }
+}
+
+// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit
+// path when the names slice is empty.
+func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) {
+ out := map[string]providers.LLMProvider{}
+ cfg := &config.Config{
+ ModelList: []*config.ModelConfig{
+ {ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
+ },
+ }
+ populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out)
+ if len(out) != 0 {
+ t.Fatalf("expected empty map, got %d entries", len(out))
+ }
+}
+
+// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit
+// path when model_list is empty — no provider can be created.
+func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) {
+ out := map[string]providers.LLMProvider{}
+ cfg := &config.Config{}
+ populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out)
+ if len(out) != 0 {
+ t.Fatalf("expected empty map, got %d entries", len(out))
+ }
+}
+
+// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a
+// name with no matching model_list entry is skipped and does not
+// cause a panic or leave a nil entry in the map.
+func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) {
+ out := map[string]providers.LLMProvider{}
+ cfg := &config.Config{
+ ModelList: []*config.ModelConfig{
+ {ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
+ },
+ }
+ populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out)
+
+ if len(out) != 0 {
+ t.Fatalf("expected empty map for unmatched name, got %d entries", len(out))
+ }
+}
+
+// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks
+// mirrors the exact scenario from bug #2140: primary model on OpenRouter with
+// Gemini fallbacks. Each entry must get its own provider instance so that
+// fallback requests go to the correct API endpoint, not the primary's.
+func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) {
+ workspace := t.TempDir()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: workspace,
+ ModelName: "mistral-small-3.1",
+ ModelFallbacks: []string{"gemma-3-27b", "gemini-images"},
+ },
+ },
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "mistral-small-3.1",
+ Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free",
+ APIBase: "https://openrouter.ai/api/v1",
+ APIKeys: config.SimpleSecureStrings("sk-or-test"),
+ Workspace: workspace,
+ },
+ {
+ ModelName: "gemma-3-27b",
+ Model: "gemini/gemma-3-27b-it",
+ APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
+ Workspace: workspace,
+ },
+ {
+ ModelName: "gemini-images",
+ Model: "gemini/gemini-2.5-flash-lite",
+ APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
+ Workspace: workspace,
+ },
+ },
+ }
+
+ primaryProvider := &mockProvider{}
+ agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider)
+
+ // Only fallback models need entries — the primary uses the injected provider directly.
+ wantKeys := []string{
+ providers.ModelKey("gemini", "gemma-3-27b-it"),
+ providers.ModelKey("gemini", "gemini-2.5-flash-lite"),
+ }
+
+ for _, key := range wantKeys {
+ p, ok := agent.CandidateProviders[key]
+ if !ok {
+ t.Errorf("CandidateProviders missing key %q", key)
+ continue
+ }
+ if p == nil {
+ t.Errorf("CandidateProviders[%q] is nil", key)
+ }
+ // Each fallback must use its own provider, not the injected primary.
+ if p == primaryProvider {
+ t.Errorf(
+ "CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials",
+ key,
+ )
+ }
+ }
+
+ if t.Failed() {
+ t.Logf("CandidateProviders keys present: %v", func() []string {
+ keys := make([]string, 0, len(agent.CandidateProviders))
+ for k := range agent.CandidateProviders {
+ keys = append(keys, k)
+ }
+ return keys
+ }())
+ }
+}
+
+func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) {
+ workspace := t.TempDir()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: workspace,
+ ModelName: "test-model",
+ },
+ },
+ Tools: config.ToolsConfig{
+ ReadFile: config.ReadFileToolConfig{
+ Enabled: true,
+ Mode: config.ReadFileModeLines,
+ MaxReadFileSize: 4096,
+ },
+ },
+ }
+
+ agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
+ readTool, ok := agent.Tools.Get("read_file")
+ if !ok {
+ t.Fatal("read_file tool not registered")
+ }
+
+ params := readTool.Parameters()
+ props, _ := params["properties"].(map[string]any)
+ if _, ok := props["start_line"]; !ok {
+ t.Fatalf("expected line-mode schema to expose start_line, got %#v", props)
+ }
+ if _, ok := props["max_lines"]; !ok {
+ t.Fatalf("expected line-mode schema to expose max_lines, got %#v", props)
+ }
+ if _, ok := props["offset"]; ok {
+ t.Fatalf("did not expect line-mode schema to expose offset, got %#v", props)
+ }
+ if _, ok := props["length"]; ok {
+ t.Fatalf("did not expect line-mode schema to expose length, got %#v", props)
+ }
+}
+
func TestNewAgentInstance_InvalidExecConfigDoesNotExit(t *testing.T) {
workspace := t.TempDir()
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index a7dcb0b9f..4b75f6e1b 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -18,6 +18,8 @@ import (
"sync/atomic"
"time"
+ "github.com/sipeed/picoclaw/pkg/audio/asr"
+ "github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/commands"
@@ -32,7 +34,6 @@ import (
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
- "github.com/sipeed/picoclaw/pkg/voice"
)
type AgentLoop struct {
@@ -48,11 +49,11 @@ type AgentLoop struct {
// Runtime state
running atomic.Bool
- summarizing sync.Map
+ contextManager ContextManager
fallback *providers.FallbackChain
channelManager *channels.Manager
mediaStore media.MediaStore
- transcriber voice.Transcriber
+ transcriber asr.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
hookRuntime hookRuntime
@@ -116,9 +117,18 @@ func NewAgentLoop(
) *AgentLoop {
registry := NewAgentRegistry(cfg, provider)
- // Set up shared fallback chain
+ // Set up shared fallback chain with rate limiting.
cooldown := providers.NewCooldownTracker()
- fallbackChain := providers.NewFallbackChain(cooldown)
+ rl := providers.NewRateLimiterRegistry()
+ // Register rate limiters for all agents' candidates so that RPM limits
+ // configured in ModelConfig are enforced before each LLM call.
+ for _, agentID := range registry.ListAgentIDs() {
+ if agent, ok := registry.GetAgent(agentID); ok {
+ rl.RegisterCandidates(agent.Candidates)
+ rl.RegisterCandidates(agent.LightCandidates)
+ }
+ }
+ fallbackChain := providers.NewFallbackChain(cooldown, rl)
// Create state manager using default agent's workspace for channel recording
defaultAgent := registry.GetDefaultAgent()
@@ -134,13 +144,13 @@ func NewAgentLoop(
registry: registry,
state: stateManager,
eventBus: eventBus,
- summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
al.hooks = NewHookManager(eventBus)
configureHookManagerFromConfig(al.hooks, cfg)
+ al.contextManager = al.resolveContextManager()
// Register shared tools to all agents (now that al is created)
registerSharedTools(al, cfg, msgBus, registry, provider)
@@ -157,6 +167,13 @@ func registerSharedTools(
provider providers.LLMProvider,
) {
allowReadPaths := buildAllowReadPatterns(cfg)
+ var ttsProvider tts.TTSProvider
+ if cfg.Tools.IsToolEnabled("send_tts") {
+ ttsProvider = tts.DetectTTS(cfg)
+ if ttsProvider == nil {
+ logger.WarnCF("voice-tts", "send_tts enabled but no TTS provider configured", nil)
+ }
+ }
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
@@ -267,6 +284,21 @@ func registerSharedTools(
agent.Tools.Register(sendFileTool)
}
+ if ttsProvider != nil {
+ agent.Tools.Register(tools.NewSendTTSTool(ttsProvider, nil))
+ }
+
+ if cfg.Tools.IsToolEnabled("load_image") {
+ loadImageTool := tools.NewLoadImageTool(
+ agent.Workspace,
+ cfg.Agents.Defaults.RestrictToWorkspace,
+ cfg.Agents.Defaults.GetMaxMediaSize(),
+ nil,
+ allowReadPaths,
+ )
+ agent.Tools.Register(loadImageTool)
+ }
+
// Skill discovery and installation tools
skills_enabled := cfg.Tools.IsToolEnabled("skills")
find_skills_enable := cfg.Tools.IsToolEnabled("find_skills")
@@ -309,6 +341,14 @@ func registerSharedTools(
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
+ // Inject a media resolver so the legacy RunToolLoop fallback path can
+ // resolve media:// refs in the same way the main AgentLoop does.
+ // This keeps subagent vision support working even when the optimized
+ // sub-turn spawner path is unavailable.
+ subagentManager.SetMediaResolver(func(msgs []providers.Message) []providers.Message {
+ return resolveMediaRefs(msgs, al.mediaStore, cfg.Agents.Defaults.GetMaxMediaSize())
+ })
+
// Set the spawner that links into AgentLoop's turnState
subagentManager.SetSpawner(func(
ctx context.Context,
@@ -1075,6 +1115,7 @@ func (al *AgentLoop) ReloadProviderAndConfig(
go func() {
defer func() {
if r := recover(); r != nil {
+ logger.RecoverPanicNoExit(r)
panicErr = fmt.Errorf("panic during registry creation: %v", r)
logger.ErrorCF("agent", "Panic during registry creation",
map[string]any{"panic": r})
@@ -1115,8 +1156,15 @@ func (al *AgentLoop) ReloadProviderAndConfig(
al.cfg = cfg
al.registry = registry
- // Also update fallback chain with new config
- al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
+ // Also update fallback chain with new config; rebuild rate limiter registry.
+ newRL := providers.NewRateLimiterRegistry()
+ for _, agentID := range registry.ListAgentIDs() {
+ if agent, ok := registry.GetAgent(agentID); ok {
+ newRL.RegisterCandidates(agent.Candidates)
+ newRL.RegisterCandidates(agent.LightCandidates)
+ }
+ }
+ al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), newRL)
al.mu.Unlock()
@@ -1174,10 +1222,15 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
agent.Tools.SetMediaStore(s)
}
}
+ registry.ForEachTool("send_tts", func(t tools.Tool) {
+ if st, ok := t.(*tools.SendTTSTool); ok {
+ st.SetMediaStore(s)
+ }
+ })
}
// SetTranscriber injects a voice transcriber for agent-level audio transcription.
-func (al *AgentLoop) SetTranscriber(t voice.Transcriber) {
+func (al *AgentLoop) SetTranscriber(t asr.Transcriber) {
al.transcriber = t
}
@@ -1198,19 +1251,23 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
// Transcribe each audio media ref in order.
var transcriptions []string
+ var keptMedia []string
for _, ref := range msg.Media {
path, meta, err := al.mediaStore.ResolveWithMeta(ref)
if err != nil {
logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err})
+ keptMedia = append(keptMedia, ref)
continue
}
if !utils.IsAudioFile(meta.Filename, meta.ContentType) {
+ keptMedia = append(keptMedia, ref)
continue
}
result, err := al.transcriber.Transcribe(ctx, path)
if err != nil {
logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err})
transcriptions = append(transcriptions, "")
+ keptMedia = append(keptMedia, ref)
continue
}
transcriptions = append(transcriptions, result.Text)
@@ -1230,15 +1287,21 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou
}
text := transcriptions[idx]
idx++
+ if text == "" {
+ return match
+ }
return "[voice: " + text + "]"
})
// Append any remaining transcriptions not matched by an annotation.
for ; idx < len(transcriptions); idx++ {
- newContent += "\n[voice: " + transcriptions[idx] + "]"
+ if transcriptions[idx] != "" {
+ newContent += "\n[voice: " + transcriptions[idx] + "]"
+ }
}
msg.Content = newContent
+ msg.Media = keptMedia
return msg, true
}
@@ -1825,8 +1888,15 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
var history []providers.Message
var summary string
if !ts.opts.NoHistory {
- history = ts.agent.Sessions.GetHistory(ts.sessionKey)
- summary = ts.agent.Sessions.GetSummary(ts.sessionKey)
+ // ContextManager assembles budget-aware history and summary.
+ if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
+ SessionKey: ts.sessionKey,
+ Budget: ts.agent.ContextWindow,
+ MaxTokens: ts.agent.MaxTokens,
+ }); err == nil && resp != nil {
+ history = resp.History
+ summary = resp.Summary
+ }
}
ts.captureRestorePoint(history, summary)
@@ -1851,22 +1921,28 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) {
logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
map[string]any{"session_key": ts.sessionKey})
- if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
- al.emitEvent(
- EventKindContextCompress,
- ts.eventMeta("runTurn", "turn.context.compress"),
- ContextCompressPayload{
- Reason: ContextCompressReasonProactive,
- DroppedMessages: compression.DroppedMessages,
- RemainingMessages: compression.RemainingMessages,
- },
- )
- ts.refreshRestorePointFromSession(ts.agent)
+ if err := al.contextManager.Compact(turnCtx, &CompactRequest{
+ SessionKey: ts.sessionKey,
+ Reason: ContextCompressReasonProactive,
+ Budget: ts.agent.ContextWindow,
+ }); err != nil {
+ logger.WarnCF("agent", "Proactive compact failed", map[string]any{
+ "session_key": ts.sessionKey,
+ "error": err.Error(),
+ })
+ }
+ ts.refreshRestorePointFromSession(ts.agent)
+ // Re-assemble from CM after compact.
+ if resp, err := al.contextManager.Assemble(turnCtx, &AssembleRequest{
+ SessionKey: ts.sessionKey,
+ Budget: ts.agent.ContextWindow,
+ MaxTokens: ts.agent.MaxTokens,
+ }); err == nil && resp != nil {
+ history = resp.History
+ summary = resp.Summary
}
- newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
- newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
messages = ts.agent.ContextBuilder.BuildMessages(
- newHistory, newSummary, ts.userMessage,
+ history, summary, ts.userMessage,
ts.media, ts.channel, ts.chatID,
ts.opts.SenderID, ts.opts.SenderDisplayName,
activeSkillNames(ts.agent, ts.opts)...,
@@ -1888,6 +1964,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
}
ts.recordPersistedMessage(rootMsg)
+ ts.ingestMessage(turnCtx, al, rootMsg)
}
activeCandidates, activeModel, usedLight := al.selectCandidates(ts.agent, ts.userMessage, messages)
@@ -1963,6 +2040,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
ts.recordPersistedMessage(pm)
+ ts.ingestMessage(turnCtx, al, pm)
}
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
@@ -2016,6 +2094,14 @@ turnLoop:
providerToolDefs = filtered
}
+ // Resolve media:// refs produced by tool results (e.g. load_image).
+ // Skipped on iteration 1 because inbound user media is already resolved
+ // before entering the loop; only subsequent iterations can contain new
+ // tool-generated media refs that need base64 encoding.
+ if iteration > 1 {
+ messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
+ }
+
callMessages := messages
if gracefulTerminal {
callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage())
@@ -2115,7 +2201,11 @@ turnLoop:
providerCtx,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
- return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
+ candidateProvider := activeProvider
+ if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
+ candidateProvider = cp
+ }
+ return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
},
)
if fbErr != nil {
@@ -2221,23 +2311,28 @@ turnLoop:
))
}
- if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
- al.emitEvent(
- EventKindContextCompress,
- ts.eventMeta("runTurn", "turn.context.compress"),
- ContextCompressPayload{
- Reason: ContextCompressReasonRetry,
- DroppedMessages: compression.DroppedMessages,
- RemainingMessages: compression.RemainingMessages,
- },
- )
- ts.refreshRestorePointFromSession(ts.agent)
+ if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
+ SessionKey: ts.sessionKey,
+ Reason: ContextCompressReasonRetry,
+ Budget: ts.agent.ContextWindow,
+ }); compactErr != nil {
+ logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
+ "session_key": ts.sessionKey,
+ "error": compactErr.Error(),
+ })
+ }
+ ts.refreshRestorePointFromSession(ts.agent)
+ // Re-assemble from CM after compact.
+ if asmResp, asmErr := al.contextManager.Assemble(turnCtx, &AssembleRequest{
+ SessionKey: ts.sessionKey,
+ Budget: ts.agent.ContextWindow,
+ MaxTokens: ts.agent.MaxTokens,
+ }); asmErr == nil && asmResp != nil {
+ history = asmResp.History
+ summary = asmResp.Summary
}
-
- newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
- newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
messages = ts.agent.ContextBuilder.BuildMessages(
- newHistory, newSummary, "",
+ history, summary, "",
nil, ts.channel, ts.chatID, ts.opts.SenderID, ts.opts.SenderDisplayName,
activeSkillNames(ts.agent, ts.opts)...,
)
@@ -2409,6 +2504,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg)
ts.recordPersistedMessage(assistantMsg)
+ ts.ingestMessage(turnCtx, al, assistantMsg)
}
ts.setPhase(TurnPhaseTools)
@@ -2633,6 +2729,7 @@ turnLoop:
if toolResult == nil {
toolResult = tools.ErrorResult("hook returned nil tool result")
}
+
if len(toolResult.Media) > 0 && toolResult.ResponseHandled {
parts := make([]bus.MediaPart, 0, len(toolResult.Media))
for _, ref := range toolResult.Media {
@@ -2675,6 +2772,13 @@ turnLoop:
}
if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
+ // For tools like load_image that produce media refs without sending them
+ // to the user channel (ResponseHandled == false), both Media and ArtifactTags
+ // coexist on the result:
+ // - Media: carries media:// refs that resolveMediaRefs will base64-encode
+ // into image_url parts in the next LLM iteration (enabling vision).
+ // - ArtifactTags: exposes the local file path as a structured [file:…] tag
+ // in the tool result text, so the LLM knows an artifact was produced.
toolResult.ArtifactTags = buildArtifactTags(al.mediaStore, toolResult.Media)
}
@@ -2693,7 +2797,6 @@ turnLoop:
"content_len": len(toolResult.ForUser),
})
}
-
contentForLLM := toolResult.ContentForLLM()
// Filter sensitive data (API keys, tokens, secrets) before sending to LLM
@@ -2706,6 +2809,9 @@ turnLoop:
Content: contentForLLM,
ToolCallID: toolCallID,
}
+ if len(toolResult.Media) > 0 && !toolResult.ResponseHandled {
+ toolResultMsg.Media = append(toolResultMsg.Media, toolResult.Media...)
+ }
al.emitEvent(
EventKindToolExecEnd,
ts.eventMeta("runTurn", "turn.tool.end"),
@@ -2722,6 +2828,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
ts.recordPersistedMessage(toolResultMsg)
+ ts.ingestMessage(turnCtx, al, toolResultMsg)
}
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
@@ -2821,6 +2928,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddMessage(ts.sessionKey, summaryMsg.Role, summaryMsg.Content)
ts.recordPersistedMessage(summaryMsg)
+ ts.ingestMessage(turnCtx, al, summaryMsg)
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
turnStatus = TurnEndStatusError
al.emitEvent(
@@ -2835,7 +2943,7 @@ turnLoop:
}
}
if ts.opts.EnableSummary {
- al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
+ al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
}
ts.setPhase(TurnPhaseCompleted)
@@ -2890,6 +2998,7 @@ turnLoop:
finalMsg := providers.Message{Role: "assistant", Content: finalContent}
ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content)
ts.recordPersistedMessage(finalMsg)
+ ts.ingestMessage(turnCtx, al, finalMsg)
if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
turnStatus = TurnEndStatusError
al.emitEvent(
@@ -2905,7 +3014,14 @@ turnLoop:
}
if ts.opts.EnableSummary {
- al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
+ al.contextManager.Compact(
+ turnCtx,
+ &CompactRequest{
+ SessionKey: ts.sessionKey,
+ Reason: ContextCompressReasonSummarize,
+ Budget: ts.agent.ContextWindow,
+ },
+ )
}
ts.setPhase(TurnPhaseCompleted)
@@ -2984,103 +3100,28 @@ func (al *AgentLoop) selectCandidates(
return agent.LightCandidates, resolvedCandidateModel(agent.LightCandidates, agent.Router.LightModel()), true
}
-// maybeSummarize triggers summarization if the session history exceeds thresholds.
-func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
- newHistory := agent.Sessions.GetHistory(sessionKey)
- tokenEstimate := al.estimateTokens(newHistory)
- threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
-
- if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold {
- summarizeKey := agent.ID + ":" + sessionKey
- if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
- go func() {
- defer al.summarizing.Delete(summarizeKey)
- logger.Debug("Memory threshold reached. Optimizing conversation history...")
- al.summarizeSession(agent, sessionKey, turnScope)
- }()
- }
+// resolveContextManager selects the ContextManager implementation based on config.
+func (al *AgentLoop) resolveContextManager() ContextManager {
+ name := al.cfg.Agents.Defaults.ContextManager
+ if name == "" || name == "legacy" {
+ return &legacyContextManager{al: al}
}
-}
-
-type compressionResult struct {
- DroppedMessages int
- RemainingMessages int
-}
-
-// forceCompression aggressively reduces context when the limit is hit.
-// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
-// cycle, as defined in #1316), so tool-call sequences are never split.
-//
-// If the history is a single Turn with no safe split point, the function
-// falls back to keeping only the most recent user message. This breaks
-// Turn atomicity as a last resort to avoid a context-exceeded loop.
-//
-// Session history contains only user/assistant/tool messages — the system
-// prompt is built dynamically by BuildMessages and is NOT stored here.
-// The compression note is recorded in the session summary so that
-// BuildMessages can include it in the next system prompt.
-func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) {
- history := agent.Sessions.GetHistory(sessionKey)
- if len(history) <= 2 {
- return compressionResult{}, false
+ factory, ok := lookupContextManager(name)
+ if !ok {
+ logger.WarnCF("agent", "Unknown context manager, falling back to legacy", map[string]any{
+ "name": name,
+ })
+ return &legacyContextManager{al: al}
}
-
- // Split at a Turn boundary so no tool-call sequence is torn apart.
- // parseTurnBoundaries gives us the start of each Turn; we drop the
- // oldest half of Turns and keep the most recent ones.
- turns := parseTurnBoundaries(history)
- var mid int
- if len(turns) >= 2 {
- mid = turns[len(turns)/2]
- } else {
- // Fewer than 2 Turns — fall back to message-level midpoint
- // aligned to the nearest Turn boundary.
- mid = findSafeBoundary(history, len(history)/2)
+ cm, err := factory(al.cfg.Agents.Defaults.ContextManagerConfig, al)
+ if err != nil {
+ logger.WarnCF("agent", "Failed to create context manager, falling back to legacy", map[string]any{
+ "name": name,
+ "error": err.Error(),
+ })
+ return &legacyContextManager{al: al}
}
- var keptHistory []providers.Message
- if mid <= 0 {
- // No safe Turn boundary — the entire history is a single Turn
- // (e.g. one user message followed by a massive tool response).
- // Keeping everything would leave the agent stuck in a context-
- // exceeded loop, so fall back to keeping only the most recent
- // user message. This breaks Turn atomicity as a last resort.
- for i := len(history) - 1; i >= 0; i-- {
- if history[i].Role == "user" {
- keptHistory = []providers.Message{history[i]}
- break
- }
- }
- } else {
- keptHistory = history[mid:]
- }
-
- droppedCount := len(history) - len(keptHistory)
-
- // Record compression in the session summary so BuildMessages includes it
- // in the system prompt. We do not modify history messages themselves.
- existingSummary := agent.Sessions.GetSummary(sessionKey)
- compressionNote := fmt.Sprintf(
- "[Emergency compression dropped %d oldest messages due to context limit]",
- droppedCount,
- )
- if existingSummary != "" {
- compressionNote = existingSummary + "\n\n" + compressionNote
- }
- agent.Sessions.SetSummary(sessionKey, compressionNote)
-
- agent.Sessions.SetHistory(sessionKey, keptHistory)
- agent.Sessions.Save(sessionKey)
-
- logger.WarnCF("agent", "Forced compression executed", map[string]any{
- "session_key": sessionKey,
- "dropped_msgs": droppedCount,
- "new_count": len(keptHistory),
- })
-
- return compressionResult{
- DroppedMessages: droppedCount,
- RemainingMessages: len(keptHistory),
- }, true
+ return cm
}
// GetStartupInfo returns information about loaded tools and skills for logging.
@@ -3172,247 +3213,13 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
}
// summarizeSession summarizes the conversation history for a session.
-func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
- ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
- defer cancel()
-
- history := agent.Sessions.GetHistory(sessionKey)
- summary := agent.Sessions.GetSummary(sessionKey)
-
- // Keep the most recent Turns for continuity, aligned to a Turn boundary
- // so that no tool-call sequence is split.
- if len(history) <= 4 {
- return
- }
-
- safeCut := findSafeBoundary(history, len(history)-4)
- if safeCut <= 0 {
- return
- }
- keepCount := len(history) - safeCut
- toSummarize := history[:safeCut]
-
- // Oversized Message Guard
- maxMessageTokens := agent.ContextWindow / 2
- validMessages := make([]providers.Message, 0)
- omitted := false
-
- for _, m := range toSummarize {
- if m.Role != "user" && m.Role != "assistant" {
- continue
- }
- msgTokens := len(m.Content) / 2
- if msgTokens > maxMessageTokens {
- omitted = true
- continue
- }
- validMessages = append(validMessages, m)
- }
-
- if len(validMessages) == 0 {
- return
- }
-
- const (
- maxSummarizationMessages = 10
- llmMaxRetries = 3
- llmTemperature = 0.3
- fallbackMaxContentLength = 200
- )
-
- // Multi-Part Summarization
- var finalSummary string
- if len(validMessages) > maxSummarizationMessages {
- mid := len(validMessages) / 2
-
- mid = al.findNearestUserMessage(validMessages, mid)
-
- part1 := validMessages[:mid]
- part2 := validMessages[mid:]
-
- s1, _ := al.summarizeBatch(ctx, agent, part1, "")
- s2, _ := al.summarizeBatch(ctx, agent, part2, "")
-
- mergePrompt := fmt.Sprintf(
- "Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
- s1,
- s2,
- )
-
- resp, err := al.retryLLMCall(ctx, agent, mergePrompt, llmMaxRetries)
- if err == nil && resp.Content != "" {
- finalSummary = resp.Content
- } else {
- finalSummary = s1 + " " + s2
- }
- } else {
- finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary)
- }
-
- if omitted && finalSummary != "" {
- finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
- }
-
- if finalSummary != "" {
- agent.Sessions.SetSummary(sessionKey, finalSummary)
- agent.Sessions.TruncateHistory(sessionKey, keepCount)
- agent.Sessions.Save(sessionKey)
- al.emitEvent(
- EventKindSessionSummarize,
- turnScope.meta(0, "summarizeSession", "turn.session.summarize"),
- SessionSummarizePayload{
- SummarizedMessages: len(validMessages),
- KeptMessages: keepCount,
- SummaryLen: len(finalSummary),
- OmittedOversized: omitted,
- },
- )
- }
-}
-
// findNearestUserMessage finds the nearest user message to the given index.
// It searches backward first, then forward if no user message is found.
-func (al *AgentLoop) findNearestUserMessage(messages []providers.Message, mid int) int {
- originalMid := mid
-
- for mid > 0 && messages[mid].Role != "user" {
- mid--
- }
-
- if messages[mid].Role == "user" {
- return mid
- }
-
- mid = originalMid
- for mid < len(messages) && messages[mid].Role != "user" {
- mid++
- }
-
- if mid < len(messages) {
- return mid
- }
-
- return originalMid
-}
-
// retryLLMCall calls the LLM with retry logic.
-func (al *AgentLoop) retryLLMCall(
- ctx context.Context,
- agent *AgentInstance,
- prompt string,
- maxRetries int,
-) (*providers.LLMResponse, error) {
- const (
- llmTemperature = 0.3
- )
-
- var resp *providers.LLMResponse
- var err error
-
- for attempt := 0; attempt < maxRetries; attempt++ {
- al.activeRequests.Add(1)
- resp, err = func() (*providers.LLMResponse, error) {
- defer al.activeRequests.Done()
- return agent.Provider.Chat(
- ctx,
- []providers.Message{{Role: "user", Content: prompt}},
- nil,
- agent.Model,
- map[string]any{
- "max_tokens": agent.MaxTokens,
- "temperature": llmTemperature,
- "prompt_cache_key": agent.ID,
- },
- )
- }()
-
- if err == nil && resp != nil && resp.Content != "" {
- return resp, nil
- }
- if attempt < maxRetries-1 {
- time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
- }
- }
-
- return resp, err
-}
-
// summarizeBatch summarizes a batch of messages.
-func (al *AgentLoop) summarizeBatch(
- ctx context.Context,
- agent *AgentInstance,
- batch []providers.Message,
- existingSummary string,
-) (string, error) {
- const (
- llmMaxRetries = 3
- llmTemperature = 0.3
- fallbackMinContentLength = 200
- fallbackMaxContentPercent = 10
- )
-
- var sb strings.Builder
- sb.WriteString(
- "Provide a concise summary of this conversation segment, preserving core context and key points.\n",
- )
- if existingSummary != "" {
- sb.WriteString("Existing context: ")
- sb.WriteString(existingSummary)
- sb.WriteString("\n")
- }
- sb.WriteString("\nCONVERSATION:\n")
- for _, m := range batch {
- fmt.Fprintf(&sb, "%s: %s\n", m.Role, m.Content)
- }
- prompt := sb.String()
-
- response, err := al.retryLLMCall(ctx, agent, prompt, llmMaxRetries)
- if err == nil && response.Content != "" {
- return strings.TrimSpace(response.Content), nil
- }
-
- var fallback strings.Builder
- fallback.WriteString("Conversation summary: ")
- for i, m := range batch {
- if i > 0 {
- fallback.WriteString(" | ")
- }
- content := strings.TrimSpace(m.Content)
- runes := []rune(content)
- if len(runes) == 0 {
- fallback.WriteString(fmt.Sprintf("%s: ", m.Role))
- continue
- }
-
- keepLength := len(runes) * fallbackMaxContentPercent / 100
- if keepLength < fallbackMinContentLength {
- keepLength = fallbackMinContentLength
- }
-
- if keepLength > len(runes) {
- keepLength = len(runes)
- }
-
- content = string(runes[:keepLength])
- if keepLength < len(runes) {
- content += "..."
- }
- fallback.WriteString(fmt.Sprintf("%s: %s", m.Role, content))
- }
- return fallback.String(), nil
-}
-
// estimateTokens estimates the number of tokens in a message list.
// Counts Content, ToolCalls arguments, and ToolCallID metadata so that
// tool-heavy conversations are not systematically undercounted.
-func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
- total := 0
- for _, m := range messages {
- total += estimateMessageTokens(m)
- }
- return total
-}
-
func (al *AgentLoop) handleCommand(
ctx context.Context,
msg bus.InboundMessage,
@@ -3609,7 +3416,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
return "", fmt.Errorf("failed to initialize model %q: %w", value, err)
}
- nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, modelCfg.Model, agent.Fallbacks)
+ nextCandidates := resolveModelCandidates(cfg, cfg.Agents.Defaults.Provider, value, agent.Fallbacks)
if len(nextCandidates) == 0 {
return "", fmt.Errorf("model %q did not resolve to any provider candidates", value)
}
diff --git a/pkg/agent/loop_mcp.go b/pkg/agent/loop_mcp.go
index 97debbc33..b9c844d1a 100644
--- a/pkg/agent/loop_mcp.go
+++ b/pkg/agent/loop_mcp.go
@@ -126,6 +126,8 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
+ mcpTool.SetWorkspace(agent.Workspace)
+ mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
if registerAsHidden {
agent.Tools.RegisterHidden(mcpTool)
diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go
index b544ffb4f..127ff64b3 100644
--- a/pkg/agent/loop_test.go
+++ b/pkg/agent/loop_test.go
@@ -2132,6 +2132,162 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
}
}
+// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
+// bug #2140. It verifies that when the primary model returns a rate-limit error
+// the fallback closure routes the retry to the fallback model's own provider
+// (its own api_base), not back to the primary provider's endpoint.
+func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
+ workspace := t.TempDir()
+
+ primaryCalls := 0
+ primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ primaryCalls++
+ // Return 429 so FallbackChain classifies this as retriable and moves on.
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusTooManyRequests)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "error": map[string]any{
+ "message": "rate limit exceeded",
+ "type": "rate_limit_error",
+ },
+ })
+ }))
+ defer primaryServer.Close()
+
+ fallbackCalls := 0
+ fallbackServer := newStrictChatCompletionTestServer(
+ t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
+ )
+ defer fallbackServer.Close()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: workspace,
+ ModelName: "mistral-primary",
+ ModelFallbacks: []string{"gemma-fallback"},
+ MaxTokens: 4096,
+ MaxToolIterations: 3,
+ },
+ },
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "mistral-primary",
+ Model: "openrouter/mistralai/mistral-small-3.1",
+ APIBase: primaryServer.URL,
+ APIKeys: config.SimpleSecureStrings("primary-key"),
+ Workspace: workspace,
+ },
+ {
+ ModelName: "gemma-fallback",
+ Model: "gemini/gemma-3-27b-it",
+ APIBase: fallbackServer.URL,
+ APIKeys: config.SimpleSecureStrings("fallback-key"),
+ Workspace: workspace,
+ },
+ },
+ }
+
+ provider, _, err := providers.CreateProvider(cfg)
+ if err != nil {
+ t.Fatalf("CreateProvider() error = %v", err)
+ }
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+ helper := testHelper{al: al}
+
+ resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
+ Channel: "telegram",
+ SenderID: "user1",
+ ChatID: "chat1",
+ Content: "hi",
+ })
+
+ if resp != "fallback reply" {
+ t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
+ }
+ if primaryCalls == 0 {
+ t.Fatal("primary server was never called; expected at least one attempt")
+ }
+ if fallbackCalls != 1 {
+ t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
+ }
+}
+
+// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
+// that when a candidate has no model_list entry it is absent from CandidateProviders
+// and the fallback closure falls back to activeProvider instead of panicking.
+func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
+ workspace := t.TempDir()
+
+ // Primary server: returns 429 on first call, succeeds on second.
+ // Both the primary and the unregistered fallback share this server
+ // (same api_base) so activeProvider routes both calls here.
+ callCount := 0
+ primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ callCount++
+ w.Header().Set("Content-Type", "application/json")
+ if callCount == 1 {
+ w.WriteHeader(http.StatusTooManyRequests)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
+ })
+ return
+ }
+ // Second call (fallback via activeProvider) succeeds.
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "choices": []map[string]any{
+ {"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
+ },
+ })
+ }))
+ defer primaryServer.Close()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: workspace,
+ ModelName: "primary-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 3,
+ // No model_list entry for this alias — absent from CandidateProviders.
+ ModelFallbacks: []string{"openrouter/fallback-model"},
+ },
+ },
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "primary-model",
+ Model: "openrouter/primary-model",
+ APIBase: primaryServer.URL,
+ APIKeys: config.SimpleSecureStrings("primary-key"),
+ Workspace: workspace,
+ },
+ },
+ }
+
+ provider, _, err := providers.CreateProvider(cfg)
+ if err != nil {
+ t.Fatalf("CreateProvider() error = %v", err)
+ }
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+
+ helper := testHelper{al: al}
+ resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
+ Channel: "telegram",
+ SenderID: "user1",
+ ChatID: "chat1",
+ Content: "hi",
+ })
+
+ if resp != "active provider reply" {
+ t.Fatalf("response = %q, want %q", resp, "active provider reply")
+ }
+ if callCount < 2 {
+ t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
+ }
+}
+
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
diff --git a/pkg/agent/model_resolution.go b/pkg/agent/model_resolution.go
index 140cff718..7cbf3a8d6 100644
--- a/pkg/agent/model_resolution.go
+++ b/pkg/agent/model_resolution.go
@@ -8,44 +8,102 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
-func buildModelListResolver(cfg *config.Config) func(raw string) (string, bool) {
- ensureProtocol := func(model string) string {
- model = strings.TrimSpace(model)
- if model == "" {
- return ""
- }
- if strings.Contains(model, "/") {
- return model
- }
- return "openai/" + model
+func ensureProtocolModel(model string) string {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ return ""
+ }
+ if strings.Contains(model, "/") {
+ return model
+ }
+ return "openai/" + model
+}
+
+func modelConfigIdentityKey(mc *config.ModelConfig) string {
+ if mc == nil {
+ return ""
+ }
+ if name := strings.TrimSpace(mc.ModelName); name != "" {
+ return "model_name:" + name
+ }
+ return ""
+}
+
+func candidateFromModelConfig(
+ defaultProvider string,
+ mc *config.ModelConfig,
+) (providers.FallbackCandidate, bool) {
+ if mc == nil {
+ return providers.FallbackCandidate{}, false
}
- return func(raw string) (string, bool) {
- raw = strings.TrimSpace(raw)
- if raw == "" || cfg == nil {
- return "", false
- }
-
- if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
- return ensureProtocol(mc.Model), true
- }
-
- for i := range cfg.ModelList {
- fullModel := strings.TrimSpace(cfg.ModelList[i].Model)
- if fullModel == "" {
- continue
- }
- if fullModel == raw {
- return ensureProtocol(fullModel), true
- }
- _, modelID := providers.ExtractProtocol(fullModel)
- if modelID == raw {
- return ensureProtocol(fullModel), true
- }
- }
-
- return "", false
+ ref := providers.ParseModelRef(ensureProtocolModel(mc.Model), defaultProvider)
+ if ref == nil {
+ return providers.FallbackCandidate{}, false
}
+
+ return providers.FallbackCandidate{
+ Provider: ref.Provider,
+ Model: ref.Model,
+ RPM: mc.RPM,
+ IdentityKey: modelConfigIdentityKey(mc),
+ }, true
+}
+
+func lookupModelConfigByRef(cfg *config.Config, raw string) *config.ModelConfig {
+ raw = strings.TrimSpace(raw)
+ if raw == "" || cfg == nil {
+ return nil
+ }
+
+ if mc, err := cfg.GetModelConfig(raw); err == nil && mc != nil && strings.TrimSpace(mc.Model) != "" {
+ return mc
+ }
+
+ for i := range cfg.ModelList {
+ mc := cfg.ModelList[i]
+ if mc == nil {
+ continue
+ }
+ fullModel := strings.TrimSpace(mc.Model)
+ if fullModel == "" {
+ continue
+ }
+ if fullModel == raw {
+ return mc
+ }
+ _, modelID := providers.ExtractProtocol(fullModel)
+ if modelID == raw {
+ return mc
+ }
+ }
+
+ return nil
+}
+
+func resolveModelCandidate(
+ cfg *config.Config,
+ defaultProvider string,
+ raw string,
+) (providers.FallbackCandidate, bool) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return providers.FallbackCandidate{}, false
+ }
+
+ if mc := lookupModelConfigByRef(cfg, raw); mc != nil {
+ return candidateFromModelConfig(defaultProvider, mc)
+ }
+
+ ref := providers.ParseModelRef(raw, defaultProvider)
+ if ref == nil {
+ return providers.FallbackCandidate{}, false
+ }
+
+ return providers.FallbackCandidate{
+ Provider: ref.Provider,
+ Model: ref.Model,
+ }, true
}
func resolveModelCandidates(
@@ -54,14 +112,29 @@ func resolveModelCandidates(
primary string,
fallbacks []string,
) []providers.FallbackCandidate {
- return providers.ResolveCandidatesWithLookup(
- providers.ModelConfig{
- Primary: primary,
- Fallbacks: fallbacks,
- },
- defaultProvider,
- buildModelListResolver(cfg),
- )
+ seen := make(map[string]bool)
+ candidates := make([]providers.FallbackCandidate, 0, 1+len(fallbacks))
+
+ addCandidate := func(raw string) {
+ candidate, ok := resolveModelCandidate(cfg, defaultProvider, raw)
+ if !ok {
+ return
+ }
+
+ key := candidate.StableKey()
+ if seen[key] {
+ return
+ }
+ seen[key] = true
+ candidates = append(candidates, candidate)
+ }
+
+ addCandidate(primary)
+ for _, fallback := range fallbacks {
+ addCandidate(fallback)
+ }
+
+ return candidates
}
func resolvedCandidateModel(candidates []providers.FallbackCandidate, fallback string) string {
diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go
index 56439885a..c5eeb3a49 100644
--- a/pkg/agent/subturn.go
+++ b/pkg/agent/subturn.go
@@ -432,6 +432,7 @@ func spawnSubTurn(
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
defer func() {
if r := recover(); r != nil {
+ logger.RecoverPanicNoExit(r)
err = fmt.Errorf("subturn panicked: %v", r)
result = nil
logger.ErrorCF("subturn", "SubTurn panicked", map[string]any{
@@ -515,6 +516,7 @@ func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, re
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
defer func() {
if r := recover(); r != nil {
+ logger.RecoverPanicNoExit(r)
logger.WarnCF("subturn", "recovered panic sending to pendingResults", map[string]any{
"parent_id": parentTS.turnID,
"child_id": childID,
@@ -607,6 +609,7 @@ type ephemeralSessionStoreIface interface {
SetHistory(key string, history []providers.Message)
TruncateHistory(key string, keepLast int)
Save(key string) error
+ ListSessions() []string
Close() error
}
@@ -666,8 +669,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
e.history = e.history[len(e.history)-keepLast:]
}
-func (e *ephemeralSessionStore) Save(_ string) error { return nil }
-func (e *ephemeralSessionStore) Close() error { return nil }
+func (e *ephemeralSessionStore) Save(_ string) error { return nil }
+func (e *ephemeralSessionStore) Close() error { return nil }
+func (e *ephemeralSessionStore) ListSessions() []string { return nil }
func (e *ephemeralSessionStore) truncateLocked() {
if len(e.history) > maxEphemeralHistorySize {
diff --git a/pkg/agent/turn.go b/pkg/agent/turn.go
index 41a57d942..b30fa186d 100644
--- a/pkg/agent/turn.go
+++ b/pkg/agent/turn.go
@@ -8,6 +8,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -341,6 +342,23 @@ func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
ts.captureRestorePoint(history, summary)
}
+// ingestMessage calls the ContextManager's Ingest method for a persisted message.
+// Errors are logged but never block the turn.
+func (ts *turnState) ingestMessage(ctx context.Context, al *AgentLoop, msg providers.Message) {
+ if al.contextManager == nil {
+ return
+ }
+ if err := al.contextManager.Ingest(ctx, &IngestRequest{
+ SessionKey: ts.sessionKey,
+ Message: msg,
+ }); err != nil {
+ logger.WarnCF("agent", "Context manager ingest failed", map[string]any{
+ "session_key": ts.sessionKey,
+ "error": err.Error(),
+ })
+ }
+}
+
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)
diff --git a/pkg/audio/asr/README.md b/pkg/audio/asr/README.md
new file mode 100644
index 000000000..0477276dd
--- /dev/null
+++ b/pkg/audio/asr/README.md
@@ -0,0 +1,166 @@
+# ASR (Automatic Speech Recognition)
+
+This package handles speech-to-text for PicoClaw voice input.
+
+If you are new to ASR setup, the simplest mental model is:
+
+1. Add one or more ASR-capable entries to `model_list`.
+2. Point `voice.model_name` at the one you want to use.
+3. Put the API key in `.security.yml`.
+
+## Quick Recommendation
+
+For most new users, start with one of these:
+
+| Provider | Example model | Why start here |
+| --- | --- | --- |
+| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Fast Whisper-style transcription and a straightforward OpenAI-compatible API. Groq currently advertises a free tier plan for 2000 reqs/day. |
+| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | Easy setup and strong speech-to-text quality. ElevenLabs currently advertises a free plan that includes speech-to-text usage. |
+
+Pricing and free-plan limits can change, so check the linked pricing pages before depending on them in production.
+
+## How ASR Configuration Works
+
+PicoClaw does not keep ASR API keys inside the `voice` section.
+
+Instead:
+
+- `voice.model_name` chooses a named entry from `model_list`.
+- The matching `model_list` entry describes the actual provider and model.
+- `.security.yml` stores the API key for that named model entry.
+
+This is the recommended pattern because it is explicit, reusable, and consistent with the rest of PicoClaw's model configuration.
+
+## Recommended Setup
+
+### Option A: Groq Whisper
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "model_name": "groq-asr",
+ "echo_transcription": true
+ },
+ "model_list": [
+ {
+ "model_name": "groq-asr",
+ "model": "groq/whisper-large-v3-turbo"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ groq-asr:
+ api_keys:
+ - "gsk_your_groq_key"
+```
+
+Notes:
+
+- You can omit `api_base` and PicoClaw will use Groq's default API base automatically.
+- If you set `api_base` manually for Groq Whisper, both of these forms work:
+ - `https://api.groq.com/openai/v1`
+ - `https://api.groq.com/openai/v1/audio/transcriptions`
+- Any OpenAI-compatible Whisper model name containing `whisper` can use the Whisper transcription path, not only `whisper-large-v3-turbo`.
+
+### Option B: ElevenLabs
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "model_name": "elevenlabs-asr",
+ "echo_transcription": true
+ },
+ "model_list": [
+ {
+ "model_name": "elevenlabs-asr",
+ "model": "elevenlabs/scribe_v1"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ elevenlabs-asr:
+ api_keys:
+ - "sk-elevenlabs-your-key"
+```
+
+### Option C: OpenAI Whisper
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "model_name": "openai-asr"
+ },
+ "model_list": [
+ {
+ "model_name": "openai-asr",
+ "model": "openai/whisper-1"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ openai-asr:
+ api_keys:
+ - "sk-openai-your-key"
+```
+
+## Other ASR-Capable Model Types
+
+PicoClaw currently supports three main ASR routes:
+
+| Route | Example models | Behavior |
+| --- | --- | --- |
+| ElevenLabs ASR | `elevenlabs/scribe_v1` | Uses the ElevenLabs transcription API. |
+| Whisper endpoint models | `openai/whisper-1`, `groq/whisper-large-v3` | Uses an OpenAI-compatible `/audio/transcriptions` endpoint. |
+| Audio-capable chat models **(Under construction)** | `openai/gpt-4o-audio-preview`, `gemini/gemini-2.5-flash` | Sends audio to a multimodal chat model and asks it to transcribe. |
+
+If you are unsure which one to pick, choose Groq Whisper or ElevenLabs first.
+
+## How PicoClaw Chooses a Transcriber
+
+`DetectTranscriber` resolves ASR in this order:
+
+1. **Preferred path**: resolve `voice.model_name` against `model_list`.
+2. If that resolved model is:
+ - `elevenlabs/...`, PicoClaw uses the ElevenLabs transcriber.
+ - an OpenAI-compatible Whisper model, PicoClaw uses the Whisper transcriber.
+ - an audio-capable chat model, PicoClaw uses `AudioModelTranscriber`.
+3. **Fallback path**: if `voice.model_name` is not set, PicoClaw performs a compatibility scan through `model_list` for legacy auto-detected ASR entries.
+
+Fallback scanning exists for backward compatibility. New configurations should set `voice.model_name` explicitly.
+
+## Common Mistakes
+
+- Defining an ASR model in `model_list` but forgetting to set `voice.model_name`.
+- Putting the API key in `voice` instead of `.security.yml`.
+- Using a non-ASR model and expecting Whisper-style transcription behavior.
+- Setting a custom `api_base` that points to the wrong provider endpoint.
+
+## Minimal Checklist
+
+Before testing voice input, make sure:
+
+- `voice.model_name` matches a `model_list[].model_name`.
+- The matching `.security.yml` entry contains a valid API key.
+- The selected model is actually ASR-capable.
+- Voice input is enabled for the channel you are using.
diff --git a/pkg/audio/asr/README_zh.md b/pkg/audio/asr/README_zh.md
new file mode 100644
index 000000000..104116080
--- /dev/null
+++ b/pkg/audio/asr/README_zh.md
@@ -0,0 +1,166 @@
+# ASR(自动语音识别)
+
+这个目录负责 PicoClaw 的语音转文字能力。
+
+如果你是第一次配置 ASR,可以参考如下步骤:
+
+1. 在 `model_list` 里添加一个或多个支持 ASR 的模型条目。
+2. 用 `voice.model_name` 指向你想使用的那个条目。
+3. 在 `.security.yml` 里配置对应的 API Key。
+
+## 快速推荐
+
+对于大多数新用户,建议先从下面两种开始:
+
+| 提供商 | 示例模型 | 推荐理由 |
+| --- | --- | --- |
+| [Groq](https://console.groq.com/keys) | `groq/whisper-large-v3-turbo` | Whisper 风格转录速度快,并且提供 OpenAI 兼容接口,配置比较直接。Groq 目前官方提供2000请求每日的免费套餐。 |
+| [ElevenLabs](https://elevenlabs.io/pricing) | `elevenlabs/scribe_v1` | 上手简单,语音转文字质量也不错。ElevenLabs 目前官方免费套餐包含 STT 用量。 |
+
+价格和免费额度可能会变化,正式使用前请以官网定价页为准。
+
+## ASR 配置是如何工作的
+
+PicoClaw 不会把 ASR 的 API Key 放在 `voice` 配置里。
+
+推荐的方式是:
+
+- `voice.model_name` 用来选择 `model_list` 里的某个命名模型。
+- `model_list` 条目描述真实的提供商和模型。
+- `.security.yml` 负责保存该模型条目的 API Key。
+
+这种方式更明确、更安全,也和 PicoClaw 其他模型配置方式保持一致。
+
+## 推荐配置方式
+
+### 方案 A:Groq Whisper
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "model_name": "groq-asr",
+ "echo_transcription": true
+ },
+ "model_list": [
+ {
+ "model_name": "groq-asr",
+ "model": "groq/whisper-large-v3-turbo"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ groq-asr:
+ api_keys:
+ - "gsk_your_groq_key"
+```
+
+说明:
+
+- 你可以不写 `api_base`,PicoClaw 会自动使用 Groq 默认接口地址。
+- 如果你手动设置 Groq Whisper 的 `api_base`,下面两种写法都可以:
+ - `https://api.groq.com/openai/v1`
+ - `https://api.groq.com/openai/v1/audio/transcriptions`
+- 只要是 OpenAI 兼容、并且模型名里包含 `whisper` 的模型,都可以走 Whisper 转录路径,不仅限于 `whisper-large-v3-turbo`。
+
+### 方案 B:ElevenLabs
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "model_name": "elevenlabs-asr",
+ "echo_transcription": true
+ },
+ "model_list": [
+ {
+ "model_name": "elevenlabs-asr",
+ "model": "elevenlabs/scribe_v1"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ elevenlabs-asr:
+ api_keys:
+ - "sk-elevenlabs-your-key"
+```
+
+### 方案 C:OpenAI Whisper
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "model_name": "openai-asr"
+ },
+ "model_list": [
+ {
+ "model_name": "openai-asr",
+ "model": "openai/whisper-1"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ openai-asr:
+ api_keys:
+ - "sk-openai-your-key"
+```
+
+## 其他支持 ASR 的模型类型
+
+PicoClaw 目前主要支持三种 ASR 路径:
+
+| 路径 | 示例模型 | 行为说明 |
+| --- | --- | --- |
+| ElevenLabs ASR | `elevenlabs/scribe_v1` | 使用 ElevenLabs 的语音转录接口。 |
+| Whisper 接口模型 | `openai/whisper-1`、`groq/whisper-large-v3` | 使用 OpenAI 兼容的 `/audio/transcriptions` 接口。 |
+| 支持音频的聊天模型 **(重构中)** | `openai/gpt-4o-audio-preview`、`gemini/gemini-2.5-flash` | 把音频发给多模态聊天模型,并要求它返回转录结果。 |
+
+如果你不确定该选哪种,建议优先使用 Groq Whisper 或 ElevenLabs。
+
+## PicoClaw 如何选择转录器
+
+`DetectTranscriber` 会按下面顺序选择 ASR:
+
+1. **首选路径**:根据 `voice.model_name` 在 `model_list` 中找到对应模型。
+2. 如果找到的模型属于以下类型:
+ - `elevenlabs/...`,则使用 ElevenLabs transcriber。
+ - OpenAI 兼容的 Whisper 模型,则使用 Whisper transcriber。
+ - 支持音频输入的聊天模型,则使用 `AudioModelTranscriber`。
+3. **回退路径**:如果没有设置 `voice.model_name`,PicoClaw 会为了兼容旧配置,扫描 `model_list` 中可自动识别的 ASR 条目。
+
+回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.model_name`。
+
+## 常见错误
+
+- 在 `model_list` 里定义了 ASR 模型,但忘了设置 `voice.model_name`。
+- 把 API Key 写进了 `voice`,而不是 `.security.yml`。
+- 选择了不支持 ASR 的模型,却期望得到 Whisper 风格的转录结果。
+- 自定义了错误的 `api_base`,导致请求打到错误的接口地址。
+
+## 最小检查清单
+
+在测试语音输入前,请确认:
+
+- `voice.model_name` 能正确匹配某个 `model_list[].model_name`。
+- `.security.yml` 中对应条目已经配置了有效 API Key。
+- 你选择的模型确实支持 ASR。
+- 你当前使用的频道已经启用了语音输入能力。
diff --git a/pkg/audio/asr/agent.go b/pkg/audio/asr/agent.go
new file mode 100644
index 000000000..c483a0778
--- /dev/null
+++ b/pkg/audio/asr/agent.go
@@ -0,0 +1,253 @@
+package asr
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/pion/rtp"
+ "github.com/pion/webrtc/v3/pkg/media/oggwriter"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+type speechAccumulator struct {
+ writer *oggwriter.OggWriter
+ file string
+ lastAudioAt time.Time
+ mu sync.Mutex
+ closed bool
+ chatID string
+ speakerID string
+ sessionID string
+ channel string
+}
+
+func (a *speechAccumulator) Push(chunk bus.AudioChunk) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ if a.closed {
+ return
+ }
+
+ a.lastAudioAt = time.Now()
+
+ pkt := &rtp.Packet{
+ Header: rtp.Header{
+ SequenceNumber: uint16(chunk.Sequence),
+ Timestamp: chunk.Timestamp,
+ SSRC: 1, // Stable arbitrary dummy
+ },
+ Payload: chunk.Data,
+ }
+
+ if err := a.writer.WriteRTP(pkt); err != nil {
+ logger.ErrorCF("voice-agent", "Failed to write RTP", map[string]any{"error": err})
+ }
+}
+
+func (a *speechAccumulator) Close() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if !a.closed {
+ a.writer.Close()
+ a.closed = true
+ }
+}
+
+type Agent struct {
+ bus *bus.MessageBus
+ transcriber Transcriber
+
+ mu sync.Mutex
+ sessions map[string]*speechAccumulator // keyed by sessionID_speakerID
+}
+
+func NewAgent(mb *bus.MessageBus, t Transcriber) *Agent {
+ return &Agent{
+ bus: mb,
+ transcriber: t,
+ sessions: make(map[string]*speechAccumulator),
+ }
+}
+
+func (a *Agent) Start(ctx context.Context) {
+ logger.InfoCF("voice-agent", "Started Voice Agent orchestrator", nil)
+ go a.listenChunks(ctx)
+ go a.vadTick(ctx)
+
+ // Cleanup sessions on shutdown
+ go func() {
+ <-ctx.Done()
+ a.mu.Lock()
+ for key, acc := range a.sessions {
+ acc.Close()
+ os.Remove(acc.file)
+ delete(a.sessions, key)
+ }
+ a.mu.Unlock()
+ logger.InfoCF("voice-agent", "Cleaned up voice sessions on shutdown", nil)
+ }()
+}
+
+func (a *Agent) listenChunks(ctx context.Context) {
+ chunks := a.bus.AudioChunksChan()
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case chunk, ok := <-chunks:
+ if !ok {
+ return
+ }
+ a.handleChunk(chunk)
+ }
+ }
+}
+
+func (a *Agent) handleChunk(chunk bus.AudioChunk) {
+ // Only accept Opus-encoded audio
+ if chunk.Format != "opus" {
+ logger.DebugCF("voice-agent", "Ignoring unsupported audio format", map[string]any{"format": chunk.Format})
+ return
+ }
+
+ key := fmt.Sprintf("%s_%s", chunk.SessionID, chunk.SpeakerID)
+
+ a.mu.Lock()
+ acc, exists := a.sessions[key]
+ if !exists {
+ filename := filepath.Join(os.TempDir(), fmt.Sprintf("voice_%s_%d.ogg", key, time.Now().UnixNano()))
+ writer, err := oggwriter.New(filename, uint32(chunk.SampleRate), uint16(chunk.Channels))
+ if err != nil {
+ a.mu.Unlock()
+ logger.ErrorCF("voice-agent", "Failed to create OggWriter", map[string]any{"error": err})
+ return
+ }
+
+ acc = &speechAccumulator{
+ writer: writer,
+ file: filename,
+ lastAudioAt: time.Now(),
+ chatID: chunk.ChatID,
+ speakerID: chunk.SpeakerID,
+ sessionID: chunk.SessionID,
+ channel: chunk.Channel,
+ }
+ a.sessions[key] = acc
+ logger.DebugCF("voice-agent", "Started accumulating voice", map[string]any{"key": key, "file": filename})
+ }
+ a.mu.Unlock()
+
+ acc.Push(chunk)
+}
+
+func (a *Agent) vadTick(ctx context.Context) {
+ ticker := time.NewTicker(500 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ a.checkSilence(ctx)
+ }
+ }
+}
+
+func (a *Agent) checkSilence(ctx context.Context) {
+ a.mu.Lock()
+ now := time.Now()
+ var finished []*speechAccumulator
+
+ for key, acc := range a.sessions {
+ acc.mu.Lock()
+ last := acc.lastAudioAt
+ acc.mu.Unlock()
+
+ if now.Sub(last) > 1500*time.Millisecond {
+ acc.Close()
+ delete(a.sessions, key)
+ finished = append(finished, acc)
+ }
+ }
+ a.mu.Unlock()
+
+ for _, acc := range finished {
+ go a.processUtterance(ctx, acc)
+ }
+}
+
+func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) {
+ defer os.Remove(acc.file)
+
+ logger.InfoCF("voice-agent", "User finished speaking, transcribing...", map[string]any{"file": acc.file})
+
+ if a.transcriber == nil {
+ logger.ErrorCF("voice-agent", "No STT configured!", nil)
+ return
+ }
+
+ res, err := a.transcriber.Transcribe(ctx, acc.file)
+ if err != nil {
+ logger.ErrorCF("voice-agent", "Transcription failed", map[string]any{"error": err})
+ return
+ }
+
+ if res.Text == "" {
+ logger.DebugCF("voice-agent", "Ignored empty transcription", map[string]any{"file": acc.file})
+ return
+ }
+
+ logger.InfoCF("voice-agent", "Transcription result", map[string]any{"text": res.Text, "duration": res.Duration})
+
+ channelType := acc.channel
+ if channelType == "" {
+ channelType = "discord" // fallback for legacy chunks
+ }
+
+ text := strings.ToLower(strings.TrimSpace(res.Text))
+ if strings.Contains(text, "leave the voice channel") || strings.Contains(text, "leave voice") ||
+ strings.Contains(text, "disconnect voice") || strings.Contains(text, "leave the channel") ||
+ strings.Contains(text, "leave channel") {
+ logger.InfoCF("voice-agent", "Voice command triggered: leave", nil)
+ if err := a.bus.PublishVoiceControl(ctx, bus.VoiceControl{
+ SessionID: acc.sessionID,
+ Type: "command",
+ Action: "leave",
+ }); err != nil {
+ logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err})
+ }
+ if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{
+ Context: bus.NewOutboundContext(channelType, acc.chatID, ""),
+ Content: "Goodbye! Leaving the voice channel.",
+ }); err != nil {
+ logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err})
+ }
+ return
+ }
+
+ oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally."
+
+ if err := a.bus.PublishInbound(ctx, bus.InboundMessage{
+ Context: bus.InboundContext{
+ Channel: channelType,
+ ChatID: acc.chatID,
+ ChatType: "channel",
+ SenderID: acc.speakerID,
+ Raw: map[string]string{
+ "is_voice": "true",
+ },
+ },
+ Content: res.Text + oralPrompt,
+ }); err != nil {
+ logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err})
+ }
+}
diff --git a/pkg/audio/asr/agent_test.go b/pkg/audio/asr/agent_test.go
new file mode 100644
index 000000000..0f9bcb3b2
--- /dev/null
+++ b/pkg/audio/asr/agent_test.go
@@ -0,0 +1,196 @@
+package asr
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/pion/webrtc/v3/pkg/media/oggwriter"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+)
+
+type fakeTranscriber struct {
+ text string
+ err error
+ lastPath string
+}
+
+func (f *fakeTranscriber) Name() string { return "fake" }
+
+func (f *fakeTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
+ f.lastPath = audioFilePath
+ if f.err != nil {
+ return nil, f.err
+ }
+ return &TranscriptionResponse{Text: f.text}, nil
+}
+
+func waitForFileRemoval(t *testing.T, path string, timeout time.Duration) {
+ t.Helper()
+
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ if _, err := os.Stat(path); err == nil {
+ t.Fatalf("expected file to be removed: %s", path)
+ }
+}
+
+func TestAgentHandleChunkCreatesSession(t *testing.T) {
+ t.Parallel()
+
+ mb := bus.NewMessageBus()
+ defer mb.Close()
+
+ agent := NewAgent(mb, &fakeTranscriber{})
+
+ chunk := bus.AudioChunk{
+ SessionID: "sess",
+ SpeakerID: "speaker",
+ ChatID: "chat",
+ Channel: "discord",
+ Sequence: 1,
+ Timestamp: 1,
+ SampleRate: 48000,
+ Channels: 2,
+ Format: "opus",
+ Data: []byte{0xF8, 0xFF, 0xFE},
+ }
+
+ agent.handleChunk(chunk)
+
+ key := "sess_speaker"
+ agent.mu.Lock()
+ acc, ok := agent.sessions[key]
+ agent.mu.Unlock()
+ if !ok {
+ t.Fatal("expected session to be created")
+ }
+
+ acc.Close()
+ _ = os.Remove(acc.file)
+}
+
+func TestAgentHandleChunkIgnoresUnsupportedFormat(t *testing.T) {
+ t.Parallel()
+
+ mb := bus.NewMessageBus()
+ defer mb.Close()
+
+ agent := NewAgent(mb, &fakeTranscriber{})
+
+ chunk := bus.AudioChunk{Format: "pcm"}
+ agent.handleChunk(chunk)
+
+ agent.mu.Lock()
+ count := len(agent.sessions)
+ agent.mu.Unlock()
+ if count != 0 {
+ t.Fatalf("expected no sessions, got %d", count)
+ }
+}
+
+func TestAgentProcessUtteranceLeaveCommand(t *testing.T) {
+ t.Parallel()
+
+ mb := bus.NewMessageBus()
+ defer mb.Close()
+
+ tr := &fakeTranscriber{text: "please leave the voice channel now"}
+ agent := NewAgent(mb, tr)
+
+ tmpDir := t.TempDir()
+ filePath := filepath.Join(tmpDir, "voice.ogg")
+ if err := os.WriteFile(filePath, []byte("data"), 0o600); err != nil {
+ t.Fatalf("write temp file: %v", err)
+ }
+
+ acc := &speechAccumulator{
+ file: filePath,
+ chatID: "chat",
+ speakerID: "speaker",
+ sessionID: "sess",
+ channel: "discord",
+ }
+
+ agent.processUtterance(context.Background(), acc)
+
+ select {
+ case ctrl := <-mb.VoiceControlsChan():
+ if ctrl.Action != "leave" || ctrl.Type != "command" || ctrl.SessionID != "sess" {
+ t.Fatalf("unexpected voice control: %#v", ctrl)
+ }
+ case <-time.After(250 * time.Millisecond):
+ t.Fatal("expected voice control publish")
+ }
+
+ select {
+ case out := <-mb.OutboundChan():
+ if !strings.Contains(out.Content, "Leaving the voice channel") {
+ t.Fatalf("unexpected outbound content: %q", out.Content)
+ }
+ case <-time.After(250 * time.Millisecond):
+ t.Fatal("expected outbound publish")
+ }
+
+ if _, err := os.Stat(filePath); !os.IsNotExist(err) {
+ t.Fatalf("expected temp file to be removed")
+ }
+}
+
+func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) {
+ t.Parallel()
+
+ mb := bus.NewMessageBus()
+ defer mb.Close()
+
+ tr := &fakeTranscriber{text: "hello there"}
+ agent := NewAgent(mb, tr)
+
+ filePath := filepath.Join(t.TempDir(), "voice.ogg")
+ writer, err := oggwriter.New(filePath, 48000, 2)
+ if err != nil {
+ t.Fatalf("create ogg writer: %v", err)
+ }
+
+ acc := &speechAccumulator{
+ writer: writer,
+ file: filePath,
+ lastAudioAt: time.Now().Add(-2 * time.Second),
+ chatID: "chat",
+ speakerID: "speaker",
+ sessionID: "sess",
+ channel: "slack",
+ }
+
+ agent.mu.Lock()
+ agent.sessions["sess_speaker"] = acc
+ agent.mu.Unlock()
+
+ agent.checkSilence(context.Background())
+
+ select {
+ case msg := <-mb.InboundChan():
+ if msg.Channel != "slack" {
+ t.Fatalf("unexpected inbound channel: %q", msg.Channel)
+ }
+ if !strings.Contains(msg.Content, "hello there") {
+ t.Fatalf("unexpected inbound content: %q", msg.Content)
+ }
+ if msg.Context.Raw["is_voice"] != "true" {
+ t.Fatalf("expected is_voice metadata, got %#v", msg.Context.Raw)
+ }
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("expected inbound publish")
+ }
+
+ waitForFileRemoval(t, filePath, 500*time.Millisecond)
+}
diff --git a/pkg/audio/asr/asr.go b/pkg/audio/asr/asr.go
new file mode 100644
index 000000000..d15dc3f09
--- /dev/null
+++ b/pkg/audio/asr/asr.go
@@ -0,0 +1,131 @@
+package asr
+
+import (
+ "context"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+type Transcriber interface {
+ Name() string
+ Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
+}
+
+type TranscriptionResponse struct {
+ Text string `json:"text"`
+ Language string `json:"language,omitempty"`
+ Duration float64 `json:"duration,omitempty"`
+}
+
+func supportsAudioTranscription(model string) bool {
+ protocol, _ := providers.ExtractProtocol(model)
+
+ switch protocol {
+ case "openai", "azure", "azure-openai",
+ "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
+ "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
+ "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
+ "qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
+ "coding-plan", "alibaba-coding", "qwen-coding":
+ // These protocols all go through the OpenAI-compatible or Azure provider path in
+ // providers.CreateProviderFromConfig, so they are the only ones that can supply
+ // the audio media payload shape expected by NewAudioModelTranscriber.
+
+ // TODO: Further restrict this by modelID, since not every model under these
+ // protocols supports audio transcription.
+ return true
+ default:
+ return false
+ }
+}
+
+func supportsWhisperTranscription(model string) bool {
+ protocol, _ := providers.ExtractProtocol(model)
+
+ switch protocol {
+ case "openai", "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
+ "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
+ "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
+ "qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
+ "coding-plan", "alibaba-coding", "qwen-coding", "mimo":
+ return true
+ default:
+ return false
+ }
+}
+
+func whisperModelID(modelCfg *config.ModelConfig) string {
+ if modelCfg == nil || modelCfg.APIKey() == "" {
+ return ""
+ }
+
+ if !supportsWhisperTranscription(modelCfg.Model) {
+ return ""
+ }
+
+ _, modelID := providers.ExtractProtocol(strings.TrimSpace(modelCfg.Model))
+ if strings.Contains(strings.ToLower(modelID), "whisper") {
+ return modelID
+ }
+ return ""
+}
+
+func transcriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
+ if modelCfg == nil {
+ return nil
+ }
+
+ protocol, _ := providers.ExtractProtocol(modelCfg.Model)
+ if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
+ return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
+ }
+ if modelID := whisperModelID(modelCfg); modelID != "" {
+ return NewWhisperTranscriber(modelCfg)
+ }
+ if supportsAudioTranscription(modelCfg.Model) {
+ return NewAudioModelTranscriber(modelCfg)
+ }
+ return nil
+}
+
+func fallbackTranscriberFromModelConfig(modelCfg *config.ModelConfig) Transcriber {
+ if modelCfg == nil {
+ return nil
+ }
+
+ protocol, _ := providers.ExtractProtocol(modelCfg.Model)
+ if protocol == "elevenlabs" && modelCfg.APIKey() != "" {
+ return NewElevenLabsTranscriber(modelCfg.APIKey(), modelCfg.APIBase)
+ }
+ if modelID := whisperModelID(modelCfg); modelID != "" {
+ return NewWhisperTranscriber(modelCfg)
+ }
+ return nil
+}
+
+// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
+// nil if no supported transcription provider is configured.
+func DetectTranscriber(cfg *config.Config) Transcriber {
+ if cfg == nil {
+ return nil
+ }
+
+ if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
+ modelCfg, err := cfg.GetModelConfig(modelName)
+ if err == nil {
+ if tr := transcriberFromModelConfig(modelCfg); tr != nil {
+ return tr
+ }
+ }
+ }
+
+ // Fall back to compatibility scanning for legacy auto-detected ASR providers.
+ for _, mc := range cfg.ModelList {
+ if tr := fallbackTranscriberFromModelConfig(mc); tr != nil {
+ return tr
+ }
+ }
+ return nil
+}
diff --git a/pkg/voice/transcriber_test.go b/pkg/audio/asr/asr_test.go
similarity index 67%
rename from pkg/voice/transcriber_test.go
rename to pkg/audio/asr/asr_test.go
index 3e71ff13a..0970d69f4 100644
--- a/pkg/voice/transcriber_test.go
+++ b/pkg/audio/asr/asr_test.go
@@ -1,4 +1,4 @@
-package voice
+package asr
import (
"testing"
@@ -33,26 +33,68 @@ func TestDetectTranscriber(t *testing.T) {
wantName: "audio-model",
},
{
- name: "groq via model list",
+ name: "voice model name alias selects elevenlabs transcriber",
+ cfg: &config.Config{
+ Voice: config.VoiceConfig{ModelName: "my-asr-model"},
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "my-asr-model",
+ Model: "elevenlabs/scribe_v1",
+ APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test"),
+ },
+ },
+ },
+ wantName: "elevenlabs",
+ },
+ {
+ name: "voice model name alias selects whisper transcriber for groq",
+ cfg: &config.Config{
+ Voice: config.VoiceConfig{ModelName: "my-asr-model"},
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "my-asr-model",
+ Model: "groq/whisper-large-v3",
+ APIKeys: config.SimpleSecureStrings("sk-groq-model"),
+ },
+ },
+ },
+ wantName: "whisper",
+ },
+ {
+ name: "openai whisper alias selects whisper transcriber",
+ cfg: &config.Config{
+ Voice: config.VoiceConfig{ModelName: "my-asr-model"},
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "my-asr-model",
+ Model: "openai/whisper-1",
+ APIKeys: config.SimpleSecureStrings("sk-openai-model"),
+ },
+ },
+ },
+ wantName: "whisper",
+ },
+ {
+ name: "whisper via model list fallback",
cfg: &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "openai", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("sk-openai")},
{
ModelName: "groq",
- Model: "groq/llama-3.3-70b",
+ Model: "groq/whisper-large-v3-turbo",
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
},
},
},
- wantName: "groq",
+ wantName: "whisper",
},
{
- name: "voice model name selects non-gemini audio model transcriber",
+ name: "voice model name alias selects non-gemini audio model transcriber",
cfg: &config.Config{
- Voice: config.VoiceConfig{ModelName: "voice-openai-audio"},
+ Voice: config.VoiceConfig{ModelName: "my-asr-model"},
ModelList: []*config.ModelConfig{
{
- ModelName: "voice-openai-audio",
+ ModelName: "my-asr-model",
Model: "openai/gpt-4o-audio-preview",
APIKeys: config.SimpleSecureStrings("sk-openai"),
},
@@ -92,7 +134,7 @@ func TestDetectTranscriber(t *testing.T) {
name: "groq model list entry without key is skipped",
cfg: &config.Config{
ModelList: []*config.ModelConfig{
- {Model: "groq/llama-3.3-70b"},
+ {Model: "groq/whisper-large-v3"},
},
},
wantNil: true,
@@ -103,12 +145,12 @@ func TestDetectTranscriber(t *testing.T) {
ModelList: []*config.ModelConfig{
{
ModelName: "groq",
- Model: "groq/llama-3.3-70b",
+ Model: "groq/whisper-large-v3",
APIKeys: config.SimpleSecureStrings("sk-groq-model"),
},
},
},
- wantName: "groq",
+ wantName: "whisper",
},
{
name: "missing voice model name config returns nil",
@@ -127,15 +169,17 @@ func TestDetectTranscriber(t *testing.T) {
{
name: "elevenlabs voice config key",
cfg: &config.Config{
- Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"},
+ ModelList: []*config.ModelConfig{
+ {Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
+ },
},
wantName: "elevenlabs",
},
{
name: "elevenlabs takes priority over groq model list",
cfg: &config.Config{
- Voice: config.VoiceConfig{ElevenLabsAPIKey: "sk_elevenlabs_test"},
ModelList: []*config.ModelConfig{
+ {Model: "elevenlabs/scribe_v1", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
{
ModelName: "groq",
Model: "groq/llama-3.3-70b",
@@ -149,10 +193,10 @@ func TestDetectTranscriber(t *testing.T) {
name: "voice model name takes priority over elevenlabs",
cfg: &config.Config{
Voice: config.VoiceConfig{
- ModelName: "voice-gemini",
- ElevenLabsAPIKey: "sk_elevenlabs_test",
+ ModelName: "voice-gemini",
},
ModelList: []*config.ModelConfig{
+ {Model: "elevenlabs", APIKeys: config.SimpleSecureStrings("sk_elevenlabs_test")},
{
ModelName: "voice-gemini",
Model: "gemini/gemini-2.5-flash",
diff --git a/pkg/voice/audio_model_transcriber.go b/pkg/audio/asr/audio_model_transcriber.go
similarity index 99%
rename from pkg/voice/audio_model_transcriber.go
rename to pkg/audio/asr/audio_model_transcriber.go
index f3ca81961..e8ded15dd 100644
--- a/pkg/voice/audio_model_transcriber.go
+++ b/pkg/audio/asr/audio_model_transcriber.go
@@ -1,4 +1,4 @@
-package voice
+package asr
import (
"context"
diff --git a/pkg/voice/audio_model_transcriber_test.go b/pkg/audio/asr/audio_model_transcriber_test.go
similarity index 99%
rename from pkg/voice/audio_model_transcriber_test.go
rename to pkg/audio/asr/audio_model_transcriber_test.go
index c33e3bf97..5aaa82061 100644
--- a/pkg/voice/audio_model_transcriber_test.go
+++ b/pkg/audio/asr/audio_model_transcriber_test.go
@@ -1,4 +1,4 @@
-package voice
+package asr
import (
"context"
diff --git a/pkg/voice/elevenlabs_transcriber.go b/pkg/audio/asr/elevenlabs_transcriber.go
similarity index 96%
rename from pkg/voice/elevenlabs_transcriber.go
rename to pkg/audio/asr/elevenlabs_transcriber.go
index 93db10f8d..452b9512d 100644
--- a/pkg/voice/elevenlabs_transcriber.go
+++ b/pkg/audio/asr/elevenlabs_transcriber.go
@@ -1,4 +1,4 @@
-package voice
+package asr
import (
"bytes"
@@ -23,12 +23,16 @@ type ElevenLabsTranscriber struct {
httpClient *http.Client
}
-func NewElevenLabsTranscriber(apiKey string) *ElevenLabsTranscriber {
+func NewElevenLabsTranscriber(apiKey, apiBase string) *ElevenLabsTranscriber {
logger.DebugCF("voice", "Creating ElevenLabs transcriber", map[string]any{"has_api_key": apiKey != ""})
+ if apiBase == "" {
+ apiBase = "https://api.elevenlabs.io"
+ }
+
return &ElevenLabsTranscriber{
apiKey: apiKey,
- apiBase: "https://api.elevenlabs.io",
+ apiBase: apiBase,
httpClient: &http.Client{
Timeout: 120 * time.Second,
},
diff --git a/pkg/voice/elevenlabs_transcriber_test.go b/pkg/audio/asr/elevenlabs_transcriber_test.go
similarity index 91%
rename from pkg/voice/elevenlabs_transcriber_test.go
rename to pkg/audio/asr/elevenlabs_transcriber_test.go
index 78be8958a..fa80110be 100644
--- a/pkg/voice/elevenlabs_transcriber_test.go
+++ b/pkg/audio/asr/elevenlabs_transcriber_test.go
@@ -1,4 +1,4 @@
-package voice
+package asr
import (
"context"
@@ -14,7 +14,7 @@ import (
var _ Transcriber = (*ElevenLabsTranscriber)(nil)
func TestElevenLabsTranscriberName(t *testing.T) {
- tr := NewElevenLabsTranscriber("sk_test")
+ tr := NewElevenLabsTranscriber("sk_test", "")
if got := tr.Name(); got != "elevenlabs" {
t.Errorf("Name() = %q, want %q", got, "elevenlabs")
}
@@ -43,7 +43,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
}))
defer srv.Close()
- tr := NewElevenLabsTranscriber("sk_test")
+ tr := NewElevenLabsTranscriber("sk_test", "")
tr.apiBase = srv.URL
resp, err := tr.Transcribe(context.Background(), audioPath)
@@ -64,7 +64,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
}))
defer srv.Close()
- tr := NewElevenLabsTranscriber("sk_bad")
+ tr := NewElevenLabsTranscriber("sk_bad", "")
tr.apiBase = srv.URL
_, err := tr.Transcribe(context.Background(), audioPath)
@@ -74,7 +74,7 @@ func TestElevenLabsTranscribe(t *testing.T) {
})
t.Run("missing file", func(t *testing.T) {
- tr := NewElevenLabsTranscriber("sk_test")
+ tr := NewElevenLabsTranscriber("sk_test", "")
_, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
if err == nil {
t.Fatal("expected error for missing file, got nil")
diff --git a/pkg/audio/asr/whisper_transcriber.go b/pkg/audio/asr/whisper_transcriber.go
new file mode 100644
index 000000000..406710a8a
--- /dev/null
+++ b/pkg/audio/asr/whisper_transcriber.go
@@ -0,0 +1,245 @@
+package asr
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+type WhisperTranscriber struct {
+ apiKey string
+ apiBase string
+ modelID string
+ providerName string
+ httpClient *http.Client
+}
+
+func NewWhisperTranscriber(modelCfg *config.ModelConfig) *WhisperTranscriber {
+ if modelCfg == nil {
+ return nil
+ }
+
+ protocol, modelID := providers.ExtractProtocol(modelCfg.Model)
+ if modelID == "" {
+ modelID = strings.TrimSpace(modelCfg.Model)
+ }
+
+ tr := newWhisperTranscriber(
+ modelCfg.APIKey(),
+ providers.ResolveAPIBase(modelCfg),
+ modelID,
+ protocol,
+ )
+ if tr == nil {
+ return nil
+ }
+
+ logger.DebugCF("voice", "Creating whisper transcriber", map[string]any{
+ "api_base": tr.apiBase,
+ "has_key": tr.apiKey != "",
+ "model": tr.modelID,
+ "provider": tr.providerName,
+ })
+ return tr
+}
+
+func NewGroqTranscriber(apiKey, modelID string) *WhisperTranscriber {
+ return newWhisperTranscriber(apiKey, "https://api.groq.com/openai/v1", modelID, "groq")
+}
+
+func newWhisperTranscriber(apiKey, apiBase, modelID, providerName string) *WhisperTranscriber {
+ if modelID == "" {
+ return nil
+ }
+ if providerName == "" {
+ providerName = "whisper"
+ }
+ return &WhisperTranscriber{
+ apiKey: apiKey,
+ apiBase: strings.TrimRight(apiBase, "/"),
+ modelID: modelID,
+ providerName: providerName,
+ httpClient: &http.Client{
+ Timeout: 60 * time.Second,
+ },
+ }
+}
+
+func (t *WhisperTranscriber) transcriptionURL() string {
+ base := strings.TrimRight(t.apiBase, "/")
+ if strings.HasSuffix(base, "/audio/transcriptions") {
+ return base
+ }
+ return base + "/audio/transcriptions"
+}
+
+func (t *WhisperTranscriber) TranscribeData(
+ ctx context.Context,
+ data []byte,
+ filename string,
+) (*TranscriptionResponse, error) {
+ logger.InfoCF("voice", "Starting whisper transcription from memory", map[string]any{
+ "bytes": len(data),
+ "filename": filename,
+ "model": t.modelID,
+ "provider": t.providerName,
+ })
+
+ var requestBody bytes.Buffer
+ writer := multipart.NewWriter(&requestBody)
+
+ part, err := writer.CreateFormFile("file", filename)
+ if err != nil {
+ logger.ErrorCF("voice", "Failed to create whisper form file", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to create form file: %w", err)
+ }
+
+ if _, copyErr := io.Copy(part, bytes.NewReader(data)); copyErr != nil {
+ logger.ErrorCF("voice", "Failed to copy whisper file content", map[string]any{"error": copyErr})
+ return nil, fmt.Errorf("failed to copy file content: %w", copyErr)
+ }
+
+ if err = writer.WriteField("model", t.modelID); err != nil {
+ logger.ErrorCF("voice", "Failed to write whisper model field", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to write model field: %w", err)
+ }
+
+ if err = writer.WriteField("response_format", "json"); err != nil {
+ logger.ErrorCF("voice", "Failed to write whisper response_format field", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to write response_format field: %w", err)
+ }
+
+ if err = writer.Close(); err != nil {
+ logger.ErrorCF("voice", "Failed to close whisper multipart writer", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to close multipart writer: %w", err)
+ }
+
+ return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), int64(len(data)))
+}
+
+func (t *WhisperTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
+ logger.InfoCF("voice", "Starting whisper transcription", map[string]any{
+ "audio_file": audioFilePath,
+ "model": t.modelID,
+ "provider": t.providerName,
+ })
+
+ audioFile, err := os.Open(audioFilePath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open audio file %s: %w", audioFilePath, err)
+ }
+ defer audioFile.Close()
+
+ fileInfo, err := audioFile.Stat()
+ if err != nil {
+ return nil, fmt.Errorf("failed to stat audio file %s: %w", audioFilePath, err)
+ }
+
+ var requestBody bytes.Buffer
+ writer := multipart.NewWriter(&requestBody)
+
+ part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create form file: %w", err)
+ }
+
+ if _, copyErr := io.Copy(part, audioFile); copyErr != nil {
+ return nil, fmt.Errorf("failed to copy audio data: %w", copyErr)
+ }
+
+ if err = writer.WriteField("model", t.modelID); err != nil {
+ return nil, fmt.Errorf("failed to write model field: %w", err)
+ }
+
+ if err = writer.WriteField("response_format", "json"); err != nil {
+ return nil, fmt.Errorf("failed to write response_format field: %w", err)
+ }
+
+ if err = writer.Close(); err != nil {
+ return nil, fmt.Errorf("failed to close multipart writer: %w", err)
+ }
+
+ return t.doRequest(ctx, &requestBody, writer.FormDataContentType(), fileInfo.Size())
+}
+
+func (t *WhisperTranscriber) doRequest(
+ ctx context.Context,
+ requestBody *bytes.Buffer,
+ contentType string,
+ fileSize int64,
+) (*TranscriptionResponse, error) {
+ url := t.transcriptionURL()
+ req, err := http.NewRequestWithContext(ctx, "POST", url, requestBody)
+ if err != nil {
+ logger.ErrorCF("voice", "Failed to create whisper request", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", contentType)
+ if t.apiKey != "" {
+ req.Header.Set("Authorization", "Bearer "+t.apiKey)
+ }
+
+ logger.DebugCF("voice", "Sending whisper transcription request", map[string]any{
+ "file_size_bytes": fileSize,
+ "model": t.modelID,
+ "provider": t.providerName,
+ "request_size_bytes": requestBody.Len(),
+ "url": url,
+ })
+
+ resp, err := t.httpClient.Do(req)
+ if err != nil {
+ logger.ErrorCF("voice", "Failed to send whisper request", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ logger.ErrorCF("voice", "Failed to read whisper response", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ logger.ErrorCF("voice", "Whisper API error", map[string]any{
+ "provider": t.providerName,
+ "response": string(body),
+ "status_code": resp.StatusCode,
+ })
+ return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+ }
+
+ var result TranscriptionResponse
+ if err := json.Unmarshal(body, &result); err != nil {
+ logger.ErrorCF("voice", "Failed to unmarshal whisper response", map[string]any{"error": err})
+ return nil, fmt.Errorf("failed to unmarshal response: %w", err)
+ }
+
+ logger.InfoCF("voice", "Whisper transcription completed successfully", map[string]any{
+ "duration_seconds": result.Duration,
+ "language": result.Language,
+ "provider": t.providerName,
+ "text_length": len(result.Text),
+ "transcription_preview": utils.Truncate(result.Text, 50),
+ })
+
+ return &result, nil
+}
+
+func (t *WhisperTranscriber) Name() string {
+ return "whisper"
+}
diff --git a/pkg/audio/asr/whisper_transcriber_test.go b/pkg/audio/asr/whisper_transcriber_test.go
new file mode 100644
index 000000000..a2a5178d1
--- /dev/null
+++ b/pkg/audio/asr/whisper_transcriber_test.go
@@ -0,0 +1,102 @@
+package asr
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func TestWhisperTranscriberTranscribeDataUsesConfiguredModel(t *testing.T) {
+ var gotModel string
+ var gotPath string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ if got := r.Header.Get("Authorization"); got != "Bearer sk-openai-test" {
+ t.Errorf("Authorization = %q, want %q", got, "Bearer sk-openai-test")
+ }
+
+ reader, err := r.MultipartReader()
+ if err != nil {
+ t.Fatalf("MultipartReader() error: %v", err)
+ }
+
+ for {
+ part, err := reader.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatalf("NextPart() error: %v", err)
+ }
+
+ data, err := io.ReadAll(part)
+ if err != nil {
+ t.Fatalf("ReadAll() error: %v", err)
+ }
+
+ if part.FormName() == "model" {
+ gotModel = string(data)
+ }
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "hello from whisper"}); err != nil {
+ t.Fatalf("Encode() error: %v", err)
+ }
+ }))
+ defer server.Close()
+
+ tr := NewWhisperTranscriber(&config.ModelConfig{
+ Model: "openai/whisper-1",
+ APIBase: server.URL,
+ APIKeys: config.SimpleSecureStrings("sk-openai-test"),
+ })
+ tr.httpClient = server.Client()
+
+ resp, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg")
+ if err != nil {
+ t.Fatalf("TranscribeData() error: %v", err)
+ }
+ if resp.Text != "hello from whisper" {
+ t.Errorf("Text = %q, want %q", resp.Text, "hello from whisper")
+ }
+ if gotModel != "whisper-1" {
+ t.Errorf("model field = %q, want %q", gotModel, "whisper-1")
+ }
+ if gotPath != "/audio/transcriptions" {
+ t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions")
+ }
+}
+
+func TestWhisperTranscriberUsesEndpointAPIBaseWithoutDoubleAppend(t *testing.T) {
+ var gotPath string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(TranscriptionResponse{Text: "ok"}); err != nil {
+ t.Fatalf("Encode() error: %v", err)
+ }
+ }))
+ defer server.Close()
+
+ tr := NewWhisperTranscriber(&config.ModelConfig{
+ Model: "groq/whisper-large-v3",
+ APIBase: server.URL + "/audio/transcriptions",
+ APIKeys: config.SimpleSecureStrings("sk-groq-test"),
+ })
+ tr.httpClient = server.Client()
+
+ if _, err := tr.TranscribeData(context.Background(), []byte("audio"), "clip.ogg"); err != nil {
+ t.Fatalf("TranscribeData() error: %v", err)
+ }
+ if gotPath != "/audio/transcriptions" {
+ t.Errorf("path = %q, want %q", gotPath, "/audio/transcriptions")
+ }
+}
diff --git a/pkg/audio/ogg.go b/pkg/audio/ogg.go
new file mode 100644
index 000000000..f0055a574
--- /dev/null
+++ b/pkg/audio/ogg.go
@@ -0,0 +1,57 @@
+package audio
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+)
+
+// DecodeOggOpus reads an Ogg format stream and extracts individual Opus payloads.
+// It calls onFrame for every complete Opus frame found in the stream.
+func DecodeOggOpus(r io.Reader, onFrame func([]byte) error) error {
+ var packet bytes.Buffer
+ header := make([]byte, 27)
+ segment := make([]byte, 255)
+
+ for {
+ if _, err := io.ReadFull(r, header); err != nil {
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ return nil
+ }
+ return fmt.Errorf("failed to read ogg header: %w", err)
+ }
+ if string(header[:4]) != "OggS" {
+ return fmt.Errorf("invalid ogg magic string")
+ }
+
+ pageSegments := int(header[26])
+ segmentTable := make([]byte, pageSegments)
+ if _, err := io.ReadFull(r, segmentTable); err != nil {
+ return fmt.Errorf("failed to read segment table: %w", err)
+ }
+
+ for _, lacing := range segmentTable {
+ if _, err := io.ReadFull(r, segment[:lacing]); err != nil {
+ return fmt.Errorf("failed to read segment data: %w", err)
+ }
+
+ packet.Write(segment[:lacing])
+
+ // If lacing is less than 255, the packet is complete
+ if lacing < 255 {
+ if packet.Len() > 0 {
+ packetBytes := packet.Bytes()
+ // Ignore Ogg Opus headers
+ if !bytes.HasPrefix(packetBytes, []byte("OpusHead")) &&
+ !bytes.HasPrefix(packetBytes, []byte("OpusTags")) {
+ if err := onFrame(packetBytes); err != nil {
+ return err
+ }
+ }
+ // Start new packet
+ packet.Reset()
+ }
+ }
+ }
+ }
+}
diff --git a/pkg/audio/ogg_test.go b/pkg/audio/ogg_test.go
new file mode 100644
index 000000000..8d5e5ac2a
--- /dev/null
+++ b/pkg/audio/ogg_test.go
@@ -0,0 +1,146 @@
+package audio
+
+import (
+ "bytes"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+// buildOggPage helper creates an Ogg page for testing.
+// lacingVals specifies the segment table, and data is the payload.
+func buildOggPage(lacingVals []byte, data []byte) []byte {
+ var buf bytes.Buffer
+ // 27-byte Ogg header
+ header := make([]byte, 27)
+ copy(header[:4], "OggS")
+ header[5] = 0 // type flag
+ // For testing, we only care about OggS magic and page_segments (byte 26)
+ header[26] = byte(len(lacingVals))
+ buf.Write(header)
+ buf.Write(lacingVals)
+ buf.Write(data)
+ return buf.Bytes()
+}
+
+func TestDecodeOggOpus_ValidParsing(t *testing.T) {
+ var b bytes.Buffer
+
+ // Packet 1: Single segment, length 50
+ pkt1 := bytes.Repeat([]byte{1}, 50)
+ // Packet 2: Multi-segment (255 + 10 = 265 bytes)
+ pkt2Part1 := bytes.Repeat([]byte{2}, 255)
+ pkt2Part2 := bytes.Repeat([]byte{2}, 10)
+ // Packet 3: Continued across pages. Page 1 gets 255, Page 2 gets 20. Total 275 bytes.
+ pkt3Part1 := bytes.Repeat([]byte{3}, 255)
+ pkt3Part2 := bytes.Repeat([]byte{3}, 20)
+
+ // Page 1: OpusHead (skip), OpusTags (skip), pkt1, pkt2, pkt3Part1
+ page1Lacing := []byte{8, 8, 50, 255, 10, 255}
+ page1Data := bytes.Join([][]byte{
+ []byte("OpusHead"),
+ []byte("OpusTags"),
+ pkt1,
+ pkt2Part1, pkt2Part2,
+ pkt3Part1,
+ }, nil)
+
+ // Page 2: pkt3Part2, pkt4 (length 10)
+ pkt4 := bytes.Repeat([]byte{4}, 10)
+ page2Lacing := []byte{20, 10}
+ page2Data := bytes.Join([][]byte{
+ pkt3Part2,
+ pkt4,
+ }, nil)
+
+ b.Write(buildOggPage(page1Lacing, page1Data))
+ b.Write(buildOggPage(page2Lacing, page2Data))
+
+ var frames [][]byte
+ err := DecodeOggOpus(&b, func(frame []byte) error {
+ // making a copy to store as DecodeOggOpus might reuse backing array
+ cpy := make([]byte, len(frame))
+ copy(cpy, frame)
+ frames = append(frames, cpy)
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expectedFrames := [][]byte{
+ pkt1,
+ append(pkt2Part1, pkt2Part2...),
+ append(pkt3Part1, pkt3Part2...),
+ pkt4,
+ }
+
+ if len(frames) != len(expectedFrames) {
+ t.Fatalf("expected %d frames, got %d", len(expectedFrames), len(frames))
+ }
+
+ for i, expected := range expectedFrames {
+ if !reflect.DeepEqual(frames[i], expected) {
+ t.Errorf("frame %d mismatch:\nexp: %v\ngot: %v", i, expected, frames[i])
+ }
+ }
+}
+
+func TestDecodeOggOpus_Errors(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ errContains string
+ }{
+ {
+ name: "invalid magic string",
+ data: []byte(
+ "OggX\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ ),
+ errContains: "invalid ogg magic string",
+ },
+ {
+ name: "short header",
+ data: []byte("Ogg"),
+ errContains: "failed to read ogg header",
+ },
+ {
+ name: "eof in segment table",
+ data: func() []byte {
+ h := make([]byte, 27)
+ copy(h, "OggS")
+ h[26] = 5 // expects 5 bytes of segment table, but none provided
+ return h
+ }(),
+ errContains: "failed to read segment table",
+ },
+ {
+ name: "eof in segment data",
+ data: func() []byte {
+ h := make([]byte, 27, 28)
+ copy(h, "OggS")
+ h[26] = 1
+ return append(h, 100) // expects 100 bytes of data, but none provided
+ }(),
+ errContains: "failed to read segment data",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := DecodeOggOpus(bytes.NewReader(tt.data), func(b []byte) error { return nil })
+ if tt.name == "short header" {
+ if err != nil {
+ t.Errorf("expected no error (io.EOF/ErrUnexpectedEOF swallowed), got %v", err)
+ }
+ return
+ }
+ if err == nil {
+ t.Fatalf("expected error containing %q, got nil", tt.errContains)
+ }
+ if !strings.Contains(err.Error(), tt.errContains) {
+ t.Errorf("expected error to contain %q, got: %q", tt.errContains, err.Error())
+ }
+ })
+ }
+}
diff --git a/pkg/audio/sentence.go b/pkg/audio/sentence.go
new file mode 100644
index 000000000..89b9ac03e
--- /dev/null
+++ b/pkg/audio/sentence.go
@@ -0,0 +1,96 @@
+package audio
+
+import (
+ "strings"
+ "unicode"
+)
+
+// SplitSentences splits text into sentence-sized chunks suitable for TTS synthesis.
+// It splits on sentence-ending punctuation (.!?\n, as well as CJK 。, !, ?) while avoiding false splits
+// on decimal numbers. Very short fragments are merged with
+// the next sentence to prevent choppy playback.
+func SplitSentences(text string) []string {
+ if text == "" {
+ return nil
+ }
+
+ var sentences []string
+ var current strings.Builder
+ runes := []rune(text)
+
+ for i := 0; i < len(runes); i++ {
+ r := runes[i]
+ if r == '\n' {
+ s := strings.TrimSpace(current.String())
+ if s != "" {
+ sentences = append(sentences, s)
+ }
+ current.Reset()
+ continue
+ }
+
+ current.WriteRune(r)
+
+ if r == '.' || r == '!' || r == '?' || r == '。' || r == '!' || r == '?' {
+ // Avoid splitting on decimal numbers like "3.14"
+ if r == '.' && i > 0 && unicode.IsDigit(runes[i-1]) &&
+ i+1 < len(runes) && unicode.IsDigit(runes[i+1]) {
+ continue
+ }
+
+ // Consume contiguous punctuation clusters (e.g., "..." or "?!").
+ for i+1 < len(runes) && (runes[i+1] == '.' || runes[i+1] == '!' || runes[i+1] == '?' || runes[i+1] == '。' || runes[i+1] == '!' || runes[i+1] == '?') {
+ i++
+ current.WriteRune(runes[i])
+ }
+
+ s := strings.TrimSpace(current.String())
+ if s != "" {
+ sentences = append(sentences, s)
+ }
+ current.Reset()
+ }
+ }
+
+ // Flush remaining text
+ if s := strings.TrimSpace(current.String()); s != "" {
+ sentences = append(sentences, s)
+ }
+
+ // Merge very short fragments with the next sentence
+ return mergeShorties(sentences, 15)
+}
+
+// mergeShorties merges sentences shorter than minLen characters with the following sentence.
+func mergeShorties(sentences []string, minLen int) []string {
+ if len(sentences) <= 1 {
+ return sentences
+ }
+
+ var merged []string
+ var buf string
+
+ for _, s := range sentences {
+ if buf != "" {
+ buf += " " + s
+ if len([]rune(buf)) >= minLen {
+ merged = append(merged, buf)
+ buf = ""
+ }
+ } else if len([]rune(s)) < minLen {
+ buf = s
+ } else {
+ merged = append(merged, s)
+ }
+ }
+
+ if buf != "" {
+ if len(merged) > 0 {
+ merged[len(merged)-1] += " " + buf
+ } else {
+ merged = append(merged, buf)
+ }
+ }
+
+ return merged
+}
diff --git a/pkg/audio/sentence_test.go b/pkg/audio/sentence_test.go
new file mode 100644
index 000000000..54d69e4a6
--- /dev/null
+++ b/pkg/audio/sentence_test.go
@@ -0,0 +1,69 @@
+package audio
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestSplitSentences(t *testing.T) {
+ tests := []struct {
+ name string
+ in string
+ want []string
+ }{
+ {
+ name: "empty input",
+ in: "",
+ want: nil,
+ },
+ {
+ name: "single sentence",
+ in: "Hello world.",
+ want: []string{"Hello world."},
+ },
+ {
+ name: "decimal numbers do not split",
+ in: "The value is 3.14 today. Keep watching closely.",
+ want: []string{"The value is 3.14 today.", "Keep watching closely."},
+ },
+ {
+ name: "newline boundary",
+ in: "This is line number one\nThis is line number two",
+ want: []string{"This is line number one", "This is line number two"},
+ },
+ {
+ name: "newline with surrounding spaces",
+ in: " This is the first line \n This is the second line ",
+ want: []string{"This is the first line", "This is the second line"},
+ },
+ {
+ name: "trailing punctuation consumed",
+ in: "Please wait a moment... What on earth?! That is perfectly fine.",
+ want: []string{"Please wait a moment...", "What on earth?!", "That is perfectly fine."},
+ },
+ {
+ name: "short leading fragment merges with next",
+ in: "Hi. This is a longer sentence.",
+ want: []string{"Hi. This is a longer sentence."},
+ },
+ {
+ name: "consecutive short fragments keep merging",
+ in: "A. B. C. This is the real sentence.",
+ want: []string{"A. B. C. This is the real sentence."},
+ },
+ {
+ name: "short trailing fragment merges back",
+ in: "This sentence is long enough. End.",
+ want: []string{"This sentence is long enough. End."},
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := SplitSentences(tc.in)
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Fatalf("SplitSentences(%q) = %#v, want %#v", tc.in, got, tc.want)
+ }
+ })
+ }
+}
diff --git a/pkg/audio/tts/README.md b/pkg/audio/tts/README.md
new file mode 100644
index 000000000..ab8491da6
--- /dev/null
+++ b/pkg/audio/tts/README.md
@@ -0,0 +1,137 @@
+# TTS (Text-to-Speech)
+
+This package handles speech synthesis for PicoClaw.
+
+If you are new to TTS setup, the simplest workflow is:
+
+1. Add a TTS-capable entry to `model_list`.
+2. Point `voice.tts_model_name` at that entry.
+3. Put the API key in `.security.yml`.
+
+## Quick Recommendation
+
+For most users, these are the best starting points:
+
+| Provider | Why start here |
+| --- | --- |
+| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | Best-supported path in PicoClaw today. The current TTS implementation is built around the OpenAI-compatible `/audio/speech` API shape, and OpenAI is the safest default. |
+| [Xiaomi MiMo](https://platform.xiaomimimo.com) | A good second option if you want an OpenAI-compatible provider endpoint and are already using MiMo models in the rest of your stack. |
+
+## How TTS Configuration Works
+
+PicoClaw does not keep TTS API keys inside `voice`.
+
+Instead:
+
+- `voice.tts_model_name` selects a named entry from `model_list`.
+- That `model_list` entry provides the provider, model ID, API base, and proxy settings.
+- `.security.yml` stores the API key for the same named model entry.
+
+This is the recommended and supported configuration pattern.
+
+## Recommended Setup
+
+### Option A: OpenAI
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "tts_model_name": "openai-tts"
+ },
+ "model_list": [
+ {
+ "model_name": "openai-tts",
+ "model": "openai/tts-1"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ openai-tts:
+ api_keys:
+ - "sk-openai-your-key"
+```
+
+### Option B: Xiaomi MiMo
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "tts_model_name": "mimo-tts"
+ },
+ "model_list": [
+ {
+ "model_name": "mimo-tts",
+ "model": "mimo/mimo-v2-tts"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ mimo-tts:
+ api_keys:
+ - "your-mimo-key"
+```
+
+If you use a custom MiMo endpoint, you can also set `api_base` explicitly. Otherwise PicoClaw will use the provider default.
+
+## What PicoClaw Sends Today
+
+The current TTS runtime uses an OpenAI-compatible speech request with these defaults:
+
+- Endpoint: `/audio/speech`
+- Response format: `opus`
+- Voice: `alloy`
+- Model: taken from the selected `model_list` entry
+
+That means:
+
+- `openai/tts-1` works naturally.
+- Other OpenAI-compatible providers can work if they accept the same request format.
+- PicoClaw currently does not expose a user-facing config field for changing the TTS voice from `alloy`.
+
+## How PicoClaw Chooses a TTS Provider
+
+`DetectTTS` resolves TTS in this order:
+
+1. **Preferred path**: resolve `voice.tts_model_name` against `model_list`.
+2. If a matching model entry exists and has an API key, PicoClaw creates an OpenAI-compatible TTS provider using that model's settings.
+3. **Fallback path**: if `voice.tts_model_name` is not set or cannot be resolved, PicoClaw scans `model_list` for the first entry whose model string contains `tts` and has an API key.
+
+Fallback scanning exists for compatibility. New configs should set `voice.tts_model_name` explicitly.
+
+## Notes About API Base Handling
+
+PicoClaw normalizes the configured base URL for TTS:
+
+- For OpenAI, a base like `https://api.openai.com` or `https://api.openai.com/v1` becomes `https://api.openai.com/v1/audio/speech`.
+- For other OpenAI-compatible providers, PicoClaw preserves the configured base path and ensures it ends with `/audio/speech`.
+- If `api_base` is omitted, PicoClaw uses the provider default base when the model prefix is known.
+
+## Common Mistakes
+
+- Setting `voice.tts_model_name` to a name that does not exist in `model_list`.
+- Adding a TTS model but forgetting to put its API key in `.security.yml`.
+- Assuming PicoClaw will automatically use provider-specific custom voices.
+- Using a provider endpoint that is not compatible with the OpenAI `/audio/speech` request format.
+
+## Minimal Checklist
+
+Before testing `send_tts`, make sure:
+
+- `voice.tts_model_name` matches a `model_list[].model_name`.
+- The matching `.security.yml` entry contains a valid API key.
+- The chosen provider supports an OpenAI-compatible speech synthesis endpoint.
+- Your selected model is actually a TTS-capable model.
diff --git a/pkg/audio/tts/README_zh.md b/pkg/audio/tts/README_zh.md
new file mode 100644
index 000000000..a48b612a9
--- /dev/null
+++ b/pkg/audio/tts/README_zh.md
@@ -0,0 +1,137 @@
+# TTS(文本转语音)
+
+这个目录负责 PicoClaw 的语音合成能力。
+
+如果你是第一次配置 TTS,可以参照下面这个流程:
+
+1. 在 `model_list` 里添加一个支持 TTS 的模型。
+2. 用 `voice.tts_model_name` 指向这个模型。
+3. 在 `.security.yml` 里配置对应的 API Key。
+
+## 快速推荐
+
+对于大多数用户,建议优先从下面两种开始:
+
+| 提供商 | 推荐理由 |
+| --- | --- |
+| [OpenAI](https://platform.openai.com/docs/guides/text-to-speech) | 这是 PicoClaw 当前最稳定、最直接的 TTS 路径。当前实现就是围绕 OpenAI 兼容的 `/audio/speech` 接口格式构建的,所以 OpenAI 是最稳妥的默认选择。 |
+| [Xiaomi MiMo](https://platform.xiaomimimo.com) | 由于响应速度和语音音色对于中国用户更友好,MiMo 是一个不错的第二选择。 |
+
+## TTS 配置是如何工作的
+
+PicoClaw 不会把 TTS 的 API Key 放在 `voice` 配置里。
+
+推荐方式是:
+
+- `voice.tts_model_name` 用来选择 `model_list` 里的某个命名模型。
+- 对应的 `model_list` 条目提供真实的 provider、model ID、`api_base` 和代理配置。
+- `.security.yml` 负责保存该模型条目的 API Key。
+
+这是当前推荐且受支持的配置方式。
+
+## 推荐配置方式
+
+### 方案 A:OpenAI
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "tts_model_name": "openai-tts"
+ },
+ "model_list": [
+ {
+ "model_name": "openai-tts",
+ "model": "openai/tts-1"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ openai-tts:
+ api_keys:
+ - "sk-openai-your-key"
+```
+
+### 方案 B:Xiaomi MiMo
+
+`config.json`
+
+```json
+{
+ "voice": {
+ "tts_model_name": "mimo-tts"
+ },
+ "model_list": [
+ {
+ "model_name": "mimo-tts",
+ "model": "mimo/mimo-v2-tts"
+ }
+ ]
+}
+```
+
+`.security.yml`
+
+```yaml
+model_list:
+ mimo-tts:
+ api_keys:
+ - "your-mimo-key"
+```
+
+如果你使用自定义的 MiMo 接口地址,也可以显式设置 `api_base`。如果不设置,PicoClaw 会自动使用该 provider 的默认地址。
+
+## PicoClaw 当前实际发送的 TTS 请求
+
+当前 TTS 运行时使用的是 OpenAI 兼容的语音合成请求,并带有以下默认值:
+
+- Endpoint:`/audio/speech`
+- 返回格式:`opus`
+- Voice:`alloy`
+- Model:来自你所选中的 `model_list` 条目
+
+这意味着:
+
+- `openai/tts-1` 可以自然工作。
+- 其他 OpenAI 兼容 provider 也可能可用,前提是它们接受相同的请求格式。
+- PicoClaw 目前还没有对用户暴露一个配置项来修改 TTS voice,当前固定为 `alloy`。
+
+## PicoClaw 如何选择 TTS Provider
+
+`DetectTTS` 会按下面顺序选择 TTS:
+
+1. **首选路径**:根据 `voice.tts_model_name` 在 `model_list` 中找到对应模型。
+2. 如果找到了匹配条目,并且它有 API Key,PicoClaw 就会使用这个模型条目的配置创建一个 OpenAI 兼容的 TTS provider。
+3. **回退路径**:如果没有设置 `voice.tts_model_name`,或者该名字无法解析,PicoClaw 会扫描 `model_list`,选中第一个模型字符串里包含 `tts` 且带有 API Key 的条目。
+
+回退扫描只是为了兼容旧行为。新配置建议始终显式设置 `voice.tts_model_name`。
+
+## 关于 API Base 的处理方式
+
+PicoClaw 会对 TTS 的 `api_base` 做规范化处理:
+
+- 对 OpenAI 来说,像 `https://api.openai.com` 或 `https://api.openai.com/v1` 这样的地址,会自动变成 `https://api.openai.com/v1/audio/speech`。
+- 对其他 OpenAI 兼容 provider,PicoClaw 会尽量保留你提供的基础路径,只确保它最终以 `/audio/speech` 结尾。
+- 如果没有设置 `api_base`,并且模型前缀是已知 provider,PicoClaw 会自动使用该 provider 的默认地址。
+
+## 常见错误
+
+- `voice.tts_model_name` 指向了一个不存在的 `model_list` 名称。
+- 在 `model_list` 里定义了 TTS 模型,但忘了在 `.security.yml` 中配置对应 API Key。
+- 误以为 PicoClaw 会自动支持 provider 自定义 voice 参数。
+- 使用了不兼容 OpenAI `/audio/speech` 请求格式的接口地址。
+
+## 最小检查清单
+
+在测试 `send_tts` 之前,请确认:
+
+- `voice.tts_model_name` 能正确匹配某个 `model_list[].model_name`。
+- `.security.yml` 中对应条目已经配置了有效 API Key。
+- 你所选的 provider 支持 OpenAI 兼容的语音合成接口。
+- 你选择的模型本身确实支持 TTS。
diff --git a/pkg/audio/tts/mimo_tts.go b/pkg/audio/tts/mimo_tts.go
new file mode 100644
index 000000000..a8aee6b8c
--- /dev/null
+++ b/pkg/audio/tts/mimo_tts.go
@@ -0,0 +1,162 @@
+package tts
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+type MimoTTSProvider struct {
+ apiKey string
+ apiBase string
+ voice string
+ format string
+ model string
+ httpClient *http.Client
+}
+
+func NewMimoTTSProvider(apiKey string, apiBase string, model string, proxyURL string) *MimoTTSProvider {
+ if apiBase == "" {
+ apiBase = "https://api.xiaomimimo.com/v1/chat/completions"
+ } else {
+ if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" {
+ path := u.Path
+ if u.Host == "api.xiaomimimo.com" {
+ if path == "" || path == "/" || path == "/v1" || path == "/v1/" {
+ path = "/v1/chat/completions"
+ } else {
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ if !strings.HasPrefix(path, "/v1/") {
+ path = "/v1" + strings.TrimSuffix(path, "/")
+ }
+ if !strings.HasSuffix(path, "/chat/completions") {
+ path = strings.TrimSuffix(path, "/") + "/chat/completions"
+ }
+ }
+ } else {
+ if !strings.HasSuffix(path, "/chat/completions") {
+ path = strings.TrimSuffix(path, "/") + "/chat/completions"
+ }
+ }
+ u.Path = path
+ apiBase = u.String()
+ } else {
+ if apiBase == "https://api.xiaomimimo.com/v1" {
+ apiBase = "https://api.xiaomimimo.com/v1/chat/completions"
+ } else if !strings.HasSuffix(apiBase, "/chat/completions") {
+ apiBase = strings.TrimSuffix(apiBase, "/") + "/chat/completions"
+ }
+ }
+ }
+
+ model = strings.TrimSpace(model)
+ if model == "" {
+ model = "mimo-v2-tts"
+ }
+
+ client := &http.Client{Timeout: 60 * time.Second}
+ if proxyURL != "" {
+ if pURL, err := url.Parse(proxyURL); err == nil {
+ client.Transport = &http.Transport{Proxy: http.ProxyURL(pURL)}
+ } else {
+ logger.WarnF(
+ "NewMimoTTSProvider: invalid proxy URL; proceeding without proxy",
+ map[string]any{"proxyURL": proxyURL, "error": err},
+ )
+ }
+ }
+
+ return &MimoTTSProvider{
+ apiKey: apiKey,
+ apiBase: apiBase,
+ voice: "default_zh", // mimo_default now seems to be an alias for default_en, which is not working for Chinese TTS. default_zh seems to work fine with both English and Chinese, and is likely the intended default for TTS.
+ format: "mp3",
+ model: model,
+ httpClient: client,
+ }
+}
+
+func (t *MimoTTSProvider) Name() string {
+ return "mimo-tts"
+}
+
+func (t *MimoTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
+ logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text), "provider": t.Name()})
+
+ reqBody := map[string]any{
+ "model": t.model,
+ "messages": []map[string]string{
+ {"role": "assistant", "content": text},
+ },
+ "audio": map[string]string{
+ "format": t.format,
+ "voice": t.voice,
+ },
+ "stream": false,
+ }
+
+ jsonData, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Api-Key", t.apiKey)
+
+ resp, err := t.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+ }
+
+ var payload struct {
+ Choices []struct {
+ Message struct {
+ Audio struct {
+ Data string `json:"data"`
+ } `json:"audio"`
+ } `json:"message"`
+ } `json:"choices"`
+ }
+
+ err = json.Unmarshal(body, &payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ if len(payload.Choices) == 0 || payload.Choices[0].Message.Audio.Data == "" {
+ return nil, fmt.Errorf("invalid TTS response: missing audio data")
+ }
+
+ audioBytes, err := base64.StdEncoding.DecodeString(payload.Choices[0].Message.Audio.Data)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode audio data: %w", err)
+ }
+
+ return io.NopCloser(bytes.NewReader(audioBytes)), nil
+}
diff --git a/pkg/audio/tts/openai_tts.go b/pkg/audio/tts/openai_tts.go
new file mode 100644
index 000000000..786414873
--- /dev/null
+++ b/pkg/audio/tts/openai_tts.go
@@ -0,0 +1,126 @@
+package tts
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/providers/common"
+)
+
+type OpenAITTSProvider struct {
+ apiKey string
+ apiBase string
+ voice string
+ model string
+ httpClient *http.Client
+}
+
+func NewOpenAITTSProvider(apiKey string, apiBase string, proxyURL string, model string) *OpenAITTSProvider {
+ // Normalize apiBase to avoid malformed endpoints like
+ // "https://api.openai.com/audio/speech" when "/v1" is required.
+ if apiBase == "" {
+ apiBase = "https://api.openai.com/v1/audio/speech"
+ } else {
+ if u, err := url.Parse(apiBase); err == nil && u.Scheme != "" && u.Host != "" {
+ path := u.Path
+ if u.Host == "api.openai.com" {
+ // For the official OpenAI host, ensure exactly one /v1 prefix and
+ // that the path ends with /audio/speech.
+ if path == "" || path == "/" || path == "/v1" {
+ path = "/v1/audio/speech"
+ } else {
+ if !strings.HasPrefix(path, "/") {
+ path = "/" + path
+ }
+ if !strings.HasPrefix(path, "/v1/") {
+ path = "/v1" + strings.TrimSuffix(path, "/")
+ }
+ if !strings.HasSuffix(path, "/audio/speech") {
+ path = strings.TrimSuffix(path, "/") + "/audio/speech"
+ }
+ }
+ } else {
+ // For non-OpenAI hosts (e.g., proxies), preserve the existing base
+ // path and only ensure it ends with /audio/speech.
+ if !strings.HasSuffix(path, "/audio/speech") {
+ path = strings.TrimSuffix(path, "/") + "/audio/speech"
+ }
+ }
+ u.Path = path
+ apiBase = u.String()
+ } else {
+ // Fallback to the previous string-based behavior if parsing fails.
+ if apiBase == "https://api.openai.com/v1" {
+ apiBase = "https://api.openai.com/v1/audio/speech"
+ } else if !strings.HasSuffix(apiBase, "/audio/speech") {
+ // Just in case they provide openrouter base or standard base
+ apiBase = strings.TrimSuffix(apiBase, "/") + "/audio/speech"
+ }
+ }
+ }
+
+ client := common.NewHTTPClient(proxyURL)
+ client.Timeout = 60 * time.Second
+
+ model = strings.TrimSpace(model)
+ if model == "" {
+ model = "tts-1"
+ }
+
+ return &OpenAITTSProvider{
+ apiKey: apiKey,
+ apiBase: apiBase,
+ voice: "alloy",
+ model: model,
+ httpClient: client,
+ }
+}
+
+func (t *OpenAITTSProvider) Name() string {
+ return "openai-tts"
+}
+
+func (t *OpenAITTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
+ logger.DebugCF("voice-tts", "Starting TTS synthesis", map[string]any{"text_len": len(text)})
+
+ reqBody := map[string]any{
+ "model": t.model,
+ "input": text,
+ "voice": t.voice,
+ "response_format": "opus",
+ }
+
+ jsonData, err := json.Marshal(reqBody)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST", t.apiBase, bytes.NewReader(jsonData))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+t.apiKey)
+
+ resp, err := t.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ body, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+ }
+
+ return resp.Body, nil
+}
diff --git a/pkg/audio/tts/tts.go b/pkg/audio/tts/tts.go
new file mode 100644
index 000000000..99a9ef203
--- /dev/null
+++ b/pkg/audio/tts/tts.go
@@ -0,0 +1,151 @@
+package tts
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+type TTSProvider interface {
+ Name() string
+ Synthesize(ctx context.Context, text string) (io.ReadCloser, error)
+}
+
+func providerFromModelConfig(mc *config.ModelConfig) TTSProvider {
+ if mc == nil || mc.APIKey() == "" {
+ return nil
+ }
+
+ protocol, modelID := providers.ExtractProtocol(mc.Model)
+ if modelID == "" {
+ modelID = strings.TrimSpace(mc.Model)
+ }
+
+ switch protocol {
+ case "mimo":
+ return NewMimoTTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), modelID, mc.Proxy)
+ default:
+ return NewOpenAITTSProvider(mc.APIKey(), providers.ResolveAPIBase(mc), mc.Proxy, modelID)
+ }
+}
+
+func DetectTTS(cfg *config.Config) TTSProvider {
+ if cfg == nil {
+ return nil
+ }
+
+ if modelName := strings.TrimSpace(cfg.Voice.TTSModelName); modelName != "" {
+ if mc, err := cfg.GetModelConfig(modelName); err == nil {
+ if provider := providerFromModelConfig(mc); provider != nil {
+ return provider
+ }
+ }
+ }
+
+ for _, mc := range cfg.ModelList {
+ if strings.Contains(strings.ToLower(mc.Model), "tts") && mc.APIKey() != "" {
+ if provider := providerFromModelConfig(mc); provider != nil {
+ return provider
+ }
+ }
+ }
+ return nil
+}
+
+// SynthesizeAndStore synthesizes text to speech and registers it in the media store, returning the media reference.
+func SynthesizeAndStore(
+ ctx context.Context,
+ provider TTSProvider,
+ store media.MediaStore,
+ text string,
+ filename string,
+ channel string,
+ chatID string,
+) (string, error) {
+ if provider == nil {
+ return "", fmt.Errorf("tts provider is not configured")
+ }
+ if store == nil {
+ return "", fmt.Errorf("media store not configured")
+ }
+ if channel == "" || chatID == "" {
+ return "", fmt.Errorf("no target channel/chat available")
+ }
+ if strings.TrimSpace(text) == "" {
+ return "", fmt.Errorf("text is required")
+ }
+
+ stream, err := provider.Synthesize(ctx, text)
+ if err != nil {
+ return "", fmt.Errorf("tts synthesize failed: %w", err)
+ }
+ defer stream.Close()
+
+ err = os.MkdirAll(media.TempDir(), 0o700)
+ if err != nil {
+ return "", fmt.Errorf("failed to create media temp dir: %w", err)
+ }
+
+ fileExt := ".ogg"
+ contentType := "audio/ogg"
+ if provider.Name() == "mimo-tts" {
+ fileExt = ".mp3"
+ contentType = "audio/mpeg"
+ }
+
+ file, err := os.CreateTemp(media.TempDir(), "tts-*"+fileExt)
+ if err != nil {
+ return "", fmt.Errorf("failed to create temp file: %w", err)
+ }
+
+ removeTemp := true
+ defer func() {
+ if removeTemp {
+ _ = os.Remove(file.Name())
+ }
+ }()
+
+ _, err = io.Copy(file, stream)
+ if err != nil {
+ file.Close()
+ return "", fmt.Errorf("failed to write tts audio: %w", err)
+ }
+
+ err = file.Close()
+ if err != nil {
+ return "", fmt.Errorf("failed to close tts audio file: %w", err)
+ }
+
+ filename = strings.TrimSpace(filename)
+ if filename == "" {
+ filename = fmt.Sprintf("tts-%d%s", time.Now().Unix(), fileExt)
+ }
+
+ ext := strings.ToLower(filepath.Ext(filename))
+ if ext == "" {
+ filename += fileExt
+ } else if ext != fileExt {
+ filename = strings.TrimSuffix(filename, filepath.Ext(filename)) + fileExt
+ }
+
+ scope := fmt.Sprintf("tool:send_tts:%s:%s:%d", channel, chatID, time.Now().UnixNano())
+ ref, err := store.Store(file.Name(), media.MediaMeta{
+ Filename: filename,
+ ContentType: contentType,
+ Source: "tool:send_tts",
+ }, scope)
+ if err != nil {
+ return "", fmt.Errorf("failed to register audio: %w", err)
+ }
+ removeTemp = false
+
+ return ref, nil
+}
diff --git a/pkg/audio/tts/tts_test.go b/pkg/audio/tts/tts_test.go
new file mode 100644
index 000000000..053aa7220
--- /dev/null
+++ b/pkg/audio/tts/tts_test.go
@@ -0,0 +1,247 @@
+package tts
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/media"
+)
+
+func TestNewOpenAITTSProvider_APIBaseNormalization(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ name string
+ input string
+ expect string
+ }{
+ {
+ name: "empty base",
+ input: "",
+ expect: "https://api.openai.com/v1/audio/speech",
+ },
+ {
+ name: "official host no path",
+ input: "https://api.openai.com",
+ expect: "https://api.openai.com/v1/audio/speech",
+ },
+ {
+ name: "official host v1",
+ input: "https://api.openai.com/v1",
+ expect: "https://api.openai.com/v1/audio/speech",
+ },
+ {
+ name: "official host v1 slash",
+ input: "https://api.openai.com/v1/",
+ expect: "https://api.openai.com/v1/audio/speech",
+ },
+ {
+ name: "non-openai host preserves base path",
+ input: "https://proxy.example.com/base",
+ expect: "https://proxy.example.com/base/audio/speech",
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ provider := NewOpenAITTSProvider("key", tc.input, "", "")
+ if provider.apiBase != tc.expect {
+ t.Fatalf("apiBase mismatch: got %q, want %q", provider.apiBase, tc.expect)
+ }
+ })
+ }
+}
+
+func TestOpenAITTSProvider_SynthesizeSuccess(t *testing.T) {
+ t.Parallel()
+
+ var gotPath string
+ var gotAuth string
+ var gotContentType string
+ var gotBody map[string]any
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ gotAuth = r.Header.Get("Authorization")
+ gotContentType = r.Header.Get("Content-Type")
+
+ bodyBytes, _ := io.ReadAll(r.Body)
+ _ = r.Body.Close()
+ _ = json.Unmarshal(bodyBytes, &gotBody)
+
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("audio-bytes"))
+ }))
+ defer server.Close()
+
+ provider := NewOpenAITTSProvider("k123", server.URL, "", "")
+ stream, err := provider.Synthesize(context.Background(), "hello")
+ if err != nil {
+ t.Fatalf("Synthesize failed: %v", err)
+ }
+ defer stream.Close()
+
+ data, err := io.ReadAll(stream)
+ if err != nil {
+ t.Fatalf("read stream failed: %v", err)
+ }
+
+ if gotPath != "/audio/speech" {
+ t.Fatalf("request path mismatch: got %q", gotPath)
+ }
+ if gotAuth != "Bearer k123" {
+ t.Fatalf("authorization mismatch: got %q", gotAuth)
+ }
+ if gotContentType != "application/json" {
+ t.Fatalf("content-type mismatch: got %q", gotContentType)
+ }
+ if gotBody["model"] != "tts-1" || gotBody["voice"] != "alloy" || gotBody["response_format"] != "opus" ||
+ gotBody["input"] != "hello" {
+ bodyJSON, _ := json.Marshal(gotBody)
+ t.Fatalf("request body mismatch: %s", string(bodyJSON))
+ }
+ if string(data) != "audio-bytes" {
+ t.Fatalf("response body mismatch: got %q", string(data))
+ }
+}
+
+func TestOpenAITTSProvider_SynthesizeNon200(t *testing.T) {
+ t.Parallel()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("nope"))
+ }))
+ defer server.Close()
+
+ provider := NewOpenAITTSProvider("k123", server.URL, "", "")
+ _, err := provider.Synthesize(context.Background(), "hello")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "API error (status 500): nope") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestNewOpenAITTSProvider_UsesConfiguredModel(t *testing.T) {
+ t.Parallel()
+
+ provider := NewOpenAITTSProvider("key", "https://api.xiaomimimo.com/v1", "", "mimo-v2-tts")
+ if provider.model != "mimo-v2-tts" {
+ t.Fatalf("model mismatch: got %q, want %q", provider.model, "mimo-v2-tts")
+ }
+ if provider.apiBase != "https://api.xiaomimimo.com/v1/audio/speech" {
+ t.Fatalf("apiBase mismatch: got %q", provider.apiBase)
+ }
+}
+
+func TestDetectTTS_UsesMimoProviderForMimoModels(t *testing.T) {
+ t.Parallel()
+
+ provider := DetectTTS(&config.Config{
+ Voice: config.VoiceConfig{TTSModelName: "mimo-tts"},
+ ModelList: []*config.ModelConfig{
+ {
+ ModelName: "mimo-tts",
+ Model: "mimo/mimo-v2-tts",
+ APIKeys: config.SimpleSecureStrings("sk-mimo"),
+ },
+ },
+ })
+
+ ttsProvider, ok := provider.(*MimoTTSProvider)
+ if !ok {
+ t.Fatalf("DetectTTS() type = %T, want *MimoTTSProvider", provider)
+ }
+ if ttsProvider.model != "mimo-v2-tts" {
+ t.Fatalf("model mismatch: got %q, want %q", ttsProvider.model, "mimo-v2-tts")
+ }
+ if ttsProvider.apiBase != "https://api.xiaomimimo.com/v1/chat/completions" {
+ t.Fatalf("apiBase mismatch: got %q", ttsProvider.apiBase)
+ }
+}
+
+type stubTTSProvider struct {
+ name string
+}
+
+func (s stubTTSProvider) Name() string {
+ return s.name
+}
+
+func (s stubTTSProvider) Synthesize(ctx context.Context, text string) (io.ReadCloser, error) {
+ return io.NopCloser(strings.NewReader("audio")), nil
+}
+
+func TestSynthesizeAndStore_UsesOggMetadataByDefault(t *testing.T) {
+ t.Parallel()
+
+ store := media.NewFileMediaStore()
+ ref, err := SynthesizeAndStore(
+ context.Background(),
+ stubTTSProvider{name: "openai-tts"},
+ store,
+ "hello",
+ "",
+ "discord",
+ "chat123",
+ )
+ if err != nil {
+ t.Fatalf("SynthesizeAndStore failed: %v", err)
+ }
+
+ path, meta, err := store.ResolveWithMeta(ref)
+ if err != nil {
+ t.Fatalf("ResolveWithMeta failed: %v", err)
+ }
+ if meta.ContentType != "audio/ogg" {
+ t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/ogg")
+ }
+ if filepath.Ext(path) != ".ogg" {
+ t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".ogg")
+ }
+ if filepath.Ext(meta.Filename) != ".ogg" {
+ t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".ogg")
+ }
+}
+
+func TestSynthesizeAndStore_UsesMp3MetadataForMimo(t *testing.T) {
+ t.Parallel()
+
+ store := media.NewFileMediaStore()
+ ref, err := SynthesizeAndStore(
+ context.Background(),
+ stubTTSProvider{name: "mimo-tts"},
+ store,
+ "hello",
+ "",
+ "discord",
+ "chat123",
+ )
+ if err != nil {
+ t.Fatalf("SynthesizeAndStore failed: %v", err)
+ }
+
+ path, meta, err := store.ResolveWithMeta(ref)
+ if err != nil {
+ t.Fatalf("ResolveWithMeta failed: %v", err)
+ }
+ if meta.ContentType != "audio/mpeg" {
+ t.Fatalf("ContentType = %q, want %q", meta.ContentType, "audio/mpeg")
+ }
+ if filepath.Ext(path) != ".mp3" {
+ t.Fatalf("stored file extension = %q, want %q", filepath.Ext(path), ".mp3")
+ }
+ if filepath.Ext(meta.Filename) != ".mp3" {
+ t.Fatalf("filename extension = %q, want %q", filepath.Ext(meta.Filename), ".mp3")
+ }
+}
diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go
index 427d20779..7c93c2d18 100644
--- a/pkg/channels/discord/discord.go
+++ b/pkg/channels/discord/discord.go
@@ -3,6 +3,7 @@ package discord
import (
"context"
"fmt"
+ "io"
"net/http"
"net/url"
"os"
@@ -14,6 +15,8 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/gorilla/websocket"
+ "github.com/sipeed/picoclaw/pkg/audio"
+ "github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
@@ -42,6 +45,15 @@ type DiscordChannel struct {
typingMu sync.Mutex
typingStop map[string]chan struct{} // chatID → stop signal
botUserID string // stored for mention checking
+ bus *bus.MessageBus
+ tts tts.TTSProvider
+ voiceMu sync.RWMutex
+ voiceSSRC map[string]map[uint32]string // guildID -> ssrc -> userID
+
+ // TTS interruption: cancel active playback when user speaks
+ ttsMu sync.Mutex
+ cancelTTS context.CancelFunc
+ ttsPlayID uint64
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -73,6 +85,8 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
config: cfg,
ctx: context.Background(),
typingStop: make(map[string]chan struct{}),
+ bus: bus,
+ voiceSSRC: make(map[string]map[uint32]string),
}, nil
}
@@ -90,6 +104,8 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
c.session.AddHandler(c.handleMessage)
+ go c.listenVoiceControl(c.ctx)
+
if err := c.session.Open(); err != nil {
return fmt.Errorf("failed to open discord session: %w", err)
}
@@ -142,6 +158,25 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]s
return nil, nil
}
+ if c.tts != nil {
+ if ch, err := c.session.State.Channel(channelID); err == nil && ch.GuildID != "" {
+ if vc, ok := c.session.VoiceConnections[ch.GuildID]; ok && vc != nil {
+ // Cancel any previous TTS playback
+ c.ttsMu.Lock()
+ if c.cancelTTS != nil {
+ c.cancelTTS()
+ }
+ ttsCtx, ttsCancel := context.WithCancel(c.ctx)
+ c.ttsPlayID++
+ playID := c.ttsPlayID
+ c.cancelTTS = ttsCancel
+ c.ttsMu.Unlock()
+
+ go c.playTTS(ttsCtx, vc, msg.Content, playID)
+ }
+ }
+ }
+
msgID, err := c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID)
if err != nil {
return nil, err
@@ -359,6 +394,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
+ if c.handleVoiceCommand(s, m) {
+ return
+ }
+
content := m.Content
// In guild (group) channels, apply unified group trigger filtering
@@ -642,3 +681,134 @@ func (c *DiscordChannel) stripBotMention(text string) string {
text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
return strings.TrimSpace(text)
}
+
+func (c *DiscordChannel) listenVoiceControl(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case ctrl, ok := <-c.bus.VoiceControlsChan():
+ if !ok {
+ return
+ }
+ if ctrl.Type == "command" && ctrl.Action == "leave" {
+ if strings.HasPrefix(ctrl.SessionID, "discord_vc_") {
+ guildID := strings.TrimPrefix(ctrl.SessionID, "discord_vc_")
+ vc, exists := c.session.VoiceConnections[guildID]
+ if exists && vc != nil {
+ vc.Disconnect(ctx)
+ }
+ }
+ }
+ }
+ }
+}
+
+func (c *DiscordChannel) playTTS(ctx context.Context, vc *discordgo.VoiceConnection, text string, playID uint64) {
+ // Capture the cancel func associated with this playback (if any).
+ // Clear cancelTTS when playback finishes (normal or interrupted),
+ // but only if it still refers to this playback's cancel func.
+ defer func() {
+ c.ttsMu.Lock()
+ if c.ttsPlayID == playID {
+ c.cancelTTS = nil
+ }
+ c.ttsMu.Unlock()
+ }()
+
+ sentences := audio.SplitSentences(text)
+ if len(sentences) == 0 {
+ return
+ }
+
+ logger.InfoCF("discord", "Starting streamed TTS", map[string]any{"sentences": len(sentences)})
+
+ // Pipeline: prefetch next sentence's audio while playing current
+ type ttResult struct {
+ stream io.ReadCloser
+ err error
+ }
+
+ var prefetch chan ttResult
+
+ // Ensure any in-flight prefetch is drained on exit to prevent stream leaks,
+ // but avoid blocking indefinitely if the prefetch goroutine is stuck or never sends.
+ defer func() {
+ if prefetch != nil {
+ select {
+ case result := <-prefetch:
+ if result.stream != nil {
+ result.stream.Close()
+ }
+ case <-time.After(100 * time.Millisecond):
+ // Timed out waiting for a prefetched result; avoid blocking on exit.
+ }
+ }
+ }()
+
+ for i, sentence := range sentences {
+ // Check for cancellation (interruption)
+ select {
+ case <-ctx.Done():
+ logger.InfoCF("discord", "TTS interrupted", map[string]any{"at_sentence": i})
+ return
+ default:
+ }
+
+ // Start prefetching the NEXT sentence while we process the current one
+ var nextPrefetch chan ttResult
+ if i+1 < len(sentences) {
+ nextPrefetch = make(chan ttResult, 1)
+ nextSentence := sentences[i+1]
+ go func() {
+ s, e := c.tts.Synthesize(ctx, nextSentence)
+ nextPrefetch <- ttResult{s, e}
+ }()
+ }
+
+ // Get the current sentence's audio
+ var stream io.ReadCloser
+ var err error
+
+ if prefetch != nil {
+ // Use prefetched result from previous iteration, but be responsive to cancellation.
+ var result ttResult
+ select {
+ case result = <-prefetch:
+ stream, err = result.stream, result.err
+ case <-ctx.Done():
+ // Context canceled while waiting for prefetched audio; abort playback.
+ logger.InfoCF(
+ "discord",
+ "TTS interrupted while waiting for prefetched audio",
+ map[string]any{"at_sentence": i},
+ )
+ return
+ }
+ } else {
+ // First sentence: synthesize directly
+ stream, err = c.tts.Synthesize(ctx, sentence)
+ }
+
+ if err != nil {
+ if stream != nil {
+ stream.Close()
+ }
+ logger.ErrorCF("discord", "TTS synthesize failed", map[string]any{"error": err.Error(), "sentence": i})
+ prefetch = nextPrefetch
+ continue
+ }
+
+ if err := streamOggOpusToDiscord(ctx, vc, stream); err != nil {
+ logger.ErrorCF("discord", "TTS playback failed", map[string]any{"error": err.Error(), "sentence": i})
+ }
+ stream.Close()
+
+ prefetch = nextPrefetch
+ }
+}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *DiscordChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go
index 15a539804..8381dc9e9 100644
--- a/pkg/channels/discord/init.go
+++ b/pkg/channels/discord/init.go
@@ -1,6 +1,7 @@
package discord
import (
+ "github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
@@ -8,6 +9,10 @@ import (
func init() {
channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
- return NewDiscordChannel(cfg.Channels.Discord, b)
+ ch, err := NewDiscordChannel(cfg.Channels.Discord, b)
+ if err == nil {
+ ch.tts = tts.DetectTTS(cfg)
+ }
+ return ch, err
})
}
diff --git a/pkg/channels/discord/voice.go b/pkg/channels/discord/voice.go
new file mode 100644
index 000000000..554b8ae71
--- /dev/null
+++ b/pkg/channels/discord/voice.go
@@ -0,0 +1,314 @@
+package discord
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/bwmarrin/discordgo"
+
+ "github.com/sipeed/picoclaw/pkg/audio"
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+func (c *DiscordChannel) setVoiceUserID(guildID string, ssrc uint32, userID string) {
+ if userID == "" {
+ return
+ }
+
+ c.voiceMu.Lock()
+ defer c.voiceMu.Unlock()
+
+ ssrcMap, ok := c.voiceSSRC[guildID]
+ if !ok {
+ ssrcMap = make(map[uint32]string)
+ c.voiceSSRC[guildID] = ssrcMap
+ }
+ ssrcMap[ssrc] = userID
+}
+
+func (c *DiscordChannel) voiceUserID(guildID string, ssrc uint32) string {
+ c.voiceMu.RLock()
+ defer c.voiceMu.RUnlock()
+
+ ssrcMap, ok := c.voiceSSRC[guildID]
+ if !ok {
+ return ""
+ }
+ return ssrcMap[ssrc]
+}
+
+func (c *DiscordChannel) handleVoiceCommand(s *discordgo.Session, m *discordgo.MessageCreate) bool {
+ if m.Content == "!vc join" {
+ vs, err := s.State.VoiceState(m.GuildID, m.Author.ID)
+ if err != nil || vs == nil {
+ if _, sendErr := s.ChannelMessageSend(
+ m.ChannelID,
+ "You need to be in a voice channel first!",
+ ); sendErr != nil {
+ logger.InfoCF("discord", "Failed to send voice channel requirement message", map[string]any{
+ "channel": m.ChannelID,
+ "error": sendErr,
+ })
+ }
+ return true
+ }
+
+ logger.InfoCF("discord", "Joining voice channel", map[string]any{"channel": vs.ChannelID})
+ vc, err := s.ChannelVoiceJoin(c.ctx, m.GuildID, vs.ChannelID, false, false)
+ if err != nil {
+ if _, sendErr := s.ChannelMessageSend(
+ m.ChannelID,
+ fmt.Sprintf("Failed to join voice channel: %v", err),
+ ); sendErr != nil {
+ logger.InfoCF("discord", "Failed to send voice join error message", map[string]any{
+ "channel": m.ChannelID,
+ "error": sendErr,
+ })
+ }
+ return true
+ }
+
+ go c.receiveVoice(vc, m.GuildID, m.ChannelID)
+ if _, sendErr := s.ChannelMessageSend(
+ m.ChannelID,
+ "Joined Voice Channel! Listening for audio...",
+ ); sendErr != nil {
+ logger.InfoCF("discord", "Failed to send voice join success message", map[string]any{
+ "channel": m.ChannelID,
+ "error": sendErr,
+ })
+ }
+ return true
+ } else if m.Content == "!vc leave" {
+ vc, exists := s.VoiceConnections[m.GuildID]
+ if exists && vc != nil {
+ if err := vc.Disconnect(c.ctx); err != nil {
+ logger.InfoCF("discord", "Failed to disconnect from voice channel", map[string]any{
+ "guild": m.GuildID,
+ "error": err,
+ })
+ }
+ if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Left Voice Channel."); sendErr != nil {
+ logger.InfoCF("discord", "Failed to send voice leave success message", map[string]any{
+ "channel": m.ChannelID,
+ "error": sendErr,
+ })
+ }
+ } else {
+ if _, sendErr := s.ChannelMessageSend(m.ChannelID, "Not in a voice channel."); sendErr != nil {
+ logger.InfoCF("discord", "Failed to send voice not-in-channel message", map[string]any{
+ "channel": m.ChannelID,
+ "error": sendErr,
+ })
+ }
+ }
+ return true
+ }
+ return false
+}
+
+func VoiceReceiveActive(vc *discordgo.VoiceConnection) bool {
+ return vc != nil && vc.OpusRecv != nil
+}
+
+func streamOggOpusToDiscord(ctx context.Context, vc *discordgo.VoiceConnection, r io.Reader) (retErr error) {
+ // Recover from panic if vc.OpusSend is closed mid-send (e.g. on disconnect)
+ defer func() {
+ if rec := recover(); rec != nil {
+ retErr = fmt.Errorf("voice connection closed during playback")
+ logger.RecoverPanicNoExit(rec)
+ }
+ }()
+
+ // Wait for the speaking transition to register
+ vc.Speaking(true)
+ defer vc.Speaking(false)
+
+ return audio.DecodeOggOpus(r, func(frame []byte) error {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case vc.OpusSend <- frame:
+ return nil
+ }
+ })
+}
+
+func (c *DiscordChannel) receiveVoice(vc *discordgo.VoiceConnection, guildID string, chatID string) {
+ logger.InfoCF("discord", "Started listening for voice", map[string]any{"guild": guildID})
+
+ vc.AddHandler(func(_ *discordgo.VoiceConnection, vs *discordgo.VoiceSpeakingUpdate) {
+ if vs == nil {
+ return
+ }
+ c.setVoiceUserID(guildID, uint32(vs.SSRC), vs.UserID)
+ })
+
+ defer func() {
+ c.voiceMu.Lock()
+ delete(c.voiceSSRC, guildID)
+ c.voiceMu.Unlock()
+ }()
+
+ go func(ctx context.Context, vc *discordgo.VoiceConnection) {
+ // Recover from potential panics if OpusSend is closed mid-send.
+ defer func() {
+ if rec := recover(); rec != nil {
+ logger.WarnCF("discord", "Recovered from panic while sending wake-up frames", map[string]any{
+ "error": rec,
+ "guild": guildID,
+ })
+ }
+ }()
+
+ // If the voice connection or OpusSend are not available, nothing to do.
+ if vc == nil || vc.OpusSend == nil {
+ return
+ }
+
+ time.Sleep(250 * time.Millisecond) // Wait a bit for connection to settle
+
+ // Abort if the context has already been canceled.
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ vc.Speaking(true)
+ defer vc.Speaking(false)
+
+ silenceFrame := []byte{0xF8, 0xFF, 0xFE}
+ for i := 0; i < 5; i++ {
+ select {
+ case <-ctx.Done():
+ return
+ case vc.OpusSend <- silenceFrame:
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+
+ logger.DebugCF("discord", "Sent wake-up silence frames", map[string]any{"guild": guildID})
+ }(c.ctx, vc)
+ sessionID := fmt.Sprintf("discord_vc_%s", guildID)
+
+ c.bus.PublishVoiceControl(c.ctx, bus.VoiceControl{
+ SessionID: sessionID,
+ Type: "state",
+ Action: "listening",
+ })
+
+ var sequence uint64 = 0
+ var interruptCount int
+ var lastInterruptAt time.Time
+
+ for {
+ select {
+ case <-c.ctx.Done():
+ return
+ case p, ok := <-vc.OpusRecv:
+ if !ok {
+ logger.InfoCF("discord", "Voice channel closed", map[string]any{"guild": guildID})
+ // Cancel any TTS that may still be playing
+ c.ttsMu.Lock()
+ if c.cancelTTS != nil {
+ c.cancelTTS()
+ c.cancelTTS = nil
+ }
+ c.ttsMu.Unlock()
+ return
+ }
+
+ if p == nil {
+ logger.DebugCF("discord", "Received nil Opus packet", nil)
+ continue
+ }
+
+ if len(p.Opus) == 0 {
+ logger.DebugCF("discord", "Received empty Opus packet", map[string]any{
+ "seq": p.Sequence,
+ "ssrc": p.SSRC,
+ })
+ continue
+ }
+
+ logger.DebugCF("discord", "Received Opus packet", map[string]any{
+ "seq": p.Sequence,
+ "len": len(p.Opus),
+ "ssrc": p.SSRC,
+ })
+ // Interruption detection: if user sends voice while TTS is playing,
+ // cancel TTS after a short debounce (3 packets in 200ms)
+ now := time.Now()
+ if now.Sub(lastInterruptAt) > 500*time.Millisecond {
+ interruptCount = 0
+ }
+ interruptCount++
+ lastInterruptAt = now
+
+ if interruptCount >= 3 {
+ c.ttsMu.Lock()
+ if c.cancelTTS != nil {
+ c.cancelTTS()
+ c.cancelTTS = nil
+ logger.InfoCF("discord", "TTS interrupted by user voice", nil)
+ }
+ c.ttsMu.Unlock()
+ interruptCount = 0
+ }
+
+ userID := c.voiceUserID(guildID, p.SSRC)
+ if userID == "" {
+ logger.DebugCF("discord", "Dropping voice packet without user mapping", map[string]any{
+ "ssrc": p.SSRC,
+ "guild": guildID,
+ })
+ continue
+ }
+
+ sender := bus.SenderInfo{
+ Platform: "discord",
+ PlatformID: userID,
+ CanonicalID: identity.BuildCanonicalID("discord", userID),
+ }
+ if !c.IsAllowedSender(sender) {
+ logger.DebugCF("discord", "Voice packet rejected by allowlist", map[string]any{
+ "user_id": userID,
+ "guild": guildID,
+ })
+ continue
+ }
+
+ sequence++
+
+ chunk := bus.AudioChunk{
+ SessionID: sessionID,
+ SpeakerID: userID,
+ ChatID: chatID,
+ Channel: "discord",
+ Sequence: sequence,
+ Timestamp: p.Timestamp,
+ SampleRate: 48000,
+ Channels: 2,
+ Format: "opus",
+ Data: p.Opus,
+ }
+
+ ctx, cancel := context.WithTimeout(c.ctx, 100*time.Millisecond)
+ err := c.bus.PublishAudioChunk(ctx, chunk)
+ cancel()
+ if err != nil {
+ logger.ErrorCF("discord", "Failed to publish audio chunk", map[string]any{
+ "guild": guildID,
+ "sessionID": sessionID,
+ "sequence": sequence,
+ "error": err.Error(),
+ })
+ }
+ }
+ }
+}
diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go
index 4952394b7..81238460a 100644
--- a/pkg/channels/feishu/common.go
+++ b/pkg/channels/feishu/common.go
@@ -6,6 +6,8 @@ import (
"strings"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
+
+ "github.com/sipeed/picoclaw/pkg/channels"
)
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
@@ -145,3 +147,8 @@ func extractImageKeysRecursive(v any, feishuKeys, externalURLs *[]string) {
}
}
}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *FeishuChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go
index b0853fb8b..cd2abec90 100644
--- a/pkg/channels/line/line.go
+++ b/pkg/channels/line/line.go
@@ -696,3 +696,8 @@ func (c *LINEChannel) downloadContent(messageID, filename string) string {
},
})
}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *LINEChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go
index 7cd93c266..7c4013676 100644
--- a/pkg/channels/manager.go
+++ b/pkg/channels/manager.go
@@ -444,6 +444,23 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
m.initChannel("irc", "IRC")
}
+ if channels.VK.Enabled && channels.VK.Token.String() != "" && channels.VK.GroupID != 0 {
+ m.initChannel("vk", "VK")
+ }
+
+ if channels.TeamsWebhook.Enabled && len(channels.TeamsWebhook.Webhooks) > 0 {
+ hasValidTarget := false
+ for _, target := range channels.TeamsWebhook.Webhooks {
+ if target.WebhookURL.String() != "" {
+ hasValidTarget = true
+ break
+ }
+ }
+ if hasValidTarget {
+ m.initChannel("teams_webhook", "Teams Webhook")
+ }
+ }
+
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
"enabled_channels": len(m.channels),
})
diff --git a/pkg/channels/manager_channel.go b/pkg/channels/manager_channel.go
index b1c8c25e0..b54facda4 100644
--- a/pkg/channels/manager_channel.go
+++ b/pkg/channels/manager_channel.go
@@ -62,6 +62,13 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
value["app_secret"] = ch.Feishu.AppSecret.String()
value["encrypt_key"] = ch.Feishu.EncryptKey.String()
value["verification_token"] = ch.Feishu.VerificationToken.String()
+ case "teams_webhook":
+ // Expose webhook URLs for hash computation (they contain secrets)
+ webhooks := make(map[string]string)
+ for name, target := range ch.TeamsWebhook.Webhooks {
+ webhooks[name] = target.WebhookURL.String()
+ }
+ value["webhooks"] = webhooks
}
}
@@ -166,4 +173,13 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
newcfg.Feishu.EncryptKey = old.Feishu.EncryptKey
newcfg.Feishu.VerificationToken = old.Feishu.VerificationToken
}
+ if newcfg.TeamsWebhook.Enabled {
+ // Copy SecureString webhook URLs from old config
+ for name, oldTarget := range old.TeamsWebhook.Webhooks {
+ if newTarget, ok := newcfg.TeamsWebhook.Webhooks[name]; ok {
+ newTarget.WebhookURL = oldTarget.WebhookURL
+ newcfg.TeamsWebhook.Webhooks[name] = newTarget
+ }
+ }
+ }
}
diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go
index 29219679d..9819ac3e9 100644
--- a/pkg/channels/manager_test.go
+++ b/pkg/channels/manager_test.go
@@ -19,6 +19,8 @@ import (
type mockChannel struct {
BaseChannel
sendFn func(ctx context.Context, msg bus.OutboundMessage) error
+ startFn func(ctx context.Context) error
+ stopFn func(ctx context.Context) error
sentMessages []bus.OutboundMessage
placeholdersSent int
editedMessages int
@@ -33,8 +35,19 @@ func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri
return nil, m.sendFn(ctx, msg)
}
-func (m *mockChannel) Start(ctx context.Context) error { return nil }
-func (m *mockChannel) Stop(ctx context.Context) error { return nil }
+func (m *mockChannel) Start(ctx context.Context) error {
+ if m.startFn != nil {
+ return m.startFn(ctx)
+ }
+ return nil
+}
+
+func (m *mockChannel) Stop(ctx context.Context) error {
+ if m.stopFn != nil {
+ return m.stopFn(ctx)
+ }
+ return nil
+}
func (m *mockChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
m.placeholdersSent++
@@ -86,6 +99,101 @@ func newTestManager() *Manager {
return &Manager{
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
+ bus: bus.NewMessageBus(),
+ }
+}
+
+func TestStartAll_AllChannelsFail_ReturnsJoinedError(t *testing.T) {
+ m := newTestManager()
+ errA := errors.New("channel-a start failed")
+ errB := errors.New("channel-b start failed")
+
+ m.channels["a"] = &mockChannel{
+ startFn: func(_ context.Context) error { return errA },
+ }
+ m.channels["b"] = &mockChannel{
+ startFn: func(_ context.Context) error { return errB },
+ }
+
+ err := m.StartAll(t.Context())
+ if err == nil {
+ t.Fatal("expected StartAll to fail when all channels fail")
+ }
+ if !strings.Contains(err.Error(), "failed to start any enabled channels") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !errors.Is(err, errA) {
+ t.Fatalf("expected error to wrap errA, got: %v", err)
+ }
+ if !errors.Is(err, errB) {
+ t.Fatalf("expected error to wrap errB, got: %v", err)
+ }
+ if len(m.workers) != 0 {
+ t.Fatalf("expected no workers on full startup failure, got %d", len(m.workers))
+ }
+ if m.dispatchTask != nil {
+ t.Fatal("expected dispatch task to be cleared on full startup failure")
+ }
+}
+
+func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
+ m := newTestManager()
+ errBad := errors.New("bad channel start failed")
+ processed := make(chan struct{}, 1)
+
+ m.channels["good"] = &mockChannel{
+ sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
+ if msg.Channel == "good" {
+ select {
+ case processed <- struct{}{}:
+ default:
+ }
+ }
+ return nil
+ },
+ }
+ m.channels["bad"] = &mockChannel{
+ startFn: func(_ context.Context) error { return errBad },
+ }
+
+ err := m.StartAll(t.Context())
+ if err != nil {
+ t.Fatalf("expected StartAll to succeed with partial channel failures, got: %v", err)
+ }
+ if len(m.workers) != 1 {
+ t.Fatalf("expected exactly 1 active worker, got %d", len(m.workers))
+ }
+ if _, ok := m.workers["good"]; !ok {
+ t.Fatal("expected worker for successful channel 'good'")
+ }
+ if _, ok := m.workers["bad"]; ok {
+ t.Fatal("did not expect worker for failed channel 'bad'")
+ }
+ if m.dispatchTask == nil {
+ t.Fatal("expected dispatch task to run when at least one channel starts")
+ }
+
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer pubCancel()
+ if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
+ Channel: "good",
+ ChatID: "chat-1",
+ Content: "hello",
+ }); err != nil {
+ t.Fatalf("PublishOutbound() error = %v", err)
+ }
+
+ select {
+ case <-processed:
+ // worker processed outbound message as expected
+ case <-time.After(2 * time.Second):
+ t.Fatal("expected successful channel worker to process outbound message")
+ }
+
+ stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer stopCancel()
+ if err := m.StopAll(stopCtx); err != nil {
+ t.Fatalf("StopAll() error = %v", err)
}
}
diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go
index 431fc5dc8..406a9e3cc 100644
--- a/pkg/channels/matrix/matrix.go
+++ b/pkg/channels/matrix/matrix.go
@@ -1300,3 +1300,8 @@ func stripUserMentionWithRegexp(text string, userID id.UserID, mentionR *regexp.
cleaned = strings.TrimLeft(cleaned, ",:; ")
return strings.TrimSpace(cleaned)
}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *MatrixChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go
index 4f8dff234..c3a6f119f 100644
--- a/pkg/channels/onebot/onebot.go
+++ b/pkg/channels/onebot/onebot.go
@@ -1117,3 +1117,8 @@ func truncate(s string, n int) string {
}
return string(runes[:n]) + "..."
}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *OneBotChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/pico/client_test.go b/pkg/channels/pico/client_test.go
index 7c5a62801..b40606647 100644
--- a/pkg/channels/pico/client_test.go
+++ b/pkg/channels/pico/client_test.go
@@ -262,3 +262,57 @@ func TestSend_ClosedConnection(t *testing.T) {
ch.Stop(ctx)
}
+
+func TestParseInlineImageMedia_Valid(t *testing.T) {
+ media, err := parseInlineImageMedia(map[string]any{
+ "media": []any{
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+X2ioAAAAASUVORK5CYII=",
+ },
+ })
+ if err != nil {
+ t.Fatalf("parseInlineImageMedia() error = %v", err)
+ }
+ if len(media) != 1 {
+ t.Fatalf("len(media) = %d, want 1", len(media))
+ }
+}
+
+func TestPicoChannel_HandleMessageSend_AllowsMediaOnly(t *testing.T) {
+ mb := bus.NewMessageBus()
+ ch, err := NewPicoChannel(config.PicoConfig{
+ Token: *config.NewSecureString("test-token"),
+ }, mb)
+ if err != nil {
+ t.Fatalf("NewPicoChannel() error = %v", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ if err := ch.Start(ctx); err != nil {
+ t.Fatalf("Start() error = %v", err)
+ }
+ defer ch.Stop(ctx)
+
+ pc := &picoConn{id: "conn-1", sessionID: "sess-1"}
+ ch.handleMessageSend(pc, PicoMessage{
+ ID: "msg-1",
+ Payload: map[string]any{
+ "media": []any{
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+X2ioAAAAASUVORK5CYII=",
+ },
+ },
+ })
+
+ select {
+ case msg := <-mb.InboundChan():
+ if msg.Content != "" {
+ t.Fatalf("msg.Content = %q, want empty", msg.Content)
+ }
+ if len(msg.Media) != 1 || !strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
+ t.Fatalf("msg.Media = %#v, want inline image payload", msg.Media)
+ }
+ case <-ctx.Done():
+ t.Fatal("timed out waiting for inbound media message")
+ }
+}
diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go
index 17fb12d2b..3f8ba8643 100644
--- a/pkg/channels/pico/protocol.go
+++ b/pkg/channels/pico/protocol.go
@@ -39,19 +39,18 @@ func newMessage(msgType string, payload map[string]any) PicoMessage {
}
}
-// newError creates an error PicoMessage.
-func newError(code, message string) PicoMessage {
- return newMessage(TypeError, map[string]any{
+func newErrorWithPayload(code, message string, extra map[string]any) PicoMessage {
+ payload := map[string]any{
"code": code,
"message": message,
- })
-}
-
-func newErrorWithPayload(code, message string, payload map[string]any) PicoMessage {
- if payload == nil {
- payload = map[string]any{}
}
- payload["code"] = code
- payload["message"] = message
+ for key, value := range extra {
+ payload[key] = value
+ }
return newMessage(TypeError, payload)
}
+
+// newError creates an error PicoMessage.
+func newError(code, message string) PicoMessage {
+ return newErrorWithPayload(code, message, nil)
+}
diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go
index aa78d8e85..b274ea544 100644
--- a/pkg/channels/qq/qq.go
+++ b/pkg/channels/qq/qq.go
@@ -1003,3 +1003,8 @@ func sanitizeURLs(text string) string {
return scheme + domain + path
})
}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *QQChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/teams_webhook/init.go b/pkg/channels/teams_webhook/init.go
new file mode 100644
index 000000000..fca960039
--- /dev/null
+++ b/pkg/channels/teams_webhook/init.go
@@ -0,0 +1,13 @@
+package teamswebhook
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("teams_webhook", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewTeamsWebhookChannel(cfg.Channels.TeamsWebhook, b)
+ })
+}
diff --git a/pkg/channels/teams_webhook/teams_webhook.go b/pkg/channels/teams_webhook/teams_webhook.go
new file mode 100644
index 000000000..fa7762a3e
--- /dev/null
+++ b/pkg/channels/teams_webhook/teams_webhook.go
@@ -0,0 +1,422 @@
+package teamswebhook
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
+
+ goteamsnotify "github.com/atc0005/go-teams-notify/v2"
+ "github.com/atc0005/go-teams-notify/v2/adaptivecard"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// statusCodeRe extracts HTTP status codes from error messages like "401 Unauthorized".
+var statusCodeRe = regexp.MustCompile(`\b([45]\d{2})\b`)
+
+// markdownTableRe matches a markdown table block (header + separator + rows).
+// It captures the entire table including all rows.
+var markdownTableRe = regexp.MustCompile(`(?m)^(\|[^\n]+\|)\n(\|[-:\|\s]+\|)\n((?:\|[^\n]+\|\n?)+)`)
+
+// teamsMessageSender abstracts the Teams client for testability.
+type teamsMessageSender interface {
+ SendWithContext(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
+}
+
+// classifyTeamsError extracts HTTP status code from error message and classifies it.
+// The go-teams-notify library returns errors like "error on notification: 401 Unauthorized, ...".
+// This allows proper retry behavior: 4xx errors are permanent, 5xx are temporary.
+func classifyTeamsError(err error) error {
+ if err == nil {
+ return nil
+ }
+ errMsg := err.Error()
+ if matches := statusCodeRe.FindStringSubmatch(errMsg); len(matches) > 1 {
+ if statusCode, parseErr := strconv.Atoi(matches[1]); parseErr == nil {
+ return channels.ClassifySendError(statusCode, err)
+ }
+ }
+ // Fallback: treat as temporary network error (retryable)
+ return channels.ClassifyNetError(err)
+}
+
+// TeamsWebhookChannel is an output-only channel that sends messages
+// to Microsoft Teams via Power Automate workflow webhooks.
+// Multiple webhook targets can be configured and selected via ChatID.
+type TeamsWebhookChannel struct {
+ *channels.BaseChannel
+ config config.TeamsWebhookConfig
+ client teamsMessageSender
+}
+
+// NewTeamsWebhookChannel creates a new Teams webhook channel.
+func NewTeamsWebhookChannel(
+ cfg config.TeamsWebhookConfig,
+ bus *bus.MessageBus,
+) (*TeamsWebhookChannel, error) {
+ if len(cfg.Webhooks) == 0 {
+ return nil, fmt.Errorf("teams_webhook: at least one webhook target is required")
+ }
+
+ // Require "default" webhook target
+ if _, hasDefault := cfg.Webhooks["default"]; !hasDefault {
+ return nil, fmt.Errorf("teams_webhook: a 'default' webhook target is required")
+ }
+
+ // Validate all webhook targets have valid HTTPS URLs
+ for name, target := range cfg.Webhooks {
+ webhookURL := target.WebhookURL.String()
+ if webhookURL == "" {
+ return nil, fmt.Errorf("teams_webhook: webhook %q has empty webhook_url", name)
+ }
+ parsed, err := url.Parse(webhookURL)
+ if err != nil {
+ return nil, fmt.Errorf("teams_webhook: webhook %q has invalid URL: %w", name, err)
+ }
+ if !strings.EqualFold(parsed.Scheme, "https") {
+ return nil, fmt.Errorf("teams_webhook: webhook %q must use HTTPS (got %q)", name, parsed.Scheme)
+ }
+ }
+
+ base := channels.NewBaseChannel(
+ "teams_webhook",
+ cfg,
+ bus,
+ []string{
+ "*",
+ }, // Output-only channel; "*" suppresses misleading "allows EVERYONE" audit warning
+ channels.WithMaxMessageLength(24000), // Power Automate webhook payload limit is 28KB
+ )
+
+ client := goteamsnotify.NewTeamsClient()
+
+ return &TeamsWebhookChannel{
+ BaseChannel: base,
+ config: cfg,
+ client: client,
+ }, nil
+}
+
+// Start initializes the channel. For output-only channels, this is a no-op.
+func (c *TeamsWebhookChannel) Start(ctx context.Context) error {
+ targets := make([]string, 0, len(c.config.Webhooks))
+ for name := range c.config.Webhooks {
+ targets = append(targets, name)
+ }
+ sort.Strings(targets)
+ logger.InfoCF("teams_webhook", "Starting Teams webhook channel (output-only)", map[string]any{
+ "targets": targets,
+ })
+ c.SetRunning(true)
+ return nil
+}
+
+// Stop shuts down the channel.
+func (c *TeamsWebhookChannel) Stop(ctx context.Context) error {
+ logger.InfoC("teams_webhook", "Stopping Teams webhook channel")
+ c.SetRunning(false)
+ return nil
+}
+
+// Send delivers a message to the specified Teams webhook target.
+// The target is selected by msg.ChatID which must match a key in the webhooks map.
+func (c *TeamsWebhookChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
+ if !c.IsRunning() {
+ return nil, channels.ErrNotRunning
+ }
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ // Look up webhook target by ChatID, fall back to "default" if empty or unknown
+ targetName := msg.ChatID
+ if targetName == "" {
+ targetName = "default"
+ }
+
+ target, ok := c.config.Webhooks[targetName]
+ if !ok {
+ // Log warning and fall back to default target
+ logger.WarnCF("teams_webhook", "Unknown target, falling back to default", map[string]any{
+ "requested": msg.ChatID,
+ "using": "default",
+ })
+ target = c.config.Webhooks["default"]
+ }
+
+ // Build an Adaptive Card for rich formatting
+ card, err := c.buildAdaptiveCard(msg, target)
+ if err != nil {
+ return nil, fmt.Errorf("teams_webhook: failed to build card: %w", err)
+ }
+
+ // Create the message with the card
+ teamsMsg, err := adaptivecard.NewMessageFromCard(card)
+ if err != nil {
+ return nil, fmt.Errorf("teams_webhook: failed to create message: %w", err)
+ }
+
+ // Send to Teams
+ if err := c.client.SendWithContext(ctx, target.WebhookURL.String(), teamsMsg); err != nil {
+ // Log without raw error to avoid leaking webhook URL (embedded in net/http errors)
+ logger.ErrorCF("teams_webhook", "Failed to send message to Teams webhook", map[string]any{
+ "target": msg.ChatID,
+ })
+ // Classify error based on status code extracted from error message.
+ // The go-teams-notify library includes status in errors like "401 Unauthorized".
+ // Use ClassifySendError for proper retry behavior (4xx = permanent, 5xx = temporary).
+ classifiedErr := classifyTeamsError(err)
+ return nil, fmt.Errorf("teams_webhook: send failed: %w", classifiedErr)
+ }
+
+ logger.DebugCF("teams_webhook", "Message sent successfully", map[string]any{
+ "target": msg.ChatID,
+ })
+
+ return nil, nil
+}
+
+// buildAdaptiveCard creates a formatted Adaptive Card from the outbound message.
+// It detects markdown tables and converts them to native Adaptive Card Table elements,
+// since TextBlocks only support a limited markdown subset (no tables).
+func (c *TeamsWebhookChannel) buildAdaptiveCard(
+ msg bus.OutboundMessage,
+ target config.TeamsWebhookTarget,
+) (adaptivecard.Card, error) {
+ card := adaptivecard.NewCard()
+ card.Type = adaptivecard.TypeAdaptiveCard
+
+ // Set full width for Teams rendering
+ card.MSTeams.Width = "Full"
+
+ // Add title if configured on the target
+ title := target.Title
+ if title == "" {
+ title = "PicoClaw Notification"
+ }
+
+ titleBlock := adaptivecard.NewTextBlock(title, true)
+ titleBlock.Size = adaptivecard.SizeLarge
+ titleBlock.Weight = adaptivecard.WeightBolder
+ titleBlock.Style = adaptivecard.TextBlockStyleHeading
+
+ if err := card.AddElement(false, titleBlock); err != nil {
+ return card, err
+ }
+
+ content := msg.Content
+ if content == "" {
+ content = "(empty message)"
+ }
+
+ // Split content into text segments and tables
+ // TextBlocks support: bold, italic, bullet/numbered lists, links
+ // TextBlocks do NOT support: headers, tables, images
+ segments := splitContentWithTables(content)
+
+ for _, seg := range segments {
+ if seg.isTable {
+ // Convert markdown table to Adaptive Card Table element
+ tableElement, err := parseMarkdownTable(seg.content)
+ if err != nil {
+ // Fallback: render as preformatted text if parsing fails
+ logger.WarnCF("teams_webhook", "Failed to parse markdown table, using fallback", map[string]any{
+ "error": err.Error(),
+ })
+ block := adaptivecard.NewTextBlock("```\n"+seg.content+"\n```", true)
+ block.Wrap = true
+ if err := card.AddElement(false, block); err != nil {
+ return card, err
+ }
+ continue
+ }
+ if err := card.AddElement(false, tableElement); err != nil {
+ return card, err
+ }
+ } else {
+ // Regular text content
+ text := strings.TrimSpace(seg.content)
+ if text == "" {
+ continue
+ }
+ block := adaptivecard.NewTextBlock(text, true)
+ block.Wrap = true
+ if err := card.AddElement(false, block); err != nil {
+ return card, err
+ }
+ }
+ }
+
+ return card, nil
+}
+
+// contentSegment represents either a text block or a table in the message content.
+type contentSegment struct {
+ content string
+ isTable bool
+}
+
+// splitContentWithTables splits content into alternating text and table segments.
+func splitContentWithTables(content string) []contentSegment {
+ var segments []contentSegment
+
+ matches := markdownTableRe.FindAllStringSubmatchIndex(content, -1)
+ if len(matches) == 0 {
+ // No tables found, return entire content as text
+ return []contentSegment{{content: content, isTable: false}}
+ }
+
+ lastEnd := 0
+ for _, match := range matches {
+ // Text before this table
+ if match[0] > lastEnd {
+ segments = append(segments, contentSegment{
+ content: content[lastEnd:match[0]],
+ isTable: false,
+ })
+ }
+ // The table itself
+ segments = append(segments, contentSegment{
+ content: content[match[0]:match[1]],
+ isTable: true,
+ })
+ lastEnd = match[1]
+ }
+
+ // Text after the last table
+ if lastEnd < len(content) {
+ segments = append(segments, contentSegment{
+ content: content[lastEnd:],
+ isTable: false,
+ })
+ }
+
+ return segments
+}
+
+// parseMarkdownTable converts a markdown table string to an Adaptive Card Table element.
+func parseMarkdownTable(tableStr string) (adaptivecard.Element, error) {
+ lines := strings.Split(strings.TrimSpace(tableStr), "\n")
+ if len(lines) < 2 {
+ return adaptivecard.Element{}, fmt.Errorf("table must have at least header and separator rows")
+ }
+
+ // Track header content length per column for width calculation
+ var headerLengths []int
+
+ // Parse all rows (header + data rows, skip separator)
+ var allRows [][]adaptivecard.TableCell
+ for i, line := range lines {
+ // Skip separator row (contains only |, -, :, and spaces)
+ if i == 1 && isSeparatorRow(line) {
+ continue
+ }
+
+ cells := parseTableRow(line)
+ if len(cells) == 0 {
+ continue
+ }
+
+ var tableCells []adaptivecard.TableCell
+ for _, cellText := range cells {
+ trimmedText := strings.TrimSpace(cellText)
+
+ // Use header row (first row) to determine column widths
+ if i == 0 {
+ headerLengths = append(headerLengths, len(trimmedText))
+ }
+
+ textBlock := adaptivecard.Element{
+ Type: adaptivecard.TypeElementTextBlock,
+ Text: trimmedText,
+ Wrap: true,
+ }
+ cell := adaptivecard.TableCell{
+ Type: adaptivecard.TypeTableCell,
+ Items: []*adaptivecard.Element{&textBlock},
+ }
+ tableCells = append(tableCells, cell)
+ }
+ allRows = append(allRows, tableCells)
+ }
+
+ if len(allRows) == 0 {
+ return adaptivecard.Element{}, fmt.Errorf("no valid rows found in table")
+ }
+
+ // Create table with first row as headers
+ firstRowAsHeaders := true
+ showGridLines := true
+
+ table, err := adaptivecard.NewTableFromTableCells(allRows, 0, firstRowAsHeaders, showGridLines)
+ if err != nil {
+ return adaptivecard.Element{}, fmt.Errorf("failed to create table: %w", err)
+ }
+
+ // Set column widths based on header content length
+ table.Columns = calculateColumnWidths(headerLengths)
+
+ return table, nil
+}
+
+// calculateColumnWidths creates TableColumnDefinition entries with widths
+// proportional to the max content length of each column.
+func calculateColumnWidths(maxLengths []int) []adaptivecard.Column {
+ if len(maxLengths) == 0 {
+ return nil
+ }
+
+ // Use content length as relative weight, with a minimum of 1
+ columns := make([]adaptivecard.Column, len(maxLengths))
+ for i, length := range maxLengths {
+ weight := length
+ if weight < 1 {
+ weight = 1
+ }
+ columns[i] = adaptivecard.Column{
+ Type: "TableColumnDefinition",
+ Width: weight,
+ }
+ }
+
+ return columns
+}
+
+// isSeparatorRow checks if a line is a markdown table separator (e.g., |---|---|).
+func isSeparatorRow(line string) bool {
+ // Remove pipes and spaces, check if only dashes and colons remain
+ cleaned := strings.ReplaceAll(line, "|", "")
+ cleaned = strings.ReplaceAll(cleaned, " ", "")
+ cleaned = strings.ReplaceAll(cleaned, "-", "")
+ cleaned = strings.ReplaceAll(cleaned, ":", "")
+ return cleaned == ""
+}
+
+// parseTableRow extracts cell values from a markdown table row.
+func parseTableRow(line string) []string {
+ // Trim leading/trailing pipes and split by |
+ line = strings.TrimSpace(line)
+ line = strings.TrimPrefix(line, "|")
+ line = strings.TrimSuffix(line, "|")
+
+ if line == "" {
+ return nil
+ }
+
+ parts := strings.Split(line, "|")
+ var cells []string
+ for _, p := range parts {
+ cells = append(cells, strings.TrimSpace(p))
+ }
+ return cells
+}
diff --git a/pkg/channels/teams_webhook/teams_webhook_test.go b/pkg/channels/teams_webhook/teams_webhook_test.go
new file mode 100644
index 000000000..451ba9d18
--- /dev/null
+++ b/pkg/channels/teams_webhook/teams_webhook_test.go
@@ -0,0 +1,583 @@
+package teamswebhook
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ goteamsnotify "github.com/atc0005/go-teams-notify/v2"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// mockTeamsClient implements teamsMessageSender for testing.
+type mockTeamsClient struct {
+ sendFunc func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
+}
+
+func (m *mockTeamsClient) SendWithContext(
+ ctx context.Context,
+ webhookURL string,
+ message goteamsnotify.TeamsMessage,
+) error {
+ if m.sendFunc != nil {
+ return m.sendFunc(ctx, webhookURL, message)
+ }
+ return nil
+}
+
+func TestNewTeamsWebhookChannel(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ // Test missing webhooks
+ _, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: nil,
+ }, msgBus)
+ if err == nil {
+ t.Error("expected error for missing webhooks")
+ }
+
+ // Test missing "default" webhook
+ _, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "alerts": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook"),
+ Title: "Alerts",
+ },
+ },
+ }, msgBus)
+ if err == nil {
+ t.Error("expected error for missing 'default' webhook")
+ }
+
+ // Test empty webhook URL
+ _, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {Title: "Default"},
+ },
+ }, msgBus)
+ if err == nil {
+ t.Error("expected error for empty webhook_url")
+ }
+
+ // Test HTTP URL (should fail, must be HTTPS)
+ _, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("http://example.com/webhook"),
+ Title: "Default",
+ },
+ },
+ }, msgBus)
+ if err == nil {
+ t.Error("expected error for HTTP webhook URL (must be HTTPS)")
+ }
+
+ // Test valid config with HTTPS (must include "default")
+ ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
+ Title: "Default",
+ },
+ "alerts": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
+ Title: "Alerts",
+ },
+ },
+ }, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if ch.Name() != "teams_webhook" {
+ t.Errorf("expected name 'teams_webhook', got %q", ch.Name())
+ }
+}
+
+func TestTeamsWebhookChannel_StartStop(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook"),
+ },
+ },
+ }, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ ctx := context.Background()
+
+ if ch.IsRunning() {
+ t.Error("channel should not be running before Start")
+ }
+
+ if err := ch.Start(ctx); err != nil {
+ t.Fatalf("Start failed: %v", err)
+ }
+
+ if !ch.IsRunning() {
+ t.Error("channel should be running after Start")
+ }
+
+ if err := ch.Stop(ctx); err != nil {
+ t.Fatalf("Stop failed: %v", err)
+ }
+
+ if ch.IsRunning() {
+ t.Error("channel should not be running after Stop")
+ }
+}
+
+func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
+ Title: "Default",
+ },
+ "alerts": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook"),
+ Title: "Custom Title",
+ },
+ },
+ }, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ target := ch.config.Webhooks["alerts"]
+ msg := bus.OutboundMessage{
+ Content: "Test message content",
+ ChatID: "alerts",
+ }
+
+ card, err := ch.buildAdaptiveCard(msg, target)
+ if err != nil {
+ t.Fatalf("buildAdaptiveCard failed: %v", err)
+ }
+
+ if card.Type != "AdaptiveCard" {
+ t.Errorf("expected card type 'AdaptiveCard', got %q", card.Type)
+ }
+}
+
+func TestTeamsWebhookChannel_SendNotRunning(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook"),
+ },
+ },
+ }, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ ctx := context.Background()
+ msg := bus.OutboundMessage{Content: "test", ChatID: "default"}
+
+ _, err = ch.Send(ctx, msg)
+ if err == nil {
+ t.Error("expected error when sending while not running")
+ }
+}
+
+func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
+ tests := []struct {
+ name string
+ chatID string
+ }{
+ {"unknown target falls back to default", "unknown"},
+ {"empty ChatID uses default", ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
+ },
+ "alerts": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
+ },
+ },
+ }, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ var sentURL string
+ ch.client = &mockTeamsClient{
+ sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
+ sentURL = webhookURL
+ return nil
+ },
+ }
+
+ ctx := context.Background()
+ _ = ch.Start(ctx)
+ defer ch.Stop(ctx)
+
+ msg := bus.OutboundMessage{Content: "test", ChatID: tt.chatID}
+ _, err = ch.Send(ctx, msg)
+ if err != nil {
+ t.Fatalf("expected success, got error: %v", err)
+ }
+
+ if sentURL != "https://example.com/webhook-default" {
+ t.Errorf("expected default webhook URL, got %q", sentURL)
+ }
+ })
+ }
+}
+
+func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
+ Title: "Default",
+ },
+ "alerts": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
+ Title: "Test Alerts",
+ },
+ },
+ }, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Inject mock client
+ var sentURL string
+ ch.client = &mockTeamsClient{
+ sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
+ sentURL = webhookURL
+ return nil
+ },
+ }
+
+ ctx := context.Background()
+ _ = ch.Start(ctx)
+ defer ch.Stop(ctx)
+
+ msg := bus.OutboundMessage{Content: "Hello Teams!", ChatID: "alerts"}
+
+ _, err = ch.Send(ctx, msg)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if sentURL != "https://example.com/webhook-alerts" {
+ t.Errorf("expected webhook URL 'https://example.com/webhook-alerts', got %q", sentURL)
+ }
+}
+
+func TestTeamsWebhookChannel_SendError(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
+ Enabled: true,
+ Webhooks: map[string]config.TeamsWebhookTarget{
+ "default": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
+ },
+ "alerts": {
+ WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
+ },
+ },
+ }, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Inject mock client that returns an error
+ ch.client = &mockTeamsClient{
+ sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
+ return errors.New("error on notification: 401 Unauthorized, forbidden")
+ },
+ }
+
+ ctx := context.Background()
+ _ = ch.Start(ctx)
+ defer ch.Stop(ctx)
+
+ msg := bus.OutboundMessage{Content: "test", ChatID: "alerts"}
+
+ _, err = ch.Send(ctx, msg)
+ if err == nil {
+ t.Error("expected error from failed send")
+ }
+}
+
+func TestSplitContentWithTables(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ wantSegs int
+ wantTbl int // number of table segments
+ }{
+ {
+ name: "no tables",
+ content: "Just some text\nwith multiple lines",
+ wantSegs: 1,
+ wantTbl: 0,
+ },
+ {
+ name: "single table",
+ content: `| Col1 | Col2 |
+|------|------|
+| A | B |
+| C | D |`,
+ wantSegs: 1,
+ wantTbl: 1,
+ },
+ {
+ name: "text before table",
+ content: `Here is some text.
+
+| Col1 | Col2 |
+|------|------|
+| A | B |`,
+ wantSegs: 2,
+ wantTbl: 1,
+ },
+ {
+ name: "text before and after table",
+ content: `Before table.
+
+| Col1 | Col2 |
+|------|------|
+| A | B |
+
+After table.`,
+ wantSegs: 3,
+ wantTbl: 1,
+ },
+ {
+ name: "multiple tables",
+ content: `First table:
+
+| A | B |
+|---|---|
+| 1 | 2 |
+
+Second table:
+
+| X | Y |
+|---|---|
+| 3 | 4 |`,
+ wantSegs: 4,
+ wantTbl: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ segs := splitContentWithTables(tt.content)
+ if len(segs) != tt.wantSegs {
+ t.Errorf("got %d segments, want %d", len(segs), tt.wantSegs)
+ }
+ tableCount := 0
+ for _, s := range segs {
+ if s.isTable {
+ tableCount++
+ }
+ }
+ if tableCount != tt.wantTbl {
+ t.Errorf("got %d tables, want %d", tableCount, tt.wantTbl)
+ }
+ })
+ }
+}
+
+func TestParseMarkdownTable(t *testing.T) {
+ tableStr := `| Name | Value |
+|------|-------|
+| foo | 123 |
+| bar | 456 |`
+
+ elem, err := parseMarkdownTable(tableStr)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if elem.Type != "Table" {
+ t.Errorf("expected type 'Table', got %q", elem.Type)
+ }
+
+ // Should have 3 rows (header + 2 data rows)
+ if len(elem.Rows) != 3 {
+ t.Errorf("expected 3 rows, got %d", len(elem.Rows))
+ }
+
+ // Should have 2 columns with widths based on content length
+ if len(elem.Columns) != 2 {
+ t.Errorf("expected 2 columns, got %d", len(elem.Columns))
+ }
+}
+
+func TestParseMarkdownTableColumnWidths(t *testing.T) {
+ // Column widths are based on HEADER row only:
+ // Col1: "Description" (11 chars)
+ // Col2: "X" (1 char)
+ // Col3: "Amount" (6 chars)
+ tableStr := `| Description | X | Amount |
+|-------------|---|--------|
+| Short | Y | 100 |
+| Longer text | Z | 50 |`
+
+ elem, err := parseMarkdownTable(tableStr)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(elem.Columns) != 3 {
+ t.Fatalf("expected 3 columns, got %d", len(elem.Columns))
+ }
+
+ // Verify column widths are based on header content length
+ w1, ok1 := elem.Columns[0].Width.(int)
+ w2, ok2 := elem.Columns[1].Width.(int)
+ w3, ok3 := elem.Columns[2].Width.(int)
+
+ if !ok1 || !ok2 || !ok3 {
+ t.Fatalf("expected int widths, got types: %T, %T, %T",
+ elem.Columns[0].Width, elem.Columns[1].Width, elem.Columns[2].Width)
+ }
+
+ // Header lengths: "Description" = 11, "X" = 1, "Amount" = 6
+ if w1 != 11 {
+ t.Errorf("expected col1 width 11 (from 'Description'), got %d", w1)
+ }
+ if w2 != 1 {
+ t.Errorf("expected col2 width 1 (from 'X'), got %d", w2)
+ }
+ if w3 != 6 {
+ t.Errorf("expected col3 width 6 (from 'Amount'), got %d", w3)
+ }
+}
+
+func TestCalculateColumnWidths(t *testing.T) {
+ tests := []struct {
+ name string
+ maxLengths []int
+ wantWidths []int
+ }{
+ {
+ name: "equal lengths",
+ maxLengths: []int{10, 10, 10},
+ wantWidths: []int{10, 10, 10},
+ },
+ {
+ name: "varying lengths",
+ maxLengths: []int{5, 20, 10},
+ wantWidths: []int{5, 20, 10},
+ },
+ {
+ name: "zero length gets minimum of 1",
+ maxLengths: []int{0, 5, 0},
+ wantWidths: []int{1, 5, 1},
+ },
+ {
+ name: "empty input",
+ maxLengths: []int{},
+ wantWidths: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cols := calculateColumnWidths(tt.maxLengths)
+
+ if tt.wantWidths == nil {
+ if cols != nil {
+ t.Errorf("expected nil, got %v", cols)
+ }
+ return
+ }
+
+ if len(cols) != len(tt.wantWidths) {
+ t.Fatalf("expected %d columns, got %d", len(tt.wantWidths), len(cols))
+ }
+
+ for i, col := range cols {
+ width, ok := col.Width.(int)
+ if !ok {
+ t.Errorf("column %d: expected int width, got %T", i, col.Width)
+ continue
+ }
+ if width != tt.wantWidths[i] {
+ t.Errorf("column %d: expected width %d, got %d", i, tt.wantWidths[i], width)
+ }
+ if col.Type != "TableColumnDefinition" {
+ t.Errorf("column %d: expected type 'TableColumnDefinition', got %q", i, col.Type)
+ }
+ }
+ })
+ }
+}
+
+func TestParseTableRow(t *testing.T) {
+ tests := []struct {
+ line string
+ want []string
+ }{
+ {"| A | B | C |", []string{"A", "B", "C"}},
+ {"|A|B|C|", []string{"A", "B", "C"}},
+ {"| foo | bar |", []string{"foo", "bar"}},
+ {"", nil},
+ }
+
+ for _, tt := range tests {
+ got := parseTableRow(tt.line)
+ if len(got) != len(tt.want) {
+ t.Errorf("parseTableRow(%q): got %v, want %v", tt.line, got, tt.want)
+ continue
+ }
+ for i := range got {
+ if got[i] != tt.want[i] {
+ t.Errorf("parseTableRow(%q)[%d]: got %q, want %q", tt.line, i, got[i], tt.want[i])
+ }
+ }
+ }
+}
+
+func TestIsSeparatorRow(t *testing.T) {
+ tests := []struct {
+ line string
+ want bool
+ }{
+ {"|---|---|", true},
+ {"| --- | --- |", true},
+ {"|:---|---:|", true},
+ {"| :---: | :---: |", true},
+ {"| A | B |", false},
+ {"| foo | bar |", false},
+ }
+
+ for _, tt := range tests {
+ got := isSeparatorRow(tt.line)
+ if got != tt.want {
+ t.Errorf("isSeparatorRow(%q): got %v, want %v", tt.line, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go
index 464551351..20a659266 100644
--- a/pkg/channels/telegram/telegram.go
+++ b/pkg/channels/telegram/telegram.go
@@ -1190,3 +1190,8 @@ func isPostConnectError(err error) bool {
strings.Contains(msg, "connection closed by foreign host") ||
strings.Contains(msg, "broken pipe")
}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *TelegramChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/vk/init.go b/pkg/channels/vk/init.go
new file mode 100644
index 000000000..6a5927a32
--- /dev/null
+++ b/pkg/channels/vk/init.go
@@ -0,0 +1,13 @@
+package vk
+
+import (
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ channels.RegisterFactory("vk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
+ return NewVKChannel(cfg, b)
+ })
+}
diff --git a/pkg/channels/vk/vk.go b/pkg/channels/vk/vk.go
new file mode 100644
index 000000000..bb36da139
--- /dev/null
+++ b/pkg/channels/vk/vk.go
@@ -0,0 +1,282 @@
+package vk
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/SevereCloud/vksdk/v3/api"
+ "github.com/SevereCloud/vksdk/v3/api/params"
+ "github.com/SevereCloud/vksdk/v3/events"
+ "github.com/SevereCloud/vksdk/v3/longpoll-bot"
+ "github.com/SevereCloud/vksdk/v3/object"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/identity"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+type VKChannel struct {
+ *channels.BaseChannel
+ vk *api.VK
+ lp *longpoll.LongPoll
+ config *config.Config
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+func NewVKChannel(cfg *config.Config, bus *bus.MessageBus) (*VKChannel, error) {
+ vkCfg := cfg.Channels.VK
+
+ vk := api.NewVK(vkCfg.Token.String())
+
+ base := channels.NewBaseChannel(
+ "vk",
+ vkCfg,
+ bus,
+ vkCfg.AllowFrom,
+ channels.WithMaxMessageLength(4000),
+ channels.WithGroupTrigger(vkCfg.GroupTrigger),
+ channels.WithReasoningChannelID(vkCfg.ReasoningChannelID),
+ )
+
+ return &VKChannel{
+ BaseChannel: base,
+ vk: vk,
+ config: cfg,
+ }, nil
+}
+
+func (c *VKChannel) Start(ctx context.Context) error {
+ logger.InfoC("vk", "Starting VK bot (Long Poll mode)...")
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ groupID := c.config.Channels.VK.GroupID
+ if groupID == 0 {
+ c.cancel()
+ return fmt.Errorf("group_id is required for VK bot")
+ }
+
+ lp, err := longpoll.NewLongPoll(c.vk, groupID)
+ if err != nil {
+ c.cancel()
+ return fmt.Errorf("failed to create long poll: %w", err)
+ }
+ c.lp = lp
+
+ lp.MessageNew(func(_ context.Context, obj events.MessageNewObject) {
+ c.handleMessage(obj.Message)
+ })
+
+ c.SetRunning(true)
+
+ logger.InfoCF("vk", "VK bot connected", map[string]any{
+ "group_id": groupID,
+ })
+
+ go func() {
+ if err := lp.Run(); err != nil {
+ logger.ErrorCF("vk", "Long poll failed", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
+
+ return nil
+}
+
+func (c *VKChannel) Stop(ctx context.Context) error {
+ logger.InfoC("vk", "Stopping VK bot...")
+ c.SetRunning(false)
+
+ if c.lp != nil {
+ c.lp.Shutdown()
+ }
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ return nil
+}
+
+func (c *VKChannel) handleMessage(msg object.MessagesMessage) {
+ if msg.Action.Type != "" {
+ return
+ }
+
+ if bool(msg.Out) {
+ return
+ }
+
+ peerID := msg.PeerID
+ chatID := strconv.Itoa(peerID)
+
+ fromID := msg.FromID
+ userID := strconv.Itoa(fromID)
+
+ platformID := userID
+ sender := bus.SenderInfo{
+ Platform: "vk",
+ PlatformID: platformID,
+ CanonicalID: identity.BuildCanonicalID("vk", platformID),
+ DisplayName: c.getUserName(fromID),
+ }
+
+ if !c.IsAllowedSender(sender) {
+ logger.DebugCF("vk", "Message from unauthorized user", map[string]any{
+ "peer_id": peerID,
+ })
+ return
+ }
+
+ text := msg.Text
+ if text == "" && len(msg.Attachments) > 0 {
+ text = c.processAttachments(msg.Attachments)
+ }
+
+ if text == "" {
+ return
+ }
+
+ groupTrigger := c.config.Channels.VK.GroupTrigger
+ isGroupChat := peerID != fromID
+
+ if isGroupChat {
+ isMentioned := c.isMentioned(msg)
+ if isMentioned {
+ text = c.stripBotMention(text)
+ }
+ respond, cleaned := c.ShouldRespondInGroup(isMentioned, text)
+ if !respond {
+ return
+ }
+ text = cleaned
+ _ = groupTrigger
+ }
+
+ chatType := "direct"
+ if isGroupChat {
+ chatType = "group"
+ }
+
+ messageID := strconv.Itoa(msg.ConversationMessageID)
+
+ metadata := map[string]string{
+ "user_id": userID,
+ "is_group": fmt.Sprintf("%t", isGroupChat),
+ }
+
+ c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{
+ Channel: "vk",
+ ChatID: chatID,
+ ChatType: chatType,
+ SenderID: userID,
+ MessageID: messageID,
+ Mentioned: isGroupChat && c.isMentioned(msg),
+ Raw: metadata,
+ }, sender)
+}
+
+func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
+ if !c.IsRunning() {
+ return nil, channels.ErrNotRunning
+ }
+
+ peerID, err := strconv.Atoi(msg.ChatID)
+ if err != nil {
+ return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed)
+ }
+
+ if msg.Content == "" {
+ return nil, nil
+ }
+
+ var messageIDs []string
+ chunks := channels.SplitMessage(msg.Content, 4000)
+
+ for _, chunk := range chunks {
+ if chunk == "" {
+ continue
+ }
+
+ b := params.NewMessagesSendBuilder()
+ b.Message(chunk)
+ b.RandomID(0)
+ b.PeerID(peerID)
+
+ if msg.ReplyToMessageID != "" {
+ if replyID, err := strconv.Atoi(msg.ReplyToMessageID); err == nil {
+ b.ReplyTo(replyID)
+ }
+ }
+
+ resp, err := c.vk.MessagesSend(b.Params)
+ if err != nil {
+ logger.ErrorCF("vk", "Failed to send message", map[string]any{
+ "error": err.Error(),
+ "peer_id": peerID,
+ })
+ return messageIDs, fmt.Errorf("failed to send message: %w", err)
+ }
+
+ messageIDs = append(messageIDs, strconv.Itoa(resp))
+ }
+
+ return messageIDs, nil
+}
+
+func (c *VKChannel) isMentioned(msg object.MessagesMessage) bool {
+ return false
+}
+
+func (c *VKChannel) stripBotMention(text string) string {
+ return strings.TrimSpace(text)
+}
+
+func (c *VKChannel) getUserName(userID int) string {
+ users, err := c.vk.UsersGet(api.Params{
+ "user_ids": userID,
+ })
+ if err != nil || len(users) == 0 {
+ return strconv.Itoa(userID)
+ }
+
+ user := users[0]
+ return fmt.Sprintf("%s %s", user.FirstName, user.LastName)
+}
+
+func (c *VKChannel) processAttachments(attachments []object.MessagesMessageAttachment) string {
+ var parts []string
+
+ for _, att := range attachments {
+ switch att.Type {
+ case "photo":
+ parts = append(parts, "[photo]")
+ case "video":
+ parts = append(parts, "[video]")
+ case "audio":
+ parts = append(parts, "[audio]")
+ case "doc":
+ if att.Doc.Title != "" {
+ parts = append(parts, fmt.Sprintf("[document: %s]", att.Doc.Title))
+ } else {
+ parts = append(parts, "[document]")
+ }
+ case "audio_message":
+ parts = append(parts, "[voice]")
+ case "sticker":
+ parts = append(parts, "[sticker]")
+ }
+ }
+
+ return strings.Join(parts, " ")
+}
+
+func (c *VKChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/channels/vk/vk_test.go b/pkg/channels/vk/vk_test.go
new file mode 100644
index 000000000..c7e62ab31
--- /dev/null
+++ b/pkg/channels/vk/vk_test.go
@@ -0,0 +1,260 @@
+package vk
+
+import (
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/channels"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func TestNewVKChannel(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ t.Run("missing group_id", func(t *testing.T) {
+ cfg := &config.Config{
+ Channels: config.ChannelsConfig{
+ VK: config.VKConfig{
+ Enabled: true,
+ Token: *config.NewSecureString("test_token"),
+ },
+ },
+ }
+ ch, err := NewVKChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error during creation: %v", err)
+ }
+ if ch.Name() != "vk" {
+ t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
+ }
+ if ch.IsRunning() {
+ t.Error("new channel should not be running")
+ }
+ })
+
+ t.Run("valid config with group_id", func(t *testing.T) {
+ cfg := &config.Config{
+ Channels: config.ChannelsConfig{
+ VK: config.VKConfig{
+ Enabled: true,
+ Token: *config.NewSecureString("test_token"),
+ GroupID: 123456789,
+ },
+ },
+ }
+ ch, err := NewVKChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if ch.Name() != "vk" {
+ t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
+ }
+ if ch.IsRunning() {
+ t.Error("new channel should not be running")
+ }
+ })
+
+ t.Run("with allow_from", func(t *testing.T) {
+ cfg := &config.Config{
+ Channels: config.ChannelsConfig{
+ VK: config.VKConfig{
+ Enabled: true,
+ Token: *config.NewSecureString("test_token"),
+ GroupID: 123456789,
+ AllowFrom: []string{"123456789"},
+ },
+ },
+ }
+ ch, err := NewVKChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !ch.IsAllowedSender(bus.SenderInfo{PlatformID: "123456789"}) {
+ t.Error("user 123456789 should be allowed")
+ }
+ if ch.IsAllowedSender(bus.SenderInfo{PlatformID: "999999999"}) {
+ t.Error("user 999999999 should not be allowed")
+ }
+ })
+
+ t.Run("with group_trigger", func(t *testing.T) {
+ cfg := &config.Config{
+ Channels: config.ChannelsConfig{
+ VK: config.VKConfig{
+ Enabled: true,
+ Token: *config.NewSecureString("test_token"),
+ GroupID: 123456789,
+ GroupTrigger: config.GroupTriggerConfig{
+ MentionOnly: false,
+ Prefixes: []string{"/bot", "!bot"},
+ },
+ },
+ },
+ }
+ ch, err := NewVKChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if ch.Name() != "vk" {
+ t.Errorf("Name() = %q, want %q", ch.Name(), "vk")
+ }
+ })
+}
+
+func TestVKChannel_MaxMessageLength(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := &config.Config{
+ Channels: config.ChannelsConfig{
+ VK: config.VKConfig{
+ Enabled: true,
+ Token: *config.NewSecureString("test_token"),
+ GroupID: 123456789,
+ },
+ },
+ }
+ ch, err := NewVKChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ maxLen := ch.MaxMessageLength()
+ if maxLen != 4000 {
+ t.Errorf("MaxMessageLength() = %d, want 4000", maxLen)
+ }
+}
+
+func TestVKChannel_SplitMessage(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ maxLen int
+ want int
+ }{
+ {
+ name: "short message",
+ content: "hello",
+ maxLen: 4000,
+ want: 1,
+ },
+ {
+ name: "exact length",
+ content: string(make([]byte, 4000)),
+ maxLen: 4000,
+ want: 1,
+ },
+ {
+ name: "needs split",
+ content: string(make([]byte, 5000)),
+ maxLen: 4000,
+ want: 2,
+ },
+ {
+ name: "empty message",
+ content: "",
+ maxLen: 4000,
+ want: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := channels.SplitMessage(tt.content, tt.maxLen)
+ if len(got) != tt.want {
+ t.Errorf("SplitMessage() got %d parts, want %d parts", len(got), tt.want)
+ }
+ })
+ }
+}
+
+func TestVKChannel_ProcessAttachments(t *testing.T) {
+ tests := []struct {
+ name string
+ attachments []string
+ want string
+ }{
+ {
+ name: "empty attachments",
+ attachments: []string{},
+ want: "",
+ },
+ {
+ name: "photo attachment",
+ attachments: []string{"photo"},
+ want: "[photo]",
+ },
+ {
+ name: "video attachment",
+ attachments: []string{"video"},
+ want: "[video]",
+ },
+ {
+ name: "audio attachment",
+ attachments: []string{"audio"},
+ want: "[audio]",
+ },
+ {
+ name: "document attachment",
+ attachments: []string{"doc"},
+ want: "[doc]",
+ },
+ {
+ name: "sticker attachment",
+ attachments: []string{"sticker"},
+ want: "[sticker]",
+ },
+ {
+ name: "audio_message attachment",
+ attachments: []string{"audio_message"},
+ want: "[voice]",
+ },
+ {
+ name: "multiple attachments",
+ attachments: []string{"photo", "video", "audio"},
+ want: "[photo] [video] [audio]",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var result string
+ for i, att := range tt.attachments {
+ if i > 0 {
+ result += " "
+ }
+ if att == "audio_message" {
+ result += "[voice]"
+ } else {
+ result += "[" + att + "]"
+ }
+ }
+ if result != tt.want {
+ t.Errorf("processAttachments() = %q, want %q", result, tt.want)
+ }
+ })
+ }
+}
+
+func TestVKChannel_VoiceCapabilities(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := &config.Config{
+ Channels: config.ChannelsConfig{
+ VK: config.VKConfig{
+ Enabled: true,
+ Token: *config.NewSecureString("test_token"),
+ GroupID: 123456789,
+ },
+ },
+ }
+ ch, err := NewVKChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ caps := ch.VoiceCapabilities()
+ if !caps.ASR {
+ t.Error("VoiceCapabilities().ASR should be true")
+ }
+ if !caps.TTS {
+ t.Error("VoiceCapabilities().TTS should be true")
+ }
+}
diff --git a/pkg/channels/voice_capabilities.go b/pkg/channels/voice_capabilities.go
new file mode 100644
index 000000000..34fd24269
--- /dev/null
+++ b/pkg/channels/voice_capabilities.go
@@ -0,0 +1,58 @@
+package channels
+
+// VoiceCapabilities describes whether ASR (speech-to-text) and TTS (text-to-speech)
+// are available for a channel under the current configuration.
+type VoiceCapabilities struct {
+ ASR bool
+ TTS bool
+}
+
+// VoiceCapabilityProvider is an optional interface for channels that want to
+// explicitly declare their ASR/TTS support.
+type VoiceCapabilityProvider interface {
+ VoiceCapabilities() VoiceCapabilities
+}
+
+// Deprecated: Channels should implement VoiceCapabilityProvider instead.
+// To be removed once all existing capable channels conform to the interface.
+var asrCapableChannels = map[string]bool{
+ "discord": true,
+ "telegram": true,
+ "matrix": true,
+ "qq": true,
+ "weixin": true,
+ "line": true,
+ "feishu": true,
+ "onebot": true,
+}
+
+// DetectVoiceCapabilities returns ASR/TTS availability for a channel, gated by
+// whether providers are configured.
+func DetectVoiceCapabilities(channelName string, ch Channel, asrAvailable bool, ttsAvailable bool) VoiceCapabilities {
+ if ch == nil {
+ return VoiceCapabilities{}
+ }
+
+ if vcp, ok := ch.(VoiceCapabilityProvider); ok {
+ caps := vcp.VoiceCapabilities()
+ if !asrAvailable {
+ caps.ASR = false
+ }
+ if !ttsAvailable {
+ caps.TTS = false
+ }
+ return caps
+ }
+
+ caps := VoiceCapabilities{}
+ if asrAvailable {
+ caps.ASR = asrCapableChannels[channelName]
+ }
+ if ttsAvailable {
+ if _, ok := ch.(MediaSender); ok {
+ caps.TTS = true
+ }
+ }
+
+ return caps
+}
diff --git a/pkg/channels/weixin/weixin.go b/pkg/channels/weixin/weixin.go
index 5e62a8a3b..58d9369b2 100644
--- a/pkg/channels/weixin/weixin.go
+++ b/pkg/channels/weixin/weixin.go
@@ -414,3 +414,8 @@ func (c *WeixinChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]st
return nil, nil
}
+
+// VoiceCapabilities returns the voice capabilities of the channel.
+func (c *WeixinChannel) VoiceCapabilities() channels.VoiceCapabilities {
+ return channels.VoiceCapabilities{ASR: true, TTS: true}
+}
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 814ed9c4d..4767fcfec 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -802,13 +802,13 @@ type WebToolsConfig struct {
// the client-side web_search tool is hidden to avoid duplicate search surfaces,
// and the provider's built-in search is used instead. Falls back to client-side
// search when the provider does not support native search.
- PreferNative bool `json:"prefer_native" yaml:"-" env:"PICOCLAW_TOOLS_WEB_PREFER_NATIVE"`
+ PreferNative bool `yaml:"-" json:"prefer_native" env:"PICOCLAW_TOOLS_WEB_PREFER_NATIVE"`
// Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
// For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
- Proxy string `json:"proxy,omitempty" yaml:"-" env:"PICOCLAW_TOOLS_WEB_PROXY"`
- FetchLimitBytes int64 `json:"fetch_limit_bytes,omitempty" yaml:"-" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
- Format string `json:"format,omitempty" yaml:"-" env:"PICOCLAW_TOOLS_WEB_FORMAT"`
- PrivateHostWhitelist FlexibleStringSlice `json:"private_host_whitelist,omitempty" yaml:"-" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
+ Proxy string `yaml:"-" json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
+ FetchLimitBytes int64 `yaml:"-" json:"fetch_limit_bytes,omitempty" env:"PICOCLAW_TOOLS_WEB_FETCH_LIMIT_BYTES"`
+ Format string `yaml:"-" json:"format,omitempty" env:"PICOCLAW_TOOLS_WEB_FORMAT"`
+ PrivateHostWhitelist FlexibleStringSlice `yaml:"-" json:"private_host_whitelist,omitempty" env:"PICOCLAW_TOOLS_WEB_PRIVATE_HOST_WHITELIST"`
}
type CronToolsConfig struct {
@@ -987,7 +987,10 @@ func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
- logger.WarnF("config file not found, using default config", map[string]any{"path": path})
+ logger.WarnF(
+ "config file not found, using default config",
+ map[string]any{"path": path},
+ )
return DefaultConfig(), nil
}
logger.Errorf("failed to read config file: %v", err)
@@ -1010,7 +1013,10 @@ func LoadConfig(path string) (*Config, error) {
var cfg *Config
switch versionInfo.Version {
case 0:
- logger.InfoF("config migrate start", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
+ logger.InfoF(
+ "config migrate start",
+ map[string]any{"from": versionInfo.Version, "to": CurrentVersion},
+ )
// Legacy config (no version field)
v, e := loadConfigV0(data)
if e != nil {
@@ -1018,10 +1024,16 @@ func LoadConfig(path string) (*Config, error) {
}
cfg, e = v.Migrate()
if e != nil {
- logger.ErrorF("config migrate fail", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
+ logger.ErrorF(
+ "config migrate fail",
+ map[string]any{"from": versionInfo.Version, "to": CurrentVersion},
+ )
return nil, e
}
- logger.InfoF("config migrate success", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
+ logger.InfoF(
+ "config migrate success",
+ map[string]any{"from": versionInfo.Version, "to": CurrentVersion},
+ )
err = makeBackup(path)
if err != nil {
return nil, err
@@ -1029,7 +1041,10 @@ func LoadConfig(path string) (*Config, error) {
// Load existing security config and merge with migrated one to prevent data loss
secErr := loadSecurityConfig(cfg, securityPath(path))
if secErr != nil && !os.IsNotExist(secErr) {
- logger.WarnF("failed to load existing security config during migration", map[string]any{"error": secErr})
+ logger.WarnF(
+ "failed to load existing security config during migration",
+ map[string]any{"error": secErr},
+ )
return nil, fmt.Errorf("failed to load existing security config: %w", secErr)
}
defer func(cfg *Config) {
@@ -1037,7 +1052,10 @@ func LoadConfig(path string) (*Config, error) {
}(cfg)
case 1:
// V1→V2 migration: infer Enabled and migrate channel config fields
- logger.InfoF("config migrate start", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
+ logger.InfoF(
+ "config migrate start",
+ map[string]any{"from": versionInfo.Version, "to": CurrentVersion},
+ )
cfg, err = loadConfig(data)
if err != nil {
return nil, err
@@ -1051,7 +1069,10 @@ func LoadConfig(path string) (*Config, error) {
oldCfg := &configV1{Config: *cfg}
cfg, err = oldCfg.Migrate()
if err != nil {
- logger.ErrorF("config migrate fail", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
+ logger.ErrorF(
+ "config migrate fail",
+ map[string]any{"from": versionInfo.Version, "to": CurrentVersion},
+ )
return nil, err
}
@@ -1063,7 +1084,10 @@ func LoadConfig(path string) (*Config, error) {
defer func(cfg *Config) {
_ = SaveConfig(path, cfg)
}(cfg)
- logger.InfoF("config migrate success", map[string]any{"from": versionInfo.Version, "to": CurrentVersion})
+ logger.InfoF(
+ "config migrate success",
+ map[string]any{"from": versionInfo.Version, "to": CurrentVersion},
+ )
case CurrentVersion:
// Current version
cfg, err = loadConfig(data)
diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go
index 4b23a10ff..bb90fb2c4 100644
--- a/pkg/config/config_test.go
+++ b/pkg/config/config_test.go
@@ -177,6 +177,41 @@ func TestAgentConfig_FullParse(t *testing.T) {
}
}
+func TestDefaultConfig_MCPMaxInlineTextChars(t *testing.T) {
+ cfg := DefaultConfig()
+ if cfg.Tools.MCP.GetMaxInlineTextChars() != DefaultMCPMaxInlineTextChars {
+ t.Fatalf(
+ "DefaultConfig().Tools.MCP.GetMaxInlineTextChars() = %d, want %d",
+ cfg.Tools.MCP.GetMaxInlineTextChars(),
+ DefaultMCPMaxInlineTextChars,
+ )
+ }
+}
+
+func TestLoadConfig_MCPMaxInlineTextChars(t *testing.T) {
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "config.json")
+ raw := `{
+ "tools": {
+ "mcp": {
+ "enabled": true,
+ "max_inline_text_chars": 2048
+ }
+ }
+ }`
+ if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
+ t.Fatalf("WriteFile(configPath): %v", err)
+ }
+
+ cfg, err := LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error: %v", err)
+ }
+ if got := cfg.Tools.MCP.GetMaxInlineTextChars(); got != 2048 {
+ t.Fatalf("cfg.Tools.MCP.GetMaxInlineTextChars() = %d, want 2048", got)
+ }
+}
+
func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
jsonData := `{
"agents": {
@@ -253,41 +288,6 @@ func TestAgentConfig_ParsesDispatchRules(t *testing.T) {
}
}
-func TestDefaultConfig_MCPMaxInlineTextChars(t *testing.T) {
- cfg := DefaultConfig()
- if cfg.Tools.MCP.GetMaxInlineTextChars() != DefaultMCPMaxInlineTextChars {
- t.Fatalf(
- "DefaultConfig().Tools.MCP.GetMaxInlineTextChars() = %d, want %d",
- cfg.Tools.MCP.GetMaxInlineTextChars(),
- DefaultMCPMaxInlineTextChars,
- )
- }
-}
-
-func TestLoadConfig_MCPMaxInlineTextChars(t *testing.T) {
- dir := t.TempDir()
- configPath := filepath.Join(dir, "config.json")
- raw := `{
- "tools": {
- "mcp": {
- "enabled": true,
- "max_inline_text_chars": 2048
- }
- }
- }`
- if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
- t.Fatalf("WriteFile(configPath): %v", err)
- }
-
- cfg, err := LoadConfig(configPath)
- if err != nil {
- t.Fatalf("LoadConfig() error: %v", err)
- }
- if got := cfg.Tools.MCP.GetMaxInlineTextChars(); got != 2048 {
- t.Fatalf("cfg.Tools.MCP.GetMaxInlineTextChars() = %d, want 2048", got)
- }
-}
-
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
@@ -366,13 +366,6 @@ func TestDefaultConfig_Channels(t *testing.T) {
}
}
-func TestDefaultConfig_ReadFileMode(t *testing.T) {
- cfg := DefaultConfig()
- if cfg.Tools.ReadFile.EffectiveMode() != ReadFileModeBytes {
- t.Fatalf("expected default read_file mode %q, got %q", ReadFileModeBytes, cfg.Tools.ReadFile.EffectiveMode())
- }
-}
-
// TestDefaultConfig_WebTools verifies web tools config
func TestDefaultConfig_WebTools(t *testing.T) {
cfg := DefaultConfig()
@@ -1557,6 +1550,42 @@ func TestModelConfig_ExtraBodyRoundTrip(t *testing.T) {
}
}
+func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
+ dir := t.TempDir()
+ cfgPath := filepath.Join(dir, "config.json")
+
+ cfg := &Config{
+ Version: CurrentVersion,
+ ModelList: []*ModelConfig{
+ {
+ ModelName: "test-model",
+ Model: "openai/test",
+ APIKeys: SimpleSecureStrings("sk-test"),
+ CustomHeaders: map[string]string{"X-Source": "coding-plan", "X-Agent": "openclaw"},
+ },
+ },
+ }
+
+ if err := SaveConfig(cfgPath, cfg); err != nil {
+ t.Fatalf("SaveConfig error: %v", err)
+ }
+
+ loaded, err := LoadConfig(cfgPath)
+ if err != nil {
+ t.Fatalf("LoadConfig error: %v", err)
+ }
+
+ if loaded.ModelList[0].CustomHeaders == nil {
+ t.Fatal("CustomHeaders should not be nil after round-trip")
+ }
+ if got := loaded.ModelList[0].CustomHeaders["X-Source"]; got != "coding-plan" {
+ t.Errorf("CustomHeaders[X-Source] = %q, want coding-plan", got)
+ }
+ if got := loaded.ModelList[0].CustomHeaders["X-Agent"]; got != "openclaw" {
+ t.Errorf("CustomHeaders[X-Agent] = %q, want openclaw", got)
+ }
+}
+
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
cfg := DefaultConfig()
diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go
index 64aed5e8c..8be84bdf6 100644
--- a/pkg/gateway/gateway.go
+++ b/pkg/gateway/gateway.go
@@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"path/filepath"
+ "sort"
"strings"
"sync"
"sync/atomic"
@@ -13,6 +14,8 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/agent"
+ "github.com/sipeed/picoclaw/pkg/audio/asr"
+ "github.com/sipeed/picoclaw/pkg/audio/tts"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
_ "github.com/sipeed/picoclaw/pkg/channels/dingtalk"
@@ -25,7 +28,9 @@ import (
"github.com/sipeed/picoclaw/pkg/channels/pico"
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
+ _ "github.com/sipeed/picoclaw/pkg/channels/teams_webhook"
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
+ _ "github.com/sipeed/picoclaw/pkg/channels/vk"
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
_ "github.com/sipeed/picoclaw/pkg/channels/weixin"
_ "github.com/sipeed/picoclaw/pkg/channels/whatsapp"
@@ -41,7 +46,6 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
- "github.com/sipeed/picoclaw/pkg/voice"
)
const (
@@ -61,6 +65,7 @@ type services struct {
ChannelManager *channels.Manager
DeviceService *devices.Service
HealthServer *health.Server
+ VoiceAgentCancel context.CancelFunc
manualReloadChan chan struct{}
reloading atomic.Bool
authToken string
@@ -70,6 +75,27 @@ type startupBlockedProvider struct {
reason string
}
+func logChannelVoiceCapabilities(cm *channels.Manager, asrAvailable bool, ttsAvailable bool) {
+ if cm == nil {
+ return
+ }
+
+ names := cm.GetEnabledChannels()
+ sort.Strings(names)
+ for _, name := range names {
+ ch, ok := cm.GetChannel(name)
+ if !ok {
+ continue
+ }
+ caps := channels.DetectVoiceCapabilities(name, ch, asrAvailable, ttsAvailable)
+ logger.InfoCF("voice", "Channel voice capabilities", map[string]any{
+ "channel": name,
+ "asr": caps.ASR,
+ "tts": caps.TTS,
+ })
+ }
+}
+
func (p *startupBlockedProvider) Chat(
_ context.Context,
_ []providers.Message,
@@ -125,6 +151,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
// Enforce singleton: write PID file with generated token.
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
if err != nil {
+ logger.Warnf("write pid file failed: %v", err)
return fmt.Errorf("singleton check failed: %w", err)
}
defer pid.RemovePidFile(homePath)
@@ -337,11 +364,14 @@ func setupAndStartServices(
agentLoop.SetChannelManager(runningServices.ChannelManager)
agentLoop.SetMediaStore(runningServices.MediaStore)
- if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
+ transcriber := asr.DetectTranscriber(cfg)
+ if transcriber != nil {
agentLoop.SetTranscriber(transcriber)
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
+ ttsAvailable := tts.DetectTTS(cfg) != nil
+
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
@@ -358,6 +388,16 @@ func setupAndStartServices(
return nil, fmt.Errorf("error starting channels: %w", err)
}
+ logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable)
+
+ if transcriber != nil {
+ // Start Voice Agent Orchestrator after channels are ready.
+ vaCtx, vaCancel := context.WithCancel(context.Background())
+ runningServices.VoiceAgentCancel = vaCancel
+ voiceAgent := asr.NewAgent(msgBus, transcriber)
+ voiceAgent.Start(vaCtx)
+ }
+
fmt.Printf(
"✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n",
cfg.Gateway.Host,
@@ -387,6 +427,9 @@ func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Dura
if !isReload && runningServices.ChannelManager != nil {
runningServices.ChannelManager.StopAll(shutdownCtx)
}
+ if runningServices.VoiceAgentCancel != nil {
+ runningServices.VoiceAgentCancel()
+ }
if runningServices.DeviceService != nil {
runningServices.DeviceService.Stop()
}
@@ -563,14 +606,22 @@ func restartServices(
fmt.Println(" ✓ Device event service restarted")
}
- transcriber := voice.DetectTranscriber(cfg)
+ transcriber := asr.DetectTranscriber(cfg)
al.SetTranscriber(transcriber)
if transcriber != nil {
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
+
+ // Start Voice Agent Orchestrator on reload
+ vaCtx, vaCancel := context.WithCancel(context.Background())
+ runningServices.VoiceAgentCancel = vaCancel
+ voiceAgent := asr.NewAgent(msgBus, transcriber)
+ voiceAgent.Start(vaCtx)
} else {
logger.InfoCF("voice", "Transcription disabled", nil)
}
+ ttsAvailable := tts.DetectTTS(cfg) != nil
+ logChannelVoiceCapabilities(runningServices.ChannelManager, transcriber != nil, ttsAvailable)
// NOTE: PID file is written once at startup and not updated on reload.
// Changing the gateway listen address requires a full restart.
diff --git a/pkg/logger/panic.go b/pkg/logger/panic.go
index e53e4351a..0a9125dda 100644
--- a/pkg/logger/panic.go
+++ b/pkg/logger/panic.go
@@ -2,12 +2,15 @@ package logger
import (
"fmt"
+ "io"
"os"
"path/filepath"
"runtime/debug"
"time"
)
+var panicWriter io.WriteCloser
+
func InitPanic(filePath string) (func(), error) {
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
@@ -16,21 +19,36 @@ func InitPanic(filePath string) (func(), error) {
if writer == nil {
return nil, fmt.Errorf("failed to create log file: %s", filePath)
}
+ if panicWriter != nil {
+ _ = panicWriter.Close()
+ }
+ panicWriter = writer
return func() {
- defer writer.Close()
+ defer func() {
+ writer.Close()
+ panicWriter = nil
+ }()
if err := recover(); err != nil {
- now := time.Now().Format("2006-01-02 15:04:05")
- stack := debug.Stack()
- logMsg := "\n\n====================\n[" + now + "] PANIC OCCURRED: " + fmt.Sprintf(
- "%v",
- err,
- ) + "\n" + string(
- stack,
- )
-
- writer.Write([]byte(logMsg))
+ RecoverPanicNoExit(err)
os.Exit(1)
}
}, nil
}
+
+func RecoverPanicNoExit(err any) {
+ if panicWriter == nil {
+ Errorf("panicWriter is nil, should not happen")
+ return
+ }
+ now := time.Now().Format("2006-01-02 15:04:05")
+ stack := debug.Stack()
+ logMsg := "\n\n====================\n[" + now + "] PANIC OCCURRED: " + fmt.Sprintf(
+ "%v",
+ err,
+ ) + "\n" + string(
+ stack,
+ )
+
+ panicWriter.Write([]byte(logMsg))
+}
diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go
index 7e2c6b892..f6728330f 100644
--- a/pkg/memory/jsonl.go
+++ b/pkg/memory/jsonl.go
@@ -613,6 +613,33 @@ func (s *JSONLStore) rewriteJSONL(
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
}
+// ListSessions returns all known session keys by reading .meta.json files.
+func (s *JSONLStore) ListSessions() []string {
+ entries, err := os.ReadDir(s.dir)
+ if err != nil {
+ return nil
+ }
+ var keys []string
+ for _, entry := range entries {
+ if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
+ continue
+ }
+ // Read the meta file to get the original key
+ data, err := os.ReadFile(filepath.Join(s.dir, entry.Name()))
+ if err != nil {
+ continue
+ }
+ var meta SessionMeta
+ if err := json.Unmarshal(data, &meta); err != nil {
+ continue
+ }
+ if meta.Key != "" {
+ keys = append(keys, meta.Key)
+ }
+ }
+ return keys
+}
+
func (s *JSONLStore) Close() error {
return nil
}
diff --git a/pkg/memory/store.go b/pkg/memory/store.go
index b6e11707d..11526b27c 100644
--- a/pkg/memory/store.go
+++ b/pkg/memory/store.go
@@ -37,6 +37,9 @@ type Store interface {
// data. Backends that do not accumulate dead data may return nil.
Compact(ctx context.Context, sessionKey string) error
+ // ListSessions returns all known session keys.
+ ListSessions() []string
+
// Close releases any resources held by the store.
Close() error
}
diff --git a/pkg/pid/pidfile.go b/pkg/pid/pidfile.go
index 584b9b2b5..0b6d461c2 100644
--- a/pkg/pid/pidfile.go
+++ b/pkg/pid/pidfile.go
@@ -4,6 +4,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
"os"
"path/filepath"
@@ -16,6 +17,8 @@ import (
const pidFileName = ".picoclaw.pid"
+var errInvalidPidFile = errors.New("invalid pid file")
+
// PidFileData is the JSON structure stored in the PID file.
type PidFileData struct {
PID int `json:"pid"`
@@ -94,6 +97,7 @@ func WritePidFile(homePath, host string, port int) (*PidFileData, error) {
os.Remove(tmp)
return nil, fmt.Errorf("failed to rename pid file: %w", err)
}
+ logger.Debugf("wrote pid file: %s success", pidPath)
return data, nil
}
@@ -108,10 +112,20 @@ func ReadPidFileWithCheck(homePath string) *PidFileData {
pidPath := pidFilePath(homePath)
data, err := readPidFileUnlocked(pidPath)
if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ if errors.Is(err, errInvalidPidFile) {
+ logger.Warnf("invalid pid file, remove it: %s (%v)", pidPath, err)
+ _ = os.Remove(pidPath)
+ return nil
+ }
+ logger.Debugf("failed to read pid file: %s", err)
return nil
}
if !isProcessRunning(data.PID) {
+ logger.Debugf("process not running, remove pid file: %s", pidPath)
os.Remove(pidPath)
return nil
}
@@ -147,12 +161,12 @@ func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
var data PidFileData
if err := json.Unmarshal(raw, &data); err != nil {
- return nil, err
+ return nil, fmt.Errorf("%w: %v", errInvalidPidFile, err)
}
// Validate PID is a positive integer.
if data.PID <= 0 {
- return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID)
+ return nil, fmt.Errorf("%w: pid=%d", errInvalidPidFile, data.PID)
}
return &data, nil
diff --git a/pkg/pid/pidfile_test.go b/pkg/pid/pidfile_test.go
index 921f590ad..e54b93f4f 100644
--- a/pkg/pid/pidfile_test.go
+++ b/pkg/pid/pidfile_test.go
@@ -191,6 +191,22 @@ func TestReadPidFileWithCheckStalePID(t *testing.T) {
}
}
+// TestReadPidFileWithCheckInvalidFile auto-cleans malformed PID file.
+func TestReadPidFileWithCheckInvalidFile(t *testing.T) {
+ dir := tmpDir(t)
+ path := filepath.Join(dir, pidFileName)
+ os.WriteFile(path, []byte("not json"), 0o600)
+
+ data := ReadPidFileWithCheck(dir)
+ if data != nil {
+ t.Error("expected nil for malformed pid file")
+ }
+
+ if _, err := os.Stat(path); !os.IsNotExist(err) {
+ t.Error("malformed PID file should be removed")
+ }
+}
+
// TestRemovePidFile removes the PID file for the current process.
func TestRemovePidFile(t *testing.T) {
dir := tmpDir(t)
diff --git a/pkg/pid/pidfile_unix.go b/pkg/pid/pidfile_unix.go
index 5459d8370..7bc53b752 100644
--- a/pkg/pid/pidfile_unix.go
+++ b/pkg/pid/pidfile_unix.go
@@ -3,6 +3,7 @@
package pid
import (
+ "errors"
"os"
"syscall"
)
@@ -18,5 +19,11 @@ func isProcessRunning(pid int) bool {
return false
}
// Signal(nil) does not kill the process but checks existence on Unix.
- return p.Signal(syscall.Signal(0)) == nil
+ err = p.Signal(syscall.Signal(0))
+ if err == nil {
+ return true
+ }
+ var errno syscall.Errno
+ // EPERM means the process exists but we are not allowed to signal it.
+ return errors.As(err, &errno) && errno == syscall.EPERM
}
diff --git a/pkg/pid/pidfile_windows.go b/pkg/pid/pidfile_windows.go
index 6a2cce793..6d8b79552 100644
--- a/pkg/pid/pidfile_windows.go
+++ b/pkg/pid/pidfile_windows.go
@@ -23,19 +23,19 @@ func isProcessRunning(pid int) bool {
return false
}
- handle, _, err := procOpenProcess.Call(
+ handle, _, _ := procOpenProcess.Call(
uintptr(processQueryLimitedInformation),
0,
uintptr(pid),
)
- if handle == 0 || err != nil {
+ if handle == 0 {
return false
}
defer procCloseHandle.Call(handle)
var exitCode uint32
- ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
- if ret == 0 || err != nil {
+ ret, _, _ := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
+ if ret == 0 {
return false
}
return exitCode == stillActive
diff --git a/pkg/providers/anthropic_messages/provider.go b/pkg/providers/anthropic_messages/provider.go
index 6a1c473dd..1e865b709 100644
--- a/pkg/providers/anthropic_messages/provider.go
+++ b/pkg/providers/anthropic_messages/provider.go
@@ -41,15 +41,16 @@ type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
+ userAgent string
}
// NewProvider creates a new Anthropic Messages API provider.
-func NewProvider(apiKey, apiBase string) *Provider {
- return NewProviderWithTimeout(apiKey, apiBase, 0)
+func NewProvider(apiKey, apiBase, userAgent string) *Provider {
+ return NewProviderWithTimeout(apiKey, apiBase, userAgent, 0)
}
// NewProviderWithTimeout creates a provider with custom request timeout.
-func NewProviderWithTimeout(apiKey, apiBase string, timeoutSeconds int) *Provider {
+func NewProviderWithTimeout(apiKey, apiBase, userAgent string, timeoutSeconds int) *Provider {
baseURL := normalizeBaseURL(apiBase)
timeout := defaultRequestTimeout
if timeoutSeconds > 0 {
@@ -57,8 +58,9 @@ func NewProviderWithTimeout(apiKey, apiBase string, timeoutSeconds int) *Provide
}
return &Provider{
- apiKey: apiKey,
- apiBase: baseURL,
+ apiKey: apiKey,
+ apiBase: baseURL,
+ userAgent: userAgent,
httpClient: &http.Client{
Timeout: timeout,
},
@@ -105,6 +107,9 @@ func (p *Provider) Chat(
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-API-Key", p.apiKey) //nolint:canonicalheader // Anthropic API requires exact header name
req.Header.Set("Anthropic-Version", defaultAPIVersion)
+ if p.userAgent != "" {
+ req.Header.Set("User-Agent", p.userAgent)
+ }
// Execute request
resp, err := p.httpClient.Do(req)
diff --git a/pkg/providers/anthropic_messages/provider_test.go b/pkg/providers/anthropic_messages/provider_test.go
index 39bc48117..ba9d24b66 100644
--- a/pkg/providers/anthropic_messages/provider_test.go
+++ b/pkg/providers/anthropic_messages/provider_test.go
@@ -411,7 +411,7 @@ func TestNormalizeBaseURL(t *testing.T) {
}
func TestNewProvider(t *testing.T) {
- provider := NewProvider("test-key", "https://api.example.com")
+ provider := NewProvider("test-key", "https://api.example.com", "")
if provider == nil {
t.Fatal("NewProvider() returned nil")
}
@@ -424,7 +424,7 @@ func TestNewProvider(t *testing.T) {
}
func TestGetDefaultModel(t *testing.T) {
- provider := NewProvider("test-key", "")
+ provider := NewProvider("test-key", "", "")
got := provider.GetDefaultModel()
expected := "claude-sonnet-4.6"
if got != expected {
@@ -743,7 +743,7 @@ func TestProviderChatErrors(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create provider using constructor to ensure proper initialization
- provider := NewProvider(tt.apiKey, "https://api.example.com")
+ provider := NewProvider(tt.apiKey, "https://api.example.com", "")
_, err := provider.Chat(context.Background(), tt.messages, nil, "test-model", nil)
if err == nil {
diff --git a/pkg/providers/azure/provider.go b/pkg/providers/azure/provider.go
index 429b26798..7de703248 100644
--- a/pkg/providers/azure/provider.go
+++ b/pkg/providers/azure/provider.go
@@ -36,6 +36,7 @@ type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
+ userAgent string
}
// Option configures the Azure Provider.
@@ -50,11 +51,19 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
}
+// WithUserAgent sets the User-Agent header for requests.
+func WithUserAgent(userAgent string) Option {
+ return func(p *Provider) {
+ p.userAgent = userAgent
+ }
+}
+
// NewProvider creates a new Azure OpenAI provider.
-func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
+func NewProvider(apiKey, apiBase, proxy, userAgent string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
+ userAgent: userAgent,
httpClient: common.NewHTTPClient(proxy),
}
@@ -68,9 +77,9 @@ func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
}
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
-func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
+func NewProviderWithTimeout(apiKey, apiBase, proxy, userAgent string, requestTimeoutSeconds int) *Provider {
return NewProvider(
- apiKey, apiBase, proxy,
+ apiKey, apiBase, proxy, userAgent,
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
)
}
@@ -141,6 +150,9 @@ func (p *Provider) Chat(
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
+ if p.userAgent != "" {
+ req.Header.Set("User-Agent", p.userAgent)
+ }
resp, err := p.httpClient.Do(req)
if err != nil {
diff --git a/pkg/providers/azure/provider_test.go b/pkg/providers/azure/provider_test.go
index b3752ea50..816ae97dc 100644
--- a/pkg/providers/azure/provider_test.go
+++ b/pkg/providers/azure/provider_test.go
@@ -46,7 +46,7 @@ func TestProviderChat_AzureURLConstruction(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -69,7 +69,7 @@ func TestProviderChat_AzureAuthHeader(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-azure-key", server.URL, "")
+ p := NewProvider("test-azure-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -92,7 +92,7 @@ func TestProviderChat_AzureRequestBodyContainsModel(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -112,7 +112,7 @@ func TestProviderChat_AzureUsesMaxOutputTokens(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
@@ -144,7 +144,7 @@ func TestProviderChat_AzureStoreIsFalse(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -161,7 +161,7 @@ func TestProviderChat_AzureHTTPError(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("bad-key", server.URL, "")
+ p := NewProvider("bad-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
@@ -176,7 +176,7 @@ func TestProviderChat_AzureRateLimitError(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for 429, got nil")
@@ -194,7 +194,7 @@ func TestProviderChat_AzureServerError(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for 500, got nil")
@@ -229,7 +229,7 @@ func TestProviderChat_AzureParseTextOutput(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -270,7 +270,7 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
}))
defer server.Close()
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -287,7 +287,7 @@ func TestProviderChat_AzureParseToolCalls(t *testing.T) {
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
- p := NewProvider("test-key", "", "")
+ p := NewProvider("test-key", "", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
@@ -295,21 +295,21 @@ func TestProvider_AzureEmptyAPIBase(t *testing.T) {
}
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
- p := NewProvider("test-key", "https://example.com", "")
+ p := NewProvider("test-key", "https://example.com", "", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
- p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
+ p := NewProvider("test-key", "https://example.com", "", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
- p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
+ p := NewProviderWithTimeout("test-key", "https://example.com", "", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
@@ -343,7 +343,7 @@ func TestProviderChat_AzureNativeWebSearchInjection(t *testing.T) {
},
}
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
// With native_search=true: user-defined web_search should be replaced by built-in
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment",
@@ -393,7 +393,7 @@ func TestProviderChat_AzureNoNativeWebSearch(t *testing.T) {
},
}
- p := NewProvider("test-key", server.URL, "")
+ p := NewProvider("test-key", server.URL, "", "")
// Without native_search: user-defined web_search should be kept as-is
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, tools, "deployment", nil)
diff --git a/pkg/providers/common/common_test.go b/pkg/providers/common/common_test.go
index 79a637d48..0a4d5f34a 100644
--- a/pkg/providers/common/common_test.go
+++ b/pkg/providers/common/common_test.go
@@ -262,6 +262,22 @@ func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
}
}
+func TestDecodeToolCallArguments_StringJSON_NewlineEscape(t *testing.T) {
+ raw := json.RawMessage(`"{\"content\":\"line1\\nline2\"}"`)
+ args := DecodeToolCallArguments(raw, "write_file")
+ if args["content"] != "line1\nline2" {
+ t.Errorf("content = %q, want newline-expanded string", args["content"])
+ }
+}
+
+func TestDecodeToolCallArguments_StringJSON_LiteralBackslashN(t *testing.T) {
+ raw := json.RawMessage(`"{\"content\":\"line1\\\\nline2\"}"`)
+ args := DecodeToolCallArguments(raw, "write_file")
+ if args["content"] != `line1\nline2` {
+ t.Errorf("content = %q, want literal backslash-n", args["content"])
+ }
+}
+
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
args := DecodeToolCallArguments(nil, "test")
if len(args) != 0 {
diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go
index e956db209..f13dc646c 100644
--- a/pkg/providers/factory_provider.go
+++ b/pkg/providers/factory_provider.go
@@ -24,6 +24,7 @@ type protocolMeta struct {
var protocolMetaByName = map[string]protocolMeta{
"openai": {defaultAPIBase: "https://api.openai.com/v1"},
+ "venice": {defaultAPIBase: "https://api.venice.ai/api/v1"},
"openrouter": {defaultAPIBase: "https://openrouter.ai/api/v1"},
"litellm": {defaultAPIBase: "http://localhost:4000/v1"},
"lmstudio": {defaultAPIBase: "http://localhost:1234/v1", emptyAPIKeyAllowed: true},
@@ -98,6 +99,19 @@ func ExtractProtocol(model string) (protocol, modelID string) {
return protocol, modelID
}
+// ResolveAPIBase returns the configured API base, or the protocol default when
+// the model uses an HTTP-based provider family with a known default endpoint.
+func ResolveAPIBase(cfg *config.ModelConfig) string {
+ if cfg == nil {
+ return ""
+ }
+ if apiBase := strings.TrimSpace(cfg.APIBase); apiBase != "" {
+ return strings.TrimRight(apiBase, "/")
+ }
+ protocol, _ := ExtractProtocol(cfg.Model)
+ return strings.TrimRight(getDefaultAPIBase(protocol), "/")
+}
+
// CreateProviderFromConfig creates a provider based on the ModelConfig.
// It uses the protocol prefix in the Model field to determine which provider to create.
// Supported protocol families include OpenAI-compatible prefixes (e.g., openai, openrouter, groq, gemini),
@@ -115,6 +129,11 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
protocol, modelID := ExtractProtocol(cfg.Model)
+ userAgent := cfg.UserAgent
+ if userAgent == "" {
+ userAgent = fmt.Sprintf("PicoClaw/%s", config.Version)
+ }
+
switch protocol {
case "openai":
// OpenAI with OAuth/token auth (Codex-style)
@@ -138,8 +157,10 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
+ userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
+ cfg.CustomHeaders,
), modelID, nil
case "azure", "azure-openai":
@@ -157,6 +178,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.APIKey(),
cfg.APIBase,
cfg.Proxy,
+ userAgent,
cfg.RequestTimeout,
), modelID, nil
@@ -196,7 +218,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
}
return provider, modelID, nil
- case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "gemini", "nvidia",
+ case "litellm", "lmstudio", "openrouter", "groq", "zhipu", "gemini", "nvidia", "venice",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
"qwen-us", "dashscope-us", "mistral", "avian", "longcat", "modelscope", "novita",
@@ -214,8 +236,10 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
+ userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
+ cfg.CustomHeaders,
), modelID, nil
case "minimax":
@@ -239,8 +263,10 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
+ userAgent,
cfg.RequestTimeout,
extraBody,
+ cfg.CustomHeaders,
), modelID, nil
case "anthropic":
@@ -265,8 +291,10 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
apiBase,
cfg.Proxy,
cfg.MaxTokensField,
+ userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
+ cfg.CustomHeaders,
), modelID, nil
case "anthropic-messages":
@@ -281,6 +309,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
return anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
+ userAgent,
cfg.RequestTimeout,
), modelID, nil
@@ -296,6 +325,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
return anthropicmessages.NewProviderWithTimeout(
cfg.APIKey(),
apiBase,
+ userAgent,
cfg.RequestTimeout,
), modelID, nil
diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go
index 588b81650..c362463ae 100644
--- a/pkg/providers/factory_provider_test.go
+++ b/pkg/providers/factory_provider_test.go
@@ -112,6 +112,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
protocol string
}{
{"openai", "openai"},
+ {"venice", "venice"},
{"groq", "groq"},
{"novita", "novita"},
{"openrouter", "openrouter"},
@@ -160,6 +161,12 @@ func TestGetDefaultAPIBase_LMStudio(t *testing.T) {
}
}
+func TestGetDefaultAPIBase_Venice(t *testing.T) {
+ if got := getDefaultAPIBase("venice"); got != "https://api.venice.ai/api/v1" {
+ t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "venice", got, "https://api.venice.ai/api/v1")
+ }
+}
+
func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "test-litellm",
@@ -362,6 +369,28 @@ func TestCreateProviderFromConfig_Mimo(t *testing.T) {
}
}
+func TestCreateProviderFromConfig_Venice(t *testing.T) {
+ cfg := &config.ModelConfig{
+ ModelName: "test-venice",
+ Model: "venice/venice-uncensored",
+ }
+ cfg.SetAPIKey("test-key")
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+ if provider == nil {
+ t.Fatal("CreateProviderFromConfig() returned nil provider")
+ }
+ if modelID != "venice-uncensored" {
+ t.Errorf("modelID = %q, want %q", modelID, "venice-uncensored")
+ }
+ if _, ok := provider.(*HTTPProvider); !ok {
+ t.Fatalf("expected *HTTPProvider, got %T", provider)
+ }
+}
+
func TestGetDefaultAPIBase_Mimo(t *testing.T) {
if got := getDefaultAPIBase("mimo"); got != "https://api.xiaomimimo.com/v1" {
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "mimo", got, "https://api.xiaomimimo.com/v1")
@@ -817,6 +846,150 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
}
}
+func TestCreateProviderFromConfig_CustomHeaders(t *testing.T) {
+ var gotSource, gotAuth string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotSource = r.Header.Get("X-Source")
+ gotAuth = r.Header.Get("Authorization")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
+ }))
+ defer server.Close()
+
+ cfg := &config.ModelConfig{
+ ModelName: "test-headers",
+ Model: "openai/gpt-4o",
+ APIBase: server.URL,
+ CustomHeaders: map[string]string{"X-Source": "coding-plan", "Authorization": "Token config-auth"},
+ }
+ cfg.SetAPIKey("test-key")
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+
+ _, err = provider.Chat(
+ t.Context(),
+ []Message{{Role: "user", Content: "hi"}},
+ nil,
+ modelID,
+ nil,
+ )
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ if gotSource != "coding-plan" {
+ t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
+ }
+ if gotAuth != "Token config-auth" {
+ t.Fatalf("Authorization = %q, want %q", gotAuth, "Token config-auth")
+ }
+}
+
+// openaiCompatResponse is the JSON response used by OpenAI-compatible providers.
+const openaiCompatResponse = `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
+
+// anthropicResponse is the JSON response used by Anthropic providers.
+const anthropicResponse = `{"content":[{"type":"text","text":"ok"}],"stop_reason":"end_turn","model":"claude-sonnet-4-20250514","usage":{"input_tokens":10,"output_tokens":5}}`
+
+func TestCreateProviderFromConfig_UserAgent(t *testing.T) {
+ defaultUA := "PicoClaw/" + config.Version
+
+ tests := []struct {
+ name string
+ model string
+ userAgent string
+ apiKey string
+ response string
+ wantUA string
+ chatOpts map[string]any
+ }{
+ {
+ name: "openai default user agent",
+ model: "openai/gpt-4o",
+ apiKey: "test-key",
+ response: openaiCompatResponse,
+ wantUA: defaultUA,
+ },
+ {
+ name: "openai custom user agent",
+ model: "openai/gpt-4o",
+ apiKey: "test-key",
+ userAgent: "MyAgent/1.2.3",
+ response: openaiCompatResponse,
+ wantUA: "MyAgent/1.2.3",
+ },
+ {
+ name: "anthropic default user agent",
+ model: "anthropic/claude-sonnet-4-20250514",
+ apiKey: "test-key",
+ response: anthropicResponse,
+ wantUA: defaultUA,
+ },
+ {
+ name: "anthropic-messages default user agent",
+ model: "anthropic-messages/claude-sonnet-4-20250514",
+ apiKey: "test-key",
+ response: anthropicResponse,
+ wantUA: defaultUA,
+ chatOpts: map[string]any{"max_tokens": 1024},
+ },
+ {
+ name: "azure default user agent",
+ model: "azure/my-deployment",
+ apiKey: "test-azure-key",
+ response: openaiCompatResponse,
+ wantUA: defaultUA,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var receivedUA string
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ receivedUA = r.Header.Get("User-Agent")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(tt.response))
+ }))
+ defer server.Close()
+
+ cfg := &config.ModelConfig{
+ ModelName: "test-ua-" + tt.name,
+ Model: tt.model,
+ APIBase: server.URL,
+ UserAgent: tt.userAgent,
+ }
+ cfg.SetAPIKey(tt.apiKey)
+
+ provider, modelID, err := CreateProviderFromConfig(cfg)
+ if err != nil {
+ t.Fatalf("CreateProviderFromConfig() error = %v", err)
+ }
+ if provider == nil {
+ t.Fatal("CreateProviderFromConfig() returned nil provider")
+ }
+
+ _, err = provider.Chat(
+ t.Context(),
+ []Message{{Role: "user", Content: "hi"}},
+ nil,
+ modelID,
+ tt.chatOpts,
+ )
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ if receivedUA != tt.wantUA {
+ t.Errorf("User-Agent = %q, want %q", receivedUA, tt.wantUA)
+ }
+ })
+ }
+}
+
func TestCreateProviderFromConfig_Bedrock(t *testing.T) {
// Set dummy AWS env vars to make test deterministic
t.Setenv("AWS_ACCESS_KEY_ID", "test-key")
diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go
index 549ec7837..36092105b 100644
--- a/pkg/providers/fallback.go
+++ b/pkg/providers/fallback.go
@@ -10,12 +10,24 @@ import (
// FallbackChain orchestrates model fallback across multiple candidates.
type FallbackChain struct {
cooldown *CooldownTracker
+ rl *RateLimiterRegistry
}
// FallbackCandidate represents one model/provider to try.
type FallbackCandidate struct {
- Provider string
- Model string
+ Provider string
+ Model string
+ RPM int // requests per minute; 0 means unrestricted
+ IdentityKey string // optional stable config identity for cooldown/rate limiting
+}
+
+// StableKey returns the candidate's config-level identity when available,
+// otherwise it falls back to the runtime provider/model key.
+func (c FallbackCandidate) StableKey() string {
+ if key := strings.TrimSpace(c.IdentityKey); key != "" {
+ return key
+ }
+ return ModelKey(c.Provider, c.Model)
}
// FallbackResult contains the successful response and metadata about all attempts.
@@ -36,9 +48,10 @@ type FallbackAttempt struct {
Skipped bool // true if skipped due to cooldown
}
-// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
-func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
- return &FallbackChain{cooldown: cooldown}
+// NewFallbackChain creates a new fallback chain with the given cooldown tracker
+// and rate limiter registry.
+func NewFallbackChain(cooldown *CooldownTracker, rl *RateLimiterRegistry) *FallbackChain {
+ return &FallbackChain{cooldown: cooldown, rl: rl}
}
// ResolveCandidates parses model config into a deduplicated candidate list.
@@ -117,9 +130,9 @@ func (fc *FallbackChain) Execute(
return nil, context.Canceled
}
- // Check cooldown (per provider/model, not just provider).
- // This allows multi-key failover where different keys use different model names.
- cooldownKey := ModelKey(candidate.Provider, candidate.Model)
+ // Check cooldown per stable candidate identity, not just provider/model.
+ // This allows aliases and multi-key configs to fail over independently.
+ cooldownKey := candidate.StableKey()
if !fc.cooldown.IsAvailable(cooldownKey) {
remaining := fc.cooldown.CooldownRemaining(cooldownKey)
result.Attempts = append(result.Attempts, FallbackAttempt{
@@ -136,6 +149,33 @@ func (fc *FallbackChain) Execute(
continue
}
+ // Enforce per-candidate rate limit before calling the provider.
+ // If this candidate is locally saturated, try other candidates first.
+ if fc.rl != nil {
+ if !fc.rl.TryAcquire(cooldownKey) {
+ if i < len(candidates)-1 {
+ result.Attempts = append(result.Attempts, FallbackAttempt{
+ Provider: candidate.Provider,
+ Model: candidate.Model,
+ Skipped: true,
+ Reason: FailoverRateLimit,
+ Error: fmt.Errorf("%s waiting for local rate limit token", cooldownKey),
+ })
+ continue
+ }
+ if waitErr := fc.rl.Wait(ctx, cooldownKey); waitErr != nil {
+ result.Attempts = append(result.Attempts, FallbackAttempt{
+ Provider: candidate.Provider,
+ Model: candidate.Model,
+ Skipped: true,
+ Reason: FailoverRateLimit,
+ Error: waitErr,
+ })
+ return nil, waitErr
+ }
+ }
+ }
+
// Execute the run function.
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
@@ -229,6 +269,34 @@ func (fc *FallbackChain) ExecuteImage(
return nil, context.Canceled
}
+ // Enforce per-candidate rate limit before calling the provider.
+ // If this candidate is locally saturated, try other candidates first.
+ imageKey := candidate.StableKey()
+ if fc.rl != nil {
+ if !fc.rl.TryAcquire(imageKey) {
+ if i < len(candidates)-1 {
+ result.Attempts = append(result.Attempts, FallbackAttempt{
+ Provider: candidate.Provider,
+ Model: candidate.Model,
+ Skipped: true,
+ Reason: FailoverRateLimit,
+ Error: fmt.Errorf("%s waiting for local rate limit token", imageKey),
+ })
+ continue
+ }
+ if waitErr := fc.rl.Wait(ctx, imageKey); waitErr != nil {
+ result.Attempts = append(result.Attempts, FallbackAttempt{
+ Provider: candidate.Provider,
+ Model: candidate.Model,
+ Skipped: true,
+ Reason: FailoverRateLimit,
+ Error: waitErr,
+ })
+ return nil, waitErr
+ }
+ }
+ }
+
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
elapsed := time.Since(start)
diff --git a/pkg/providers/fallback_multikey_test.go b/pkg/providers/fallback_multikey_test.go
index 9ed8fa73c..10481ec61 100644
--- a/pkg/providers/fallback_multikey_test.go
+++ b/pkg/providers/fallback_multikey_test.go
@@ -25,7 +25,7 @@ func TestMultiKeyFailover(t *testing.T) {
// Create fallback chain
cooldown := NewCooldownTracker()
- chain := NewFallbackChain(cooldown)
+ chain := NewFallbackChain(cooldown, nil)
// Mock run function: first call fails with 429, second succeeds
callCount := 0
@@ -82,7 +82,7 @@ func TestMultiKeyFailoverAllFail(t *testing.T) {
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
- chain := NewFallbackChain(cooldown)
+ chain := NewFallbackChain(cooldown, nil)
// Mock run function: all calls fail with rate limit
callCount := 0
@@ -127,7 +127,7 @@ func TestMultiKeyFailoverCooldown(t *testing.T) {
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
- chain := NewFallbackChain(cooldown)
+ chain := NewFallbackChain(cooldown, nil)
// Put the first model in cooldown (using ModelKey now, not just provider)
cooldownKey := ModelKey(candidates[0].Provider, candidates[0].Model)
@@ -183,7 +183,7 @@ func TestMultiKeyFailoverWithFormatError(t *testing.T) {
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
- chain := NewFallbackChain(cooldown)
+ chain := NewFallbackChain(cooldown, nil)
// Mock run function: first call fails with format error (bad request)
callCount := 0
@@ -263,7 +263,7 @@ func TestMultiKeyWithModelFallback(t *testing.T) {
}
cooldown := NewCooldownTracker()
- chain := NewFallbackChain(cooldown)
+ chain := NewFallbackChain(cooldown, nil)
// Mock run function: first two fail, third succeeds (model fallback)
callCount := 0
@@ -337,7 +337,7 @@ func TestMultiKeyFailoverMixedErrors(t *testing.T) {
candidates := ResolveCandidates(cfg, "zhipu")
cooldown := NewCooldownTracker()
- chain := NewFallbackChain(cooldown)
+ chain := NewFallbackChain(cooldown, nil)
// Mock run function: different errors for each key
callCount := 0
diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go
index 1a1118e33..54fb9b6ea 100644
--- a/pkg/providers/fallback_test.go
+++ b/pkg/providers/fallback_test.go
@@ -19,7 +19,7 @@ func successRun(content string) func(ctx context.Context, provider, model string
func TestFallback_SingleCandidate_Success(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
@@ -36,7 +36,7 @@ func TestFallback_SingleCandidate_Success(t *testing.T) {
func TestFallback_SecondCandidateSuccess(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
@@ -69,7 +69,7 @@ func TestFallback_SecondCandidateSuccess(t *testing.T) {
func TestFallback_AllFail(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
@@ -96,7 +96,7 @@ func TestFallback_AllFail(t *testing.T) {
func TestFallback_ContextCanceled(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
ctx, cancel := context.WithCancel(context.Background())
candidates := []FallbackCandidate{
@@ -123,7 +123,7 @@ func TestFallback_ContextCanceled(t *testing.T) {
func TestFallback_NonRetriableError(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
@@ -155,7 +155,7 @@ func TestFallback_NonRetriableError(t *testing.T) {
func TestFallback_CooldownSkip(t *testing.T) {
now := time.Now()
ct, _ := newTestTracker(now)
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
// Put openai/gpt-4 in cooldown (using ModelKey now)
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
@@ -193,7 +193,7 @@ func TestFallback_CooldownSkip(t *testing.T) {
func TestFallback_AllInCooldown(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
// Put all models in cooldown (using ModelKey now)
ct.MarkFailure(ModelKey("openai", "gpt-4"), FailoverRateLimit)
@@ -221,7 +221,7 @@ func TestFallback_AllInCooldown(t *testing.T) {
func TestFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
if err == nil {
@@ -232,7 +232,7 @@ func TestFallback_NoCandidates(t *testing.T) {
func TestFallback_EmptyFallbacks(t *testing.T) {
// Single primary, no fallbacks: should work like direct call
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
@@ -246,7 +246,7 @@ func TestFallback_EmptyFallbacks(t *testing.T) {
func TestFallback_UnclassifiedError(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
@@ -270,7 +270,7 @@ func TestFallback_UnclassifiedError(t *testing.T) {
func TestFallback_SuccessResetsCooldown(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
modelKey := ModelKey("openai", "gpt-4")
@@ -293,11 +293,78 @@ func TestFallback_SuccessResetsCooldown(t *testing.T) {
}
}
+func assertLocalRateLimitSkipsToHealthyFallback(
+ t *testing.T,
+ primaryKey string,
+ fallbackKey string,
+ fallbackProvider string,
+ fallbackModel string,
+ execute func(context.Context, *FallbackChain, []FallbackCandidate,
+ func(context.Context, string, string) (*LLMResponse, error),
+ ) (*FallbackResult, error),
+ responseContent string,
+) {
+ t.Helper()
+
+ ct := NewCooldownTracker()
+ rl := NewRateLimiterRegistry()
+ rl.Register(primaryKey, 1)
+ if err := rl.Wait(context.Background(), primaryKey); err != nil {
+ t.Fatalf("failed to pre-drain primary limiter: %v", err)
+ }
+
+ fc := NewFallbackChain(ct, rl)
+ candidates := []FallbackCandidate{
+ {Provider: "openai", Model: "gpt-4o", IdentityKey: primaryKey},
+ {Provider: fallbackProvider, Model: fallbackModel, IdentityKey: fallbackKey},
+ }
+
+ run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
+ if provider != fallbackProvider || model != fallbackModel {
+ t.Fatalf("expected fallback candidate to run, got %s/%s", provider, model)
+ }
+ return &LLMResponse{Content: responseContent, FinishReason: "stop"}, nil
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
+ defer cancel()
+
+ result, err := execute(ctx, fc, candidates, run)
+ if err != nil {
+ t.Fatalf("expected fallback success, got error: %v", err)
+ }
+ if result.Provider != fallbackProvider || result.Model != fallbackModel {
+ t.Fatalf("result = %s/%s, want %s/%s", result.Provider, result.Model, fallbackProvider, fallbackModel)
+ }
+ if len(result.Attempts) != 1 || !result.Attempts[0].Skipped {
+ t.Fatalf("expected one skipped primary attempt, got %+v", result.Attempts)
+ }
+}
+
+func TestFallback_LocalRateLimitSkipsToHealthyFallback(t *testing.T) {
+ assertLocalRateLimitSkipsToHealthyFallback(
+ t,
+ "model_name:primary",
+ "model_name:fallback",
+ "anthropic",
+ "claude",
+ func(
+ ctx context.Context,
+ fc *FallbackChain,
+ candidates []FallbackCandidate,
+ run func(context.Context, string, string) (*LLMResponse, error),
+ ) (*FallbackResult, error) {
+ return fc.Execute(ctx, candidates, run)
+ },
+ "fallback ok",
+ )
+}
+
// --- Image Fallback Tests ---
func TestImageFallback_Success(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
@@ -311,7 +378,7 @@ func TestImageFallback_Success(t *testing.T) {
func TestImageFallback_DimensionError(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
@@ -335,7 +402,7 @@ func TestImageFallback_DimensionError(t *testing.T) {
func TestImageFallback_SizeError(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
@@ -359,7 +426,7 @@ func TestImageFallback_SizeError(t *testing.T) {
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
@@ -384,9 +451,28 @@ func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
}
}
+func TestImageFallback_LocalRateLimitSkipsToHealthyFallback(t *testing.T) {
+ assertLocalRateLimitSkipsToHealthyFallback(
+ t,
+ "model_name:primary-image",
+ "model_name:fallback-image",
+ "anthropic",
+ "claude-sonnet",
+ func(
+ ctx context.Context,
+ fc *FallbackChain,
+ candidates []FallbackCandidate,
+ run func(context.Context, string, string) (*LLMResponse, error),
+ ) (*FallbackResult, error) {
+ return fc.ExecuteImage(ctx, candidates, run)
+ },
+ "image fallback ok",
+ )
+}
+
func TestImageFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
- fc := NewFallbackChain(ct)
+ fc := NewFallbackChain(ct, nil)
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
if err == nil {
diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go
index f2ff52f1d..ac91f15f6 100644
--- a/pkg/providers/http_provider.go
+++ b/pkg/providers/http_provider.go
@@ -24,13 +24,14 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
- return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, 0, nil)
+ return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil, nil)
}
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
- apiKey, apiBase, proxy, maxTokensField string,
+ apiKey, apiBase, proxy, maxTokensField, userAgent string,
requestTimeoutSeconds int,
extraBody map[string]any,
+ customHeaders map[string]string,
) *HTTPProvider {
return &HTTPProvider{
delegate: openai_compat.NewProvider(
@@ -40,6 +41,8 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
openai_compat.WithExtraBody(extraBody),
+ openai_compat.WithCustomHeaders(customHeaders),
+ openai_compat.WithUserAgent(userAgent),
),
}
}
diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go
index aa9473731..d25a0fce4 100644
--- a/pkg/providers/openai_compat/provider.go
+++ b/pkg/providers/openai_compat/provider.go
@@ -36,6 +36,8 @@ type Provider struct {
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
extraBody map[string]any // Additional fields to inject into request body
+ customHeaders map[string]string
+ userAgent string
}
type Option func(*Provider)
@@ -44,6 +46,7 @@ const defaultRequestTimeout = common.DefaultRequestTimeout
var stripModelPrefixProviders = map[string]struct{}{
"litellm": {},
+ "venice": {},
"moonshot": {},
"nvidia": {},
"groq": {},
@@ -65,6 +68,12 @@ func WithMaxTokensField(maxTokensField string) Option {
}
}
+func WithUserAgent(userAgent string) Option {
+ return func(p *Provider) {
+ p.userAgent = userAgent
+ }
+}
+
func WithRequestTimeout(timeout time.Duration) Option {
return func(p *Provider) {
if timeout > 0 {
@@ -79,6 +88,12 @@ func WithExtraBody(extraBody map[string]any) Option {
}
}
+func WithCustomHeaders(customHeaders map[string]string) Option {
+ return func(p *Provider) {
+ p.customHeaders = customHeaders
+ }
+}
+
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
@@ -173,6 +188,15 @@ func (p *Provider) buildRequestBody(
return requestBody
}
+func (p *Provider) applyCustomHeaders(req *http.Request) {
+ for k, v := range p.customHeaders {
+ if strings.TrimSpace(k) == "" {
+ continue
+ }
+ req.Header.Set(k, v)
+ }
+}
+
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
@@ -197,9 +221,13 @@ func (p *Provider) Chat(
}
req.Header.Set("Content-Type", "application/json")
+ if p.userAgent != "" {
+ req.Header.Set("User-Agent", p.userAgent)
+ }
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
+ p.applyCustomHeaders(req)
resp, err := p.httpClient.Do(req)
if err != nil {
@@ -243,9 +271,13 @@ func (p *Provider) ChatStream(
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
+ if p.userAgent != "" {
+ req.Header.Set("User-Agent", p.userAgent)
+ }
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
+ p.applyCustomHeaders(req)
// Use a client without Timeout for streaming — the http.Client.Timeout covers
// the entire request lifecycle including body reads, which would kill long streams.
diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go
index 823b0ff28..d140d63d6 100644
--- a/pkg/providers/openai_compat/provider_test.go
+++ b/pkg/providers/openai_compat/provider_test.go
@@ -479,6 +479,11 @@ func TestProviderChat_StripsKnownProviderPrefixes(t *testing.T) {
input: "lmstudio/openai/gpt-oss-20b",
wantModel: "openai/gpt-oss-20b",
},
+ {
+ name: "strips venice prefix",
+ input: "venice/venice-uncensored",
+ wantModel: "venice-uncensored",
+ },
{
name: "strips deepseek prefix",
input: "deepseek/deepseek-chat",
@@ -587,6 +592,9 @@ func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("lmstudio/openai/gpt-oss-20b", "http://localhost:1234/v1"); got != "openai/gpt-oss-20b" {
t.Fatalf("normalizeModel(lmstudio) = %q, want %q", got, "openai/gpt-oss-20b")
}
+ if got := normalizeModel("venice/venice-uncensored", "https://api.venice.ai/api/v1"); got != "venice-uncensored" {
+ t.Fatalf("normalizeModel(venice) = %q, want %q", got, "venice-uncensored")
+ }
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
@@ -702,6 +710,111 @@ func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
}
}
+func TestProviderChat_CustomHeadersInjected(t *testing.T) {
+ var gotSource, gotAuth, gotUserAgent string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotSource = r.Header.Get("X-Source")
+ gotAuth = r.Header.Get("Authorization")
+ gotUserAgent = r.Header.Get("User-Agent")
+ resp := map[string]any{
+ "choices": []map[string]any{
+ {
+ "message": map[string]any{"content": "ok"},
+ "finish_reason": "stop",
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer server.Close()
+
+ p := NewProvider(
+ "key",
+ server.URL,
+ "",
+ WithUserAgent("PicoClaw/Test"),
+ WithCustomHeaders(map[string]string{
+ "X-Source": "coding-plan",
+ "Authorization": "Token custom-auth",
+ "User-Agent": "Custom-UA/1.0",
+ }),
+ )
+
+ _, err := p.Chat(
+ t.Context(),
+ []Message{{Role: "user", Content: "hi"}},
+ nil,
+ "gpt-4o",
+ nil,
+ )
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+
+ if gotSource != "coding-plan" {
+ t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
+ }
+ if gotAuth != "Token custom-auth" {
+ t.Fatalf("Authorization = %q, want %q", gotAuth, "Token custom-auth")
+ }
+ if gotUserAgent != "Custom-UA/1.0" {
+ t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/1.0")
+ }
+}
+
+func TestProviderChatStream_CustomHeadersInjected(t *testing.T) {
+ var gotSource, gotAuth, gotUserAgent string
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotSource = r.Header.Get("X-Source")
+ gotAuth = r.Header.Get("Authorization")
+ gotUserAgent = r.Header.Get("User-Agent")
+
+ w.Header().Set("Content-Type", "text/event-stream")
+ _, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":\"stop\"}]}\n\n"))
+ _, _ = w.Write([]byte("data: [DONE]\n\n"))
+ }))
+ defer server.Close()
+
+ p := NewProvider(
+ "key",
+ server.URL,
+ "",
+ WithUserAgent("PicoClaw/Test"),
+ WithCustomHeaders(map[string]string{
+ "X-Source": "coding-plan",
+ "Authorization": "Token stream-auth",
+ "User-Agent": "Custom-UA/Stream",
+ }),
+ )
+
+ out, err := p.ChatStream(
+ t.Context(),
+ []Message{{Role: "user", Content: "hi"}},
+ nil,
+ "gpt-4o",
+ nil,
+ nil,
+ )
+ if err != nil {
+ t.Fatalf("ChatStream() error = %v", err)
+ }
+ if out.Content != "ok" {
+ t.Fatalf("Content = %q, want %q", out.Content, "ok")
+ }
+ if gotSource != "coding-plan" {
+ t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
+ }
+ if gotAuth != "Token stream-auth" {
+ t.Fatalf("Authorization = %q, want %q", gotAuth, "Token stream-auth")
+ }
+ if gotUserAgent != "Custom-UA/Stream" {
+ t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/Stream")
+ }
+}
+
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
diff --git a/pkg/providers/ratelimiter.go b/pkg/providers/ratelimiter.go
new file mode 100644
index 000000000..f475b58fb
--- /dev/null
+++ b/pkg/providers/ratelimiter.go
@@ -0,0 +1,144 @@
+package providers
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+// RateLimiter implements a token-bucket rate limiter for a single key.
+// Allows up to RPM requests per minute with a burst equal to RPM.
+// Thread-safe.
+type RateLimiter struct {
+ mu sync.Mutex
+ rpm int
+ tokens float64
+ maxBurst float64
+ lastTick time.Time
+ nowFunc func() time.Time // for testing
+}
+
+func (rl *RateLimiter) refillLocked(now time.Time) {
+ elapsed := now.Sub(rl.lastTick).Seconds()
+ rl.lastTick = now
+
+ // Refill tokens proportional to elapsed time.
+ refill := elapsed * float64(rl.rpm) / 60.0
+ rl.tokens = min(rl.maxBurst, rl.tokens+refill)
+}
+
+// newRateLimiter creates a RateLimiter that allows rpm requests/minute.
+func newRateLimiter(rpm int) *RateLimiter {
+ return &RateLimiter{
+ rpm: rpm,
+ tokens: float64(rpm), // start full
+ maxBurst: float64(rpm),
+ lastTick: time.Now(),
+ nowFunc: time.Now,
+ }
+}
+
+// Wait blocks until a token is available or ctx is canceled.
+// Returns ctx.Err() if canceled while waiting.
+func (rl *RateLimiter) Wait(ctx context.Context) error {
+ for {
+ rl.mu.Lock()
+ now := rl.nowFunc()
+ rl.refillLocked(now)
+
+ if rl.tokens >= 1.0 {
+ rl.tokens--
+ rl.mu.Unlock()
+ return nil
+ }
+
+ // Calculate how long until a token is available.
+ deficit := 1.0 - rl.tokens
+ waitSec := deficit / (float64(rl.rpm) / 60.0)
+ rl.mu.Unlock()
+
+ timer := time.NewTimer(time.Duration(waitSec * float64(time.Second)))
+ select {
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ return ctx.Err()
+ case <-timer.C:
+ // Loop to re-check (another goroutine may have consumed the token).
+ }
+ }
+}
+
+// TryAcquire attempts to consume a token without blocking.
+func (rl *RateLimiter) TryAcquire() bool {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ rl.refillLocked(rl.nowFunc())
+ if rl.tokens < 1.0 {
+ return false
+ }
+ rl.tokens--
+ return true
+}
+
+// RateLimiterRegistry holds per-candidate rate limiters.
+// Candidates with RPM=0 are unrestricted.
+// Thread-safe for concurrent reads/writes.
+type RateLimiterRegistry struct {
+ mu sync.RWMutex
+ limiters map[string]*RateLimiter
+}
+
+// NewRateLimiterRegistry creates an empty registry.
+func NewRateLimiterRegistry() *RateLimiterRegistry {
+ return &RateLimiterRegistry{
+ limiters: make(map[string]*RateLimiter),
+ }
+}
+
+// Register adds a rate limiter for the given key at the given RPM.
+// If rpm <= 0, no limiter is registered (unrestricted).
+func (r *RateLimiterRegistry) Register(key string, rpm int) {
+ if rpm <= 0 {
+ return
+ }
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.limiters[key] = newRateLimiter(rpm)
+}
+
+// Wait acquires a token for the given key, blocking if needed.
+// If no limiter is registered for key, returns immediately.
+func (r *RateLimiterRegistry) Wait(ctx context.Context, key string) error {
+ r.mu.RLock()
+ rl := r.limiters[key]
+ r.mu.RUnlock()
+ if rl == nil {
+ return nil
+ }
+ return rl.Wait(ctx)
+}
+
+// TryAcquire attempts to consume a token for the given key without blocking.
+// If no limiter is registered for key, it returns true.
+func (r *RateLimiterRegistry) TryAcquire(key string) bool {
+ r.mu.RLock()
+ rl := r.limiters[key]
+ r.mu.RUnlock()
+ if rl == nil {
+ return true
+ }
+ return rl.TryAcquire()
+}
+
+// RegisterCandidates registers rate limiters for all candidates that have RPM > 0.
+// Candidates with RPM == 0 are ignored (no restriction).
+func (r *RateLimiterRegistry) RegisterCandidates(candidates []FallbackCandidate) {
+ for _, c := range candidates {
+ if c.RPM > 0 {
+ r.Register(c.StableKey(), c.RPM)
+ }
+ }
+}
diff --git a/pkg/providers/ratelimiter_test.go b/pkg/providers/ratelimiter_test.go
new file mode 100644
index 000000000..9972616e9
--- /dev/null
+++ b/pkg/providers/ratelimiter_test.go
@@ -0,0 +1,209 @@
+package providers
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// TestRateLimiter_AllowsUpToRPM verifies that up to RPM requests pass immediately
+// (burst capacity) and the (RPM+1)-th request is delayed.
+func TestRateLimiter_AllowsUpToRPM(t *testing.T) {
+ rpm := 5
+ rl := newRateLimiter(rpm)
+
+ // All rpm tokens should be available immediately (bucket starts full).
+ for i := 0; i < rpm; i++ {
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ if err := rl.Wait(ctx); err != nil {
+ t.Fatalf("request %d should pass immediately, got: %v", i+1, err)
+ }
+ cancel()
+ }
+
+ // The next request must wait; cancel it to confirm it blocks.
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancel()
+ err := rl.Wait(ctx)
+ if err == nil {
+ t.Fatal("expected request beyond RPM to block, but it passed immediately")
+ }
+}
+
+// TestRateLimiter_ContextCancellation verifies that a blocked Wait respects cancellation.
+func TestRateLimiter_ContextCancellation(t *testing.T) {
+ rl := newRateLimiter(1)
+
+ // Drain the one token.
+ ctx := context.Background()
+ if err := rl.Wait(ctx); err != nil {
+ t.Fatalf("first request failed: %v", err)
+ }
+
+ // Second request should block; cancel it.
+ cancelCtx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
+ defer cancel()
+ err := rl.Wait(cancelCtx)
+ if err == nil {
+ t.Fatal("expected cancellation error, got nil")
+ }
+}
+
+// TestRateLimiter_TokenRefill verifies that tokens refill over time.
+func TestRateLimiter_TokenRefill(t *testing.T) {
+ rpm := 60 // 1 token per second
+ rl := newRateLimiter(rpm)
+
+ // Drain all tokens.
+ for i := 0; i < rpm; i++ {
+ rl.Wait(context.Background()) //nolint:errcheck
+ }
+
+ // Advance time via nowFunc: simulate 2 seconds passing (should give 2 tokens).
+ start := time.Now()
+ rl.nowFunc = func() time.Time { return start.Add(2 * time.Second) }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer cancel()
+ if err := rl.Wait(ctx); err != nil {
+ t.Fatalf("expected refilled token to be available: %v", err)
+ }
+}
+
+// TestRateLimiterRegistry_NoLimiter verifies that keys without a registered limiter pass freely.
+func TestRateLimiterRegistry_NoLimiter(t *testing.T) {
+ r := NewRateLimiterRegistry()
+ ctx := context.Background()
+ for i := 0; i < 100; i++ {
+ if err := r.Wait(ctx, "unregistered/key"); err != nil {
+ t.Fatalf("unregistered key should not block: %v", err)
+ }
+ }
+}
+
+// TestRateLimiterRegistry_ZeroRPM verifies that RPM=0 means no limiter is registered.
+func TestRateLimiterRegistry_ZeroRPM(t *testing.T) {
+ r := NewRateLimiterRegistry()
+ r.Register("some/key", 0)
+ ctx := context.Background()
+ for i := 0; i < 50; i++ {
+ if err := r.Wait(ctx, "some/key"); err != nil {
+ t.Fatalf("zero-RPM key should not block: %v", err)
+ }
+ }
+}
+
+// TestRateLimiterRegistry_Enforcement verifies the registry enforces RPM per key.
+func TestRateLimiterRegistry_Enforcement(t *testing.T) {
+ r := NewRateLimiterRegistry()
+ r.Register("openai/gpt-4o", 3)
+
+ // First 3 calls should pass (burst = RPM).
+ for i := 0; i < 3; i++ {
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ if err := r.Wait(ctx, "openai/gpt-4o"); err != nil {
+ t.Fatalf("call %d should pass: %v", i+1, err)
+ }
+ cancel()
+ }
+
+ // 4th call should block.
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancel()
+ if err := r.Wait(ctx, "openai/gpt-4o"); err == nil {
+ t.Fatal("4th call should have been rate-limited")
+ }
+}
+
+// TestRateLimiterRegistry_RegisterCandidates verifies that RegisterCandidates
+// correctly picks up RPM from FallbackCandidate.
+func TestRateLimiterRegistry_RegisterCandidates(t *testing.T) {
+ r := NewRateLimiterRegistry()
+ candidates := []FallbackCandidate{
+ {Provider: "openai", Model: "gpt-4o", RPM: 2},
+ {Provider: "anthropic", Model: "claude-3", RPM: 0}, // no limit
+ }
+ r.RegisterCandidates(candidates)
+
+ // openai/gpt-4o: 2 tokens burst, 3rd should block.
+ for i := 0; i < 2; i++ {
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ if err := r.Wait(ctx, "openai/gpt-4o"); err != nil {
+ t.Fatalf("openai call %d should pass: %v", i+1, err)
+ }
+ cancel()
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancel()
+ if err := r.Wait(ctx, "openai/gpt-4o"); err == nil {
+ t.Fatal("openai 3rd call should have been limited")
+ }
+
+ // anthropic/claude-3: no limit, should always pass.
+ for i := 0; i < 10; i++ {
+ if err := r.Wait(context.Background(), "anthropic/claude-3"); err != nil {
+ t.Fatalf("anthropic call should not be limited: %v", err)
+ }
+ }
+}
+
+func TestRateLimiterRegistry_RegisterCandidatesUsesStableIdentity(t *testing.T) {
+ r := NewRateLimiterRegistry()
+ candidates := []FallbackCandidate{
+ {Provider: "openai", Model: "gpt-4o", RPM: 1, IdentityKey: "model_name:primary"},
+ {Provider: "openai", Model: "gpt-4o", RPM: 2, IdentityKey: "model_name:fallback"},
+ }
+ r.RegisterCandidates(candidates)
+
+ if err := r.Wait(context.Background(), "model_name:primary"); err != nil {
+ t.Fatalf("primary first call should pass: %v", err)
+ }
+ if err := r.Wait(context.Background(), "model_name:fallback"); err != nil {
+ t.Fatalf("fallback first call should pass: %v", err)
+ }
+ if err := r.Wait(context.Background(), "model_name:fallback"); err != nil {
+ t.Fatalf("fallback second call should pass: %v", err)
+ }
+
+ ctxPrimary, cancelPrimary := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancelPrimary()
+ if err := r.Wait(ctxPrimary, "model_name:primary"); err == nil {
+ t.Fatal("primary second call should have been limited")
+ }
+
+ ctxFallback, cancelFallback := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancelFallback()
+ if err := r.Wait(ctxFallback, "model_name:fallback"); err == nil {
+ t.Fatal("fallback third call should have been limited")
+ }
+}
+
+// TestRateLimiter_Concurrency verifies thread safety under concurrent access.
+func TestRateLimiter_Concurrency(t *testing.T) {
+ rpm := 20
+ rl := newRateLimiter(rpm)
+ var passed atomic.Int64
+ var wg sync.WaitGroup
+
+ // Launch 30 goroutines; only ~20 should pass immediately.
+ for i := 0; i < 30; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
+ defer cancel()
+ if rl.Wait(ctx) == nil {
+ passed.Add(1)
+ }
+ }()
+ }
+ wg.Wait()
+
+ got := passed.Load()
+ // Allow small timing slack: between rpm-2 and rpm+2.
+ if got < int64(rpm-2) || got > int64(rpm+2) {
+ t.Fatalf("expected ~%d immediate passes, got %d", rpm, got)
+ }
+}
diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go
index a33e1eb5c..7d0908158 100644
--- a/pkg/providers/toolcall_utils.go
+++ b/pkg/providers/toolcall_utils.go
@@ -23,6 +23,12 @@ func buildCLIToolsPrompt(tools []ToolDefinition) string {
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
+ sb.WriteString("Escaping rules (what to type in `function.arguments`):\n")
+ sb.WriteString("- Use `\\n` to represent a real newline character.\n")
+ sb.WriteString("- Use `\\\\n` to represent a literal backslash+n sequence (`\\n`).\n")
+ sb.WriteString(
+ "- `function.arguments` is a JSON-encoded string, so quotes/backslashes must be escaped in the outer payload.\n\n",
+ )
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
diff --git a/pkg/seahorse/.omc/state/last-tool-error.json b/pkg/seahorse/.omc/state/last-tool-error.json
new file mode 100644
index 000000000..2e7273e23
--- /dev/null
+++ b/pkg/seahorse/.omc/state/last-tool-error.json
@@ -0,0 +1,7 @@
+{
+ "tool_name": "Bash",
+ "tool_input_preview": "{\"command\":\"cd /home/yliu/repos/picoclaw && make lint 2>&1\",\"timeout\":120000}",
+ "error": "Exit code 2\npkg/agent/context_seahorse_test.go:1027:1: File is not properly formatted (gci)\n\t\t\tEarliestAt: &now,\n^\n1 issues:\n* gci: 1\nmake: *** [Makefile:264: lint] Error 1",
+ "timestamp": "2026-04-04T02:38:32.067Z",
+ "retry_count": 6
+}
\ No newline at end of file
diff --git a/pkg/seahorse/compact_until_under_test.go b/pkg/seahorse/compact_until_under_test.go
new file mode 100644
index 000000000..2bb96c263
--- /dev/null
+++ b/pkg/seahorse/compact_until_under_test.go
@@ -0,0 +1,58 @@
+package seahorse
+
+import (
+ "context"
+ "testing"
+)
+
+// =============================================================================
+// CompactUntilUnder iteration cap
+// =============================================================================
+
+func TestCompactUntilUnderIterationCap(t *testing.T) {
+ // Setup: create a conversation with so many tokens that compaction
+ // will never reach the budget. The iteration cap prevents infinite loops.
+ //
+ // We use a mock CompleteFn that always returns the same content,
+ // and a budget of 0 which tokens can never reach.
+ // Without the cap, this would loop forever.
+
+ db := openTestDB(t)
+ if err := runSchema(db); err != nil {
+ t.Fatalf("migration: %v", err)
+ }
+ s := &Store{db: db}
+
+ conv, _ := s.GetOrCreateConversation(context.Background(), "agent:iter-cap")
+ convID := conv.ConversationID
+
+ // Add many messages to ensure there's plenty to compact
+ for i := 0; i < 40; i++ {
+ m, _ := s.AddMessage(context.Background(), convID, "user",
+ "this is a long message with lots of tokens to push context over budget", 100)
+ s.AppendContextMessage(context.Background(), convID, m.ID)
+ }
+
+ // A completeFn that always succeeds but returns non-reducing content
+ mockComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ return "Summary that doesn't reduce tokens much.", nil
+ }
+
+ ce, cancel := newTestCompactionEngineWithStore(s, mockComplete)
+ defer cancel()
+
+ // Use budget=1 so tokens can never reach budget
+ // (each message is 100 tokens, so 40 messages = 4000 tokens, budget 1 is unreachable)
+ // The function should stop after maxCompactIterations, not loop forever
+ ce.config = Config{} // ensure defaults
+
+ result, err := ce.CompactUntilUnder(context.Background(), convID, 1)
+ if err != nil {
+ // Should not error — should stop gracefully
+ t.Fatalf("CompactUntilUnder with budget=0: %v", err)
+ }
+
+ // The function should have completed within reasonable time
+ // If it exceeded the cap, it would still return (not hang)
+ _ = result
+}
diff --git a/pkg/seahorse/parts_roundtrip_test.go b/pkg/seahorse/parts_roundtrip_test.go
new file mode 100644
index 000000000..02df8a9ea
--- /dev/null
+++ b/pkg/seahorse/parts_roundtrip_test.go
@@ -0,0 +1,144 @@
+package seahorse
+
+import (
+ "context"
+ "testing"
+ "time"
+)
+
+// =============================================================================
+// Bug 1: formatMessagesForSummary ignores Parts
+// - formatMessagesForSummary only reads m.Content, empty for Part-based messages
+// - truncateSummary has same issue
+// =============================================================================
+
+func TestFormatMessagesForSummaryIncludesParts(t *testing.T) {
+ ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
+
+ messages := []Message{
+ {ID: 1, Role: "user", Content: "hello world", CreatedAt: ts},
+ {
+ ID: 2,
+ Role: "assistant",
+ Content: "", // empty — real content is in Parts
+ Parts: []MessagePart{
+ {Type: "text", Text: "I will run a command"},
+ {Type: "tool_use", Name: "bash", Arguments: `{"command":"ls -la"}`, ToolCallID: "call_1"},
+ },
+ CreatedAt: ts.Add(time.Minute),
+ },
+ {
+ ID: 3,
+ Role: "tool",
+ Content: "", // empty — real content is in Parts
+ Parts: []MessagePart{
+ {Type: "tool_result", Text: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
+ },
+ CreatedAt: ts.Add(2 * time.Minute),
+ },
+ }
+
+ result := formatMessagesForSummary(messages)
+
+ // Must contain the plain text message
+ if !contains(result, "hello world") {
+ t.Error("formatMessagesForSummary: missing plain text content")
+ }
+
+ // Must contain tool_use info (not blank)
+ if !contains(result, "bash") || !contains(result, "ls -la") {
+ t.Errorf("formatMessagesForSummary: tool_use info missing from Parts.\nGot:\n%s", result)
+ }
+
+ // Must contain tool_result info (not blank)
+ if !contains(result, "file1.txt") {
+ t.Errorf("formatMessagesForSummary: tool_result text missing from Parts.\nGot:\n%s", result)
+ }
+}
+
+func TestTruncateSummaryIncludesParts(t *testing.T) {
+ messages := []Message{
+ {ID: 1, Role: "user", Content: "run the tests", CreatedAt: time.Now()},
+ {
+ ID: 2,
+ Role: "assistant",
+ Content: "", // empty
+ Parts: []MessagePart{
+ {Type: "tool_use", Name: "bash", Arguments: `{"command":"go test ./..."}`, ToolCallID: "call_1"},
+ },
+ CreatedAt: time.Now(),
+ },
+ {
+ ID: 3,
+ Role: "tool",
+ Content: "", // empty
+ Parts: []MessagePart{
+ {Type: "tool_result", Text: "PASS\nok 3.2s", ToolCallID: "call_1"},
+ },
+ CreatedAt: time.Now(),
+ },
+ }
+
+ result := truncateSummary(messages)
+
+ // Must contain plain text
+ if !contains(result, "run the tests") {
+ t.Error("truncateSummary: missing plain text content")
+ }
+
+ // Must contain tool info from Parts (not blank)
+ if !contains(result, "bash") || !contains(result, "go test") {
+ t.Errorf("truncateSummary: tool_use info missing from Parts.\nGot:\n%s", result)
+ }
+
+ // Must contain tool_result from Parts
+ if !contains(result, "PASS") {
+ t.Errorf("truncateSummary: tool_result text missing from Parts.\nGot:\n%s", result)
+ }
+}
+
+// =============================================================================
+// Bug 2: SearchMessages cannot find Part-based messages
+// - FTS5 indexes empty content, LIKE queries empty content
+// =============================================================================
+
+func TestSearchMessagesFindsPartBasedMessages(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:search-parts")
+ convID := conv.ConversationID
+
+ // Add a plain message (searchable)
+ s.AddMessage(ctx, convID, "user", "list the files please", 5)
+
+ // Add a Part-based message (tool_use) — currently NOT searchable
+ parts := []MessagePart{
+ {Type: "tool_use", Name: "bash", Arguments: `{"command":"grep -r TODO ."}`, ToolCallID: "call_1"},
+ }
+ s.AddMessageWithParts(ctx, convID, "assistant", parts, 10)
+
+ // Add a Part-based message (tool_result) — currently NOT searchable
+ resultParts := []MessagePart{
+ {Type: "tool_result", Text: "main.go:42: TODO fix this bug", ToolCallID: "call_1"},
+ }
+ s.AddMessageWithParts(ctx, convID, "tool", resultParts, 10)
+
+ // Search for "grep" — should find the tool_use message
+ results, err := s.SearchMessages(ctx, SearchInput{Pattern: "grep"})
+ if err != nil {
+ t.Fatalf("SearchMessages: %v", err)
+ }
+ if len(results) == 0 {
+ t.Error("SearchMessages: 'grep' not found — Part-based messages are invisible to search")
+ }
+
+ // Search for "TODO fix" — should find the tool_result message
+ results2, err := s.SearchMessages(ctx, SearchInput{Pattern: "TODO fix"})
+ if err != nil {
+ t.Fatalf("SearchMessages: %v", err)
+ }
+ if len(results2) == 0 {
+ t.Error("SearchMessages: 'TODO fix' not found — tool_result messages are invisible to search")
+ }
+}
diff --git a/pkg/seahorse/schema.go b/pkg/seahorse/schema.go
new file mode 100644
index 000000000..effa6d60d
--- /dev/null
+++ b/pkg/seahorse/schema.go
@@ -0,0 +1,185 @@
+package seahorse
+
+import (
+ "database/sql"
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// SQL statements for FTS5 tables with trigram tokenizer.
+const (
+ sqlCreateSummariesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS summaries_fts USING fts5(
+ summary_id,
+ content,
+ tokenize="trigram"
+ )`
+ sqlCreateMessagesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
+ message_id,
+ content,
+ tokenize="trigram"
+ )`
+ sqlCheckFTS5Available = `CREATE VIRTUAL TABLE IF NOT EXISTS _fts5_check USING fts5(content)`
+ sqlCheckTrigramAvailable = `CREATE VIRTUAL TABLE IF NOT EXISTS _trigram_check USING fts5(content, tokenize="trigram")`
+ sqlDropFTS5Check = `DROP TABLE IF EXISTS _fts5_check`
+ sqlDropTrigramCheck = `DROP TABLE IF EXISTS _trigram_check`
+)
+
+// runSchema creates or upgrades the database schema.
+// All schemas are idempotent (safe to run multiple times).
+func runSchema(db *sql.DB) error {
+ // Check FTS5 support before creating tables
+ if err := checkFTS5Support(db); err != nil {
+ return fmt.Errorf("FTS5 check: %w", err)
+ }
+
+ stmts := []string{
+ `CREATE TABLE IF NOT EXISTS conversations (
+ conversation_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_key TEXT NOT NULL UNIQUE,
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
+ updated_at TEXT NOT NULL DEFAULT (datetime('now'))
+ )`,
+
+ `CREATE TABLE IF NOT EXISTS messages (
+ message_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
+ role TEXT NOT NULL,
+ content TEXT NOT NULL DEFAULT '',
+ token_count INTEGER NOT NULL DEFAULT 0,
+ created_at TEXT NOT NULL DEFAULT (datetime('now'))
+ )`,
+
+ `CREATE TABLE IF NOT EXISTS message_parts (
+ part_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ message_id INTEGER NOT NULL REFERENCES messages(message_id),
+ type TEXT NOT NULL,
+ text TEXT,
+ name TEXT,
+ arguments TEXT,
+ tool_call_id TEXT,
+ media_uri TEXT,
+ mime_type TEXT,
+ ordinal INTEGER NOT NULL DEFAULT 0
+ )`,
+
+ `CREATE TABLE IF NOT EXISTS summaries (
+ summary_id TEXT PRIMARY KEY,
+ conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
+ kind TEXT NOT NULL,
+ depth INTEGER NOT NULL DEFAULT 0,
+ content TEXT NOT NULL,
+ token_count INTEGER NOT NULL DEFAULT 0,
+ earliest_at TEXT,
+ latest_at TEXT,
+ descendant_count INTEGER NOT NULL DEFAULT 0,
+ descendant_token_count INTEGER NOT NULL DEFAULT 0,
+ source_message_token_count INTEGER NOT NULL DEFAULT 0,
+ model TEXT,
+ created_at TEXT NOT NULL DEFAULT (datetime('now'))
+ )`,
+
+ `CREATE TABLE IF NOT EXISTS summary_parents (
+ summary_id TEXT NOT NULL,
+ parent_summary_id TEXT NOT NULL,
+ PRIMARY KEY (summary_id, parent_summary_id)
+ )`,
+
+ `CREATE TABLE IF NOT EXISTS summary_messages (
+ summary_id TEXT NOT NULL,
+ message_id INTEGER NOT NULL,
+ ordinal INTEGER NOT NULL DEFAULT 0,
+ PRIMARY KEY (summary_id, message_id)
+ )`,
+
+ `CREATE TABLE IF NOT EXISTS context_items (
+ conversation_id INTEGER NOT NULL,
+ ordinal INTEGER NOT NULL,
+ item_type TEXT NOT NULL,
+ summary_id TEXT,
+ message_id INTEGER,
+ token_count INTEGER NOT NULL DEFAULT 0,
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
+ PRIMARY KEY (conversation_id, ordinal)
+ )`,
+
+ // FTS5 virtual table with trigram tokenizer for CJK support
+ sqlCreateSummariesFTS,
+
+ // FTS5 virtual table for message search with trigram tokenizer
+ sqlCreateMessagesFTS,
+
+ // Indexes for common query patterns
+ `CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)`,
+ `CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(conversation_id, created_at)`,
+ `CREATE INDEX IF NOT EXISTS idx_summaries_conversation ON summaries(conversation_id)`,
+ `CREATE INDEX IF NOT EXISTS idx_summaries_kind_depth ON summaries(conversation_id, kind, depth)`,
+ `CREATE INDEX IF NOT EXISTS idx_summary_parents_parent ON summary_parents(parent_summary_id)`,
+ `CREATE INDEX IF NOT EXISTS idx_summary_messages_message ON summary_messages(message_id)`,
+ `CREATE INDEX IF NOT EXISTS idx_context_items_conv ON context_items(conversation_id, ordinal)`,
+
+ // FTS5 triggers to keep summaries_fts in sync with summaries table
+ `CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON summaries BEGIN
+ INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
+ END`,
+ `CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON summaries BEGIN
+ INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
+ END`,
+ `CREATE TRIGGER IF NOT EXISTS summaries_au AFTER UPDATE ON summaries BEGIN
+ INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
+ INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
+ END`,
+
+ // FTS5 triggers to keep messages_fts in sync with messages table
+ `CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
+ INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
+ END`,
+ `CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
+ DELETE FROM messages_fts WHERE message_id = old.message_id;
+ END`,
+ `CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
+ DELETE FROM messages_fts WHERE message_id = old.message_id;
+ INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
+ END`,
+ }
+
+ for _, s := range stmts {
+ if _, err := db.Exec(s); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled.
+// This is required for full-text search with CJK (Chinese, Japanese, Korean) support.
+func checkFTS5Support(db *sql.DB) error {
+ // Check if FTS5 is compiled in
+ var fts5Enabled int
+ err := db.QueryRow(`SELECT sqlite_compileoption_used('ENABLE_FTS5')`).Scan(&fts5Enabled)
+ if err != nil {
+ // sqlite_compileoption_used might not exist in older SQLite
+ // Try a different approach: create a test FTS5 table
+ _, testErr := db.Exec(sqlCheckFTS5Available)
+ if testErr != nil {
+ return fmt.Errorf("SQLite FTS5 not available: %w (required for full-text search)", testErr)
+ }
+ db.Exec(sqlDropFTS5Check)
+ } else if fts5Enabled == 0 {
+ return fmt.Errorf("SQLite was compiled without FTS5 support (required for full-text search)")
+ }
+
+ // Check if trigram tokenizer is available by trying to create a test table
+ // Not all SQLite builds include the trigram tokenizer
+ _, err = db.Exec(sqlCheckTrigramAvailable)
+ if err != nil {
+ logger.WarnCF("seahorse", "SQLite trigram tokenizer not available, CJK search may be limited",
+ map[string]any{"error": err.Error()})
+ // Trigram is not strictly required, just better for CJK
+ // Don't return error, just log warning
+ } else {
+ db.Exec(sqlDropTrigramCheck)
+ }
+
+ return nil
+}
diff --git a/pkg/seahorse/schema_test.go b/pkg/seahorse/schema_test.go
new file mode 100644
index 000000000..17879f66c
--- /dev/null
+++ b/pkg/seahorse/schema_test.go
@@ -0,0 +1,211 @@
+package seahorse
+
+import (
+ "database/sql"
+ "testing"
+
+ _ "modernc.org/sqlite"
+)
+
+func openTestDB(t *testing.T) *sql.DB {
+ t.Helper()
+ db, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ t.Fatalf("open test db: %v", err)
+ }
+ t.Cleanup(func() { db.Close() })
+ return db
+}
+
+func TestRunMigrations(t *testing.T) {
+ db := openTestDB(t)
+
+ if err := runSchema(db); err != nil {
+ t.Fatalf("runSchema: %v", err)
+ }
+
+ // Verify all tables exist
+ tables := []string{
+ "conversations",
+ "messages",
+ "message_parts",
+ "summaries",
+ "summary_parents",
+ "summary_messages",
+ "context_items",
+ }
+ for _, tbl := range tables {
+ var name string
+ err := db.QueryRow(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name=?", tbl,
+ ).Scan(&name)
+ if err != nil {
+ t.Errorf("table %q not found: %v", tbl, err)
+ }
+ }
+
+ // Verify FTS5 virtual table exists
+ var ftsName string
+ err := db.QueryRow(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='summaries_fts'",
+ ).Scan(&ftsName)
+ if err != nil {
+ t.Errorf("FTS5 table summaries_fts not found: %v", err)
+ }
+}
+
+func TestRunMigrationsIdempotent(t *testing.T) {
+ db := openTestDB(t)
+
+ // Run migrations twice — should succeed both times
+ if err := runSchema(db); err != nil {
+ t.Fatalf("first migration: %v", err)
+ }
+ if err := runSchema(db); err != nil {
+ t.Fatalf("second migration (idempotent): %v", err)
+ }
+
+ // Verify we can still insert data after double migration
+ res, err := db.Exec(
+ "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
+ "test-session",
+ )
+ if err != nil {
+ t.Fatalf("insert after double migration: %v", err)
+ }
+ id, _ := res.LastInsertId()
+ if id == 0 {
+ t.Error("expected non-zero conversation id")
+ }
+}
+
+func TestMigrationConversationUnique(t *testing.T) {
+ db := openTestDB(t)
+ if err := runSchema(db); err != nil {
+ t.Fatalf("migration: %v", err)
+ }
+
+ // Insert first
+ _, err := db.Exec(
+ "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
+ "unique-key",
+ )
+ if err != nil {
+ t.Fatalf("first insert: %v", err)
+ }
+
+ // Duplicate should fail
+ _, err = db.Exec(
+ "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
+ "unique-key",
+ )
+ if err == nil {
+ t.Error("expected unique constraint violation for duplicate session_key")
+ }
+}
+
+func TestMigrationSummaryFTSInsert(t *testing.T) {
+ db := openTestDB(t)
+ if err := runSchema(db); err != nil {
+ t.Fatalf("migration: %v", err)
+ }
+
+ // Insert a conversation first
+ _, err := db.Exec(
+ "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
+ "fts-test",
+ )
+ if err != nil {
+ t.Fatalf("insert conversation: %v", err)
+ }
+
+ // Insert a summary
+ _, err = db.Exec(
+ `INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
+ VALUES ('sum_test1', 1, 'leaf', 0, '你好世界 hello world', 10, datetime('now'))`)
+ if err != nil {
+ t.Fatalf("insert summary: %v", err)
+ }
+
+ // FTS should find it — trigram tokenizer requires >= 3 chars
+ rows, err := db.Query(
+ "SELECT summary_id FROM summaries_fts WHERE summaries_fts MATCH ?",
+ "你好世",
+ )
+ if err != nil {
+ t.Fatalf("FTS query: %v", err)
+ }
+ defer rows.Close()
+
+ var found string
+ if rows.Next() {
+ if err := rows.Scan(&found); err != nil {
+ t.Fatalf("scan: %v", err)
+ }
+ }
+ if err := rows.Err(); err != nil {
+ t.Fatalf("rows.Err: %v", err)
+ }
+ if found != "sum_test1" {
+ t.Errorf("FTS: expected 'sum_test1', got %q", found)
+ }
+}
+
+func TestMigrationSummaryParentsPK(t *testing.T) {
+ db := openTestDB(t)
+ if err := runSchema(db); err != nil {
+ t.Fatalf("migration: %v", err)
+ }
+
+ // Insert two summaries
+ for _, id := range []string{"sum_a", "sum_b"} {
+ _, err := db.Exec(
+ `INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
+ VALUES (?, 1, 'leaf', 0, 'content', 5, datetime('now'))`, id)
+ if err != nil {
+ t.Fatalf("insert summary %s: %v", id, err)
+ }
+ }
+
+ // Link child to parent
+ _, err := db.Exec(
+ "INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
+ if err != nil {
+ t.Fatalf("link: %v", err)
+ }
+
+ // Duplicate link should fail (composite PK)
+ _, err = db.Exec(
+ "INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
+ if err == nil {
+ t.Error("expected unique constraint violation for duplicate summary_parents link")
+ }
+}
+
+func TestFTS5SQLConstants(t *testing.T) {
+ db := openTestDB(t)
+
+ // Verify FTS5 check SQL executes without error
+ _, err := db.Exec(sqlCheckFTS5Available)
+ if err != nil {
+ t.Errorf("sqlCheckFTS5Available failed: %v", err)
+ }
+
+ // Verify trigram check SQL executes without error
+ _, err = db.Exec(sqlCheckTrigramAvailable)
+ if err != nil {
+ t.Errorf("sqlCheckTrigramAvailable failed: %v", err)
+ }
+
+ // Verify summaries_fts SQL executes without error
+ _, err = db.Exec(sqlCreateSummariesFTS)
+ if err != nil {
+ t.Errorf("sqlCreateSummariesFTS failed: %v", err)
+ }
+
+ // Verify messages_fts SQL executes without error
+ _, err = db.Exec(sqlCreateMessagesFTS)
+ if err != nil {
+ t.Errorf("sqlCreateMessagesFTS failed: %v", err)
+ }
+}
diff --git a/pkg/seahorse/short_assembler.go b/pkg/seahorse/short_assembler.go
new file mode 100644
index 000000000..f0fd323ba
--- /dev/null
+++ b/pkg/seahorse/short_assembler.go
@@ -0,0 +1,261 @@
+package seahorse
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// escapeXML escapes special characters for safe inclusion in XML content.
+func escapeXML(s string) string {
+ s = strings.ReplaceAll(s, "&", "&")
+ s = strings.ReplaceAll(s, "<", "<")
+ s = strings.ReplaceAll(s, ">", ">")
+ s = strings.ReplaceAll(s, "\"", """)
+ s = strings.ReplaceAll(s, "'", "'")
+ return s
+}
+
+// resolvedItem is a context item resolved to its full content with token count.
+type resolvedItem struct {
+ ordinal int
+ itemType string // "message" or "summary"
+ message *Message
+ summary *Summary
+ tokenCount int
+}
+
+// Assemble builds budget-constrained context from summaries + messages.
+//
+// Algorithm:
+// 1. Fetch context_items, resolve to full content
+// 2. Split into evictable prefix + protected fresh tail
+// 3. If evictable fits in remaining budget → include all
+// 4. Else walk evictable from newest to oldest, keep while fits
+func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) {
+ items, err := a.store.GetContextItems(ctx, convID)
+ if err != nil {
+ return nil, fmt.Errorf("get context items: %w", err)
+ }
+ if len(items) == 0 {
+ return &AssembleResult{}, nil
+ }
+
+ // Resolve all items
+ resolved := make([]resolvedItem, len(items))
+ for i, item := range items {
+ r, err := a.resolveItem(ctx, item)
+ if err != nil {
+ return nil, err
+ }
+ resolved[i] = r
+ }
+
+ // Split into evictable prefix and protected fresh tail
+ tailStart := len(resolved) - FreshTailCount
+ if tailStart < 0 {
+ tailStart = 0
+ }
+ evictable := resolved[:tailStart]
+ freshTail := resolved[tailStart:]
+
+ // Calculate fresh tail tokens
+ freshTailTokens := 0
+ for _, r := range freshTail {
+ freshTailTokens += r.tokenCount
+ }
+
+ // Budget-aware selection of evictable items
+ remainingBudget := input.Budget - freshTailTokens
+ if remainingBudget < 0 {
+ // Fresh tail alone exceeds budget - we keep it anyway (design decision)
+ // Log for debugging retry/overflow issues
+ logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{
+ "budget": input.Budget,
+ "fresh_tail_tokens": freshTailTokens,
+ "fresh_tail_count": len(freshTail),
+ "over_budget_by": freshTailTokens - input.Budget,
+ })
+ remainingBudget = 0
+ }
+
+ var selected []resolvedItem
+ evictableTokens := 0
+ for _, r := range evictable {
+ evictableTokens += r.tokenCount
+ }
+
+ if evictableTokens <= remainingBudget {
+ // All evictable fit
+ selected = append(selected, evictable...)
+ } else {
+ // Walk from newest to oldest, keep while fits
+ var kept []resolvedItem
+ accum := 0
+ for i := len(evictable) - 1; i >= 0; i-- {
+ if accum+evictable[i].tokenCount <= remainingBudget {
+ kept = append(kept, evictable[i])
+ accum += evictable[i].tokenCount
+ } else {
+ break
+ }
+ }
+ // Reverse to restore chronological order
+ for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 {
+ kept[i], kept[j] = kept[j], kept[i]
+ }
+ selected = append(selected, kept...)
+ }
+
+ // Combine: selected evictable + fresh tail
+ final := append(selected, freshTail...)
+
+ // Build result
+ var messages []Message
+ var summaries []Summary
+ var sourceIDs []string
+ totalTokens := 0
+ maxDepth := 0
+ condensedCount := 0
+
+ for _, r := range final {
+ totalTokens += r.tokenCount
+ if r.itemType == "message" && r.message != nil {
+ messages = append(messages, *r.message)
+ sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID))
+ } else if r.itemType == "summary" && r.summary != nil {
+ summaries = append(summaries, *r.summary)
+ if r.summary.Depth > maxDepth {
+ maxDepth = r.summary.Depth
+ }
+ if r.summary.Kind == SummaryKindCondensed {
+ condensedCount++
+ }
+ }
+ }
+
+ // Build depth-aware system prompt addition
+ systemPromptAddition := ""
+ if len(summaries) > 0 {
+ if maxDepth >= 2 || condensedCount >= 2 {
+ systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" +
+ "- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" +
+ "- When uncertain, use expand to recover original detail before making claims.\n" +
+ "- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand"
+ } else {
+ systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed."
+ }
+ }
+
+ // Build Summary field: all XML summaries + system prompt addition
+ var summaryParts []string
+ for _, sum := range summaries {
+ if sum.Content == "" {
+ continue
+ }
+ // Load parent IDs for XML formatting
+ parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID)
+ if err != nil {
+ logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{
+ "summary_id": sum.SummaryID,
+ "error": err.Error(),
+ })
+ }
+ var parentIDs []string
+ for _, ps := range parentSummaries {
+ parentIDs = append(parentIDs, ps.SummaryID)
+ }
+ summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs))
+ }
+ summary := strings.Join(summaryParts, "\n\n")
+ if systemPromptAddition != "" {
+ if summary != "" {
+ summary += "\n\n"
+ }
+ summary += systemPromptAddition
+ }
+
+ return &AssembleResult{
+ Messages: messages,
+ Summary: summary,
+ }, nil
+}
+
+// resolveItem loads the full message or summary for a context item.
+func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
+ if item.ItemType == "message" {
+ msg, err := a.store.GetMessageByID(ctx, item.MessageID)
+ if err != nil {
+ return resolvedItem{}, err
+ }
+ tokens := item.TokenCount
+ if tokens == 0 {
+ tokens = msg.TokenCount
+ }
+ return resolvedItem{
+ ordinal: item.Ordinal,
+ itemType: "message",
+ message: msg,
+ tokenCount: tokens,
+ }, nil
+ }
+
+ if item.ItemType == "summary" {
+ sum, err := a.store.GetSummary(ctx, item.SummaryID)
+ if err != nil {
+ return resolvedItem{}, err
+ }
+ tokens := item.TokenCount
+ if tokens == 0 {
+ tokens = sum.TokenCount
+ }
+ return resolvedItem{
+ ordinal: item.Ordinal,
+ itemType: "summary",
+ summary: sum,
+ tokenCount: tokens,
+ }, nil
+ }
+
+ return resolvedItem{
+ ordinal: item.Ordinal,
+ itemType: item.ItemType,
+ tokenCount: item.TokenCount,
+ }, nil
+}
+
+// FormatSummaryXML formats a summary as XML for LLM context.
+// This is exported so context managers can format summaries consistently.
+func FormatSummaryXML(s *Summary, parentIDs []string) string {
+ // Build time attributes if available
+ var attrs string
+ if s.EarliestAt != nil {
+ attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339))
+ }
+ if s.LatestAt != nil {
+ attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339))
+ }
+
+ var parentsSection string
+ if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 {
+ parents := "\n"
+ for _, pid := range parentIDs {
+ parents += fmt.Sprintf(" \n", pid)
+ }
+ parents += " \n"
+ parentsSection = parents
+ }
+ return fmt.Sprintf(
+ "\n \n %s\n \n%s",
+ s.SummaryID,
+ string(s.Kind),
+ s.Depth,
+ s.DescendantCount,
+ attrs,
+ escapeXML(s.Content),
+ parentsSection,
+ )
+}
diff --git a/pkg/seahorse/short_assembler_test.go b/pkg/seahorse/short_assembler_test.go
new file mode 100644
index 000000000..88a05e64c
--- /dev/null
+++ b/pkg/seahorse/short_assembler_test.go
@@ -0,0 +1,536 @@
+package seahorse
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+)
+
+// --- Assembler Tests ---
+
+// helper: create a store with messages and summaries for assembly tests
+func setupAssemblerStore(t *testing.T) (*Store, int64) {
+ t.Helper()
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, err := s.GetOrCreateConversation(ctx, "test:assemble")
+ if err != nil {
+ t.Fatalf("create conversation: %v", err)
+ }
+
+ return s, conv.ConversationID
+}
+
+func TestAssemblerAssembleEmpty(t *testing.T) {
+ s, convID := setupAssemblerStore(t)
+ ctx := context.Background()
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+ if len(result.Messages) != 0 {
+ t.Errorf("Messages = %d, want 0", len(result.Messages))
+ }
+ if result.Summary != "" {
+ t.Errorf("Summary = %q, want empty", result.Summary)
+ }
+}
+
+func TestAssemblerAssembleMessagesOnly(t *testing.T) {
+ s, convID := setupAssemblerStore(t)
+ ctx := context.Background()
+
+ // Create messages
+ msg1, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
+ msg2, _ := s.AddMessage(ctx, convID, "assistant", "world", 5)
+
+ // Create context items
+ s.UpsertContextItems(ctx, convID, []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
+ {Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
+ })
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ if len(result.Messages) != 2 {
+ t.Fatalf("Messages = %d, want 2", len(result.Messages))
+ }
+ if result.Messages[0].Content != "hello" {
+ t.Errorf("Messages[0].Content = %q, want 'hello'", result.Messages[0].Content)
+ }
+ if result.Messages[1].Content != "world" {
+ t.Errorf("Messages[1].Content = %q, want 'world'", result.Messages[1].Content)
+ }
+ // No summaries, so Summary should be empty
+ if result.Summary != "" {
+ t.Errorf("Summary = %q, want empty", result.Summary)
+ }
+}
+
+func TestAssemblerAssembleWithSummary(t *testing.T) {
+ s, convID := setupAssemblerStore(t)
+ ctx := context.Background()
+
+ // Create a summary
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "summary of early messages",
+ TokenCount: 50,
+ })
+
+ // Create recent messages
+ msg1, _ := s.AddMessage(ctx, convID, "user", "recent", 5)
+ msg2, _ := s.AddMessage(ctx, convID, "assistant", "reply", 5)
+
+ // Context: summary + recent messages
+ s.UpsertContextItems(ctx, convID, []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 50},
+ {Ordinal: 200, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
+ {Ordinal: 300, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
+ })
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ // Messages = 2 raw messages (summaries are in Summary field, not Messages)
+ if len(result.Messages) != 2 {
+ t.Errorf("Messages = %d, want 2 (raw messages only)", len(result.Messages))
+ }
+ // Summary should contain XML with summary content
+ if result.Summary == "" {
+ t.Error("Summary should not be empty when summary exists")
+ }
+ if !strings.Contains(result.Summary, summary.Content) {
+ t.Errorf("Summary should contain summary content %q", summary.Content)
+ }
+ if !strings.Contains(result.Summary, "`,
+ TokenCount: 20,
+ })
+
+ s.UpsertContextItems(ctx, convID, []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
+ })
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ // Summary field should contain XML with escaped special characters
+ if result.Summary == "" {
+ t.Fatal("Summary should not be empty")
+ }
+
+ // Check that special characters are escaped
+ if strings.Contains(result.Summary, "") {
+ t.Errorf("BUG: unescaped < in summary content: %q", result.Summary)
+ }
+ if strings.Contains(result.Summary, `"hello"`) {
+ t.Errorf("BUG: unescaped \" in summary content: %q", result.Summary)
+ }
+ // & should be escaped as &
+ if strings.Contains(result.Summary, " & ") {
+ t.Errorf("BUG: unescaped & in summary content: %q", result.Summary)
+ }
+}
+
+func TestAssemblerSummaryXMLWithParents(t *testing.T) {
+ s, convID := setupAssemblerStore(t)
+ ctx := context.Background()
+
+ // Create a leaf and a condensed summary (condensed has parent)
+ leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf content",
+ TokenCount: 20,
+ })
+ condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindCondensed,
+ Depth: 1,
+ Content: "condensed content",
+ TokenCount: 15,
+ ParentIDs: []string{leaf.SummaryID},
+ })
+
+ msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
+
+ s.UpsertContextItems(ctx, convID, []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
+ {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
+ })
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ // Summary field should contain XML with parent information
+ if result.Summary == "" {
+ t.Fatal("Summary should not be empty")
+ }
+ xmlContent := result.Summary
+
+ // Should contain section with parent ID
+ if !contains(xmlContent, "") {
+ t.Errorf("condensed summary XML missing section: %q", xmlContent)
+ }
+ if !contains(xmlContent, leaf.SummaryID) {
+ t.Errorf("condensed summary XML missing parent ID %q: %q", leaf.SummaryID, xmlContent)
+ }
+
+ // Should contain kind="condensed"
+ if !contains(xmlContent, `kind="condensed"`) {
+ t.Errorf("condensed summary XML missing kind attribute: %q", xmlContent)
+ }
+}
+
+func TestAssemblerSummaryXMLIncludesDescendantCount(t *testing.T) {
+ s, convID := setupAssemblerStore(t)
+ ctx := context.Background()
+
+ // Create a leaf summary with specific descendant count
+ leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf content",
+ TokenCount: 20,
+ DescendantCount: 8,
+ DescendantTokenCount: 1200,
+ })
+
+ msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
+
+ s.UpsertContextItems(ctx, convID, []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
+ {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
+ })
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ if result.Summary == "" {
+ t.Fatal("Summary should not be empty")
+ }
+ xmlContent := result.Summary
+
+ // Should contain descendant_count="8"
+ if !contains(xmlContent, `descendant_count="8"`) {
+ t.Errorf("summary XML missing descendant_count attribute: %q", xmlContent)
+ }
+}
+
+func TestAssemblerLeafSummaryNoParents(t *testing.T) {
+ s, convID := setupAssemblerStore(t)
+ ctx := context.Background()
+
+ // Leaf summary has no parents
+ leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf content",
+ TokenCount: 20,
+ })
+
+ msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
+
+ s.UpsertContextItems(ctx, convID, []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
+ {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
+ })
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ if result.Summary == "" {
+ t.Fatal("Summary should not be empty")
+ }
+ xmlContent := result.Summary
+
+ // Leaf summary should NOT have section
+ if contains(xmlContent, "") {
+ t.Errorf("leaf summary XML should not have section: %q", xmlContent)
+ }
+}
+
+func TestAssemblerDepthAwarePrompt(t *testing.T) {
+ s, convID := setupAssemblerStore(t)
+ ctx := context.Background()
+
+ // Create a condensed summary (depth >= 2) to trigger full guidance
+ now := time.Now().UTC()
+ leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf summary",
+ TokenCount: 20,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindCondensed,
+ Depth: 2,
+ Content: "condensed summary",
+ TokenCount: 15,
+ ParentIDs: []string{leaf.SummaryID},
+ DescendantCount: 1,
+ DescendantTokenCount: 20,
+ })
+
+ msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
+
+ s.UpsertContextItems(ctx, convID, []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
+ {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
+ })
+
+ a := &Assembler{store: s, config: Config{}}
+ result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ // Should have a depth-aware prompt in Summary field
+ if result.Summary == "" {
+ t.Error("expected non-empty Summary when depth >= 2")
+ }
+ // SystemPromptAddition is embedded in Summary field
+ if !strings.Contains(result.Summary, "multi-level summarization") {
+ t.Error("Summary should contain system prompt addition about multi-level summarization")
+ }
+}
+
+func TestFormatSummaryXMLUsesSummaryRef(t *testing.T) {
+ // Spec: condensed summaries use not parentId
+ now := time.Now().UTC()
+ s := Summary{
+ SummaryID: "sum_condensed1",
+ Kind: SummaryKindCondensed,
+ Depth: 1,
+ Content: "condensed content",
+ TokenCount: 50,
+ DescendantCount: 2,
+ EarliestAt: &now,
+ LatestAt: &now,
+ }
+ parentIDs := []string{"sum_leaf1", "sum_leaf2"}
+
+ xml := FormatSummaryXML(&s, parentIDs)
+
+ // Must use per spec
+ if !contains(xml, ``) {
+ t.Errorf("expected , got: %s", xml)
+ }
+ if !contains(xml, ``) {
+ t.Errorf("expected , got: %s", xml)
+ }
+ // Must NOT use old tag
+ if contains(xml, "") {
+ t.Errorf("should not use tag, got: %s", xml)
+ }
+}
+
+func TestFormatSummaryXMLIncludesTimestamps(t *testing.T) {
+ // Spec: summary XML includes earliest_at and latest_at attributes
+ earliest := time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC)
+ latest := time.Date(2026, 3, 15, 14, 30, 0, 0, time.UTC)
+ s := Summary{
+ SummaryID: "sum_leaf1",
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf content",
+ TokenCount: 30,
+ DescendantCount: 0,
+ EarliestAt: &earliest,
+ LatestAt: &latest,
+ }
+
+ xml := FormatSummaryXML(&s, nil)
+
+ if !contains(xml, `earliest_at="2026-03-15T10:00:00Z"`) {
+ t.Errorf("missing earliest_at attribute, got: %s", xml)
+ }
+ if !contains(xml, `latest_at="2026-03-15T14:30:00Z"`) {
+ t.Errorf("missing latest_at attribute, got: %s", xml)
+ }
+}
+
+func TestFormatSummaryXMLNoTimestampsWhenNil(t *testing.T) {
+ // When EarliestAt/LatestAt are nil, attributes should be omitted
+ s := Summary{
+ SummaryID: "sum_leaf1",
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf content",
+ TokenCount: 30,
+ DescendantCount: 0,
+ }
+
+ xml := FormatSummaryXML(&s, nil)
+
+ if contains(xml, "earliest_at=") {
+ t.Errorf("should not have earliest_at when nil, got: %s", xml)
+ }
+ if contains(xml, "latest_at=") {
+ t.Errorf("should not have latest_at when nil, got: %s", xml)
+ }
+}
diff --git a/pkg/seahorse/short_bench_test.go b/pkg/seahorse/short_bench_test.go
new file mode 100644
index 000000000..b7e47bcff
--- /dev/null
+++ b/pkg/seahorse/short_bench_test.go
@@ -0,0 +1,336 @@
+package seahorse
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "testing"
+ "time"
+
+ _ "modernc.org/sqlite"
+)
+
+// newBenchStore creates a test store for benchmarks.
+func newBenchStore(b *testing.B) (*Store, func()) {
+ b.Helper()
+ db, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ b.Fatalf("open test db: %v", err)
+ }
+ if err := runSchema(db); err != nil {
+ db.Close()
+ b.Fatalf("migration: %v", err)
+ }
+ return &Store{db: db}, func() { db.Close() }
+}
+
+// --- Ingest benchmarks ---
+
+func BenchmarkIngest_SingleMessage(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "bench:ingest")
+ convID := conv.ConversationID
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := s.AddMessage(ctx, convID, "user", "Test message content", 15)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkIngest_BatchMessages(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:ingest-batch:%d", i))
+ convID := conv.ConversationID
+
+ for j := 0; j < 10; j++ {
+ added, err := s.AddMessage(ctx, convID, "user",
+ fmt.Sprintf("Message %d in batch", j), 10)
+ if err != nil {
+ b.Fatal(err)
+ }
+ s.AppendContextMessage(ctx, convID, added.ID)
+ }
+ }
+}
+
+// --- Assemble benchmarks ---
+
+func BenchmarkAssemble_MessagesOnly(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-msgs")
+ convID := conv.ConversationID
+
+ // Add 100 messages
+ for i := 0; i < 100; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user",
+ fmt.Sprintf("Message content %d with some text", i), 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ a := &Assembler{store: s}
+ input := AssembleInput{Budget: 50000}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := a.Assemble(ctx, convID, input)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkAssemble_WithSummaries(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-sums")
+ convID := conv.ConversationID
+
+ now := time.Now().UTC()
+
+ // Add 10 leaf summaries
+ for i := 0; i < 10; i++ {
+ sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("Leaf summary %d", i),
+ TokenCount: 500,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ s.AppendContextSummary(ctx, convID, sum.SummaryID)
+ }
+
+ // Add 20 fresh messages
+ for i := 0; i < 20; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("Fresh message %d", i), 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ a := &Assembler{store: s}
+ input := AssembleInput{Budget: 10000}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := a.Assemble(ctx, convID, input)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkAssemble_BudgetEviction(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-evict")
+ convID := conv.ConversationID
+
+ now := time.Now().UTC()
+
+ // Add 50 leaf summaries (more than budget can hold)
+ for i := 0; i < 50; i++ {
+ sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("Summary %d", i),
+ TokenCount: 300,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ s.AppendContextSummary(ctx, convID, sum.SummaryID)
+ }
+
+ // Add fresh tail
+ for i := 0; i < FreshTailCount; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ a := &Assembler{store: s}
+ input := AssembleInput{Budget: 5000} // Force eviction
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := a.Assemble(ctx, convID, input)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+// --- Search (FTS5) benchmarks ---
+
+// benchSeedSummaries adds n summaries to a conversation for search benchmarks.
+func benchSeedSummaries(b *testing.B, s *Store, convID int64, n int, contentTpl string) {
+ b.Helper()
+ now := time.Now().UTC()
+ for i := 0; i < n; i++ {
+ sum, err := s.CreateSummary(context.Background(), CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf(contentTpl, i),
+ TokenCount: 200,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ if err != nil {
+ b.Fatalf("create summary: %v", err)
+ }
+ s.AppendContextSummary(context.Background(), convID, sum.SummaryID)
+ }
+}
+
+func BenchmarkSearchSummaries_FTS5(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "bench:search-fts")
+ convID := conv.ConversationID
+
+ benchSeedSummaries(b, s, convID, 100, "Summary about database configuration and API endpoints %d")
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := s.SearchSummaries(ctx, SearchInput{
+ Pattern: "database",
+ Mode: "full_text",
+ ConversationID: convID,
+ })
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkSearchSummaries_Like(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "bench:search-like")
+ convID := conv.ConversationID
+
+ benchSeedSummaries(b, s, convID, 100, "Summary about configuration %d")
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := s.SearchSummaries(ctx, SearchInput{
+ Pattern: "config",
+ Mode: "like",
+ ConversationID: convID,
+ })
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkSearchMessages_FTS5(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "bench:search-msg-fts")
+ convID := conv.ConversationID
+
+ // Add 500 messages
+ for i := 0; i < 500; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user",
+ fmt.Sprintf("User message about API and database integration %d", i), 20)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := s.SearchMessages(ctx, SearchInput{
+ Pattern: "API database",
+ Mode: "full_text",
+ ConversationID: convID,
+ })
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+// --- Bootstrap benchmarks ---
+
+func BenchmarkBootstrap_Empty(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-empty:%d", i))
+ convID := conv.ConversationID
+ _ = convID // Bootstrap with empty history
+ }
+}
+
+func BenchmarkBootstrap_100Messages(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+
+ // Prepare 100 messages
+ msgs := make([]Message, 100)
+ for i := 0; i < 100; i++ {
+ msgs[i] = Message{
+ Role: "user",
+ Content: fmt.Sprintf("Bootstrap message %d", i),
+ TokenCount: 15,
+ }
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-100:%d", i))
+ convID := conv.ConversationID
+
+ for _, m := range msgs {
+ added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
+ s.AppendContextMessage(ctx, convID, added.ID)
+ }
+ }
+}
+
+func BenchmarkBootstrap_500Messages(b *testing.B) {
+ s, cleanup := newBenchStore(b)
+ defer cleanup()
+ ctx := context.Background()
+
+ msgs := make([]Message, 500)
+ for i := 0; i < 500; i++ {
+ msgs[i] = Message{
+ Role: "user",
+ Content: fmt.Sprintf("Bootstrap message %d", i),
+ TokenCount: 15,
+ }
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-500:%d", i))
+ convID := conv.ConversationID
+
+ for _, m := range msgs {
+ added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
+ s.AppendContextMessage(ctx, convID, added.ID)
+ }
+ }
+}
diff --git a/pkg/seahorse/short_compaction.go b/pkg/seahorse/short_compaction.go
new file mode 100644
index 000000000..30e290926
--- /dev/null
+++ b/pkg/seahorse/short_compaction.go
@@ -0,0 +1,898 @@
+package seahorse
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/tokenizer"
+)
+
+// CompactInput controls compaction behavior.
+type CompactInput struct {
+ Budget *int // Token budget override
+ Force bool // Force compaction even if below threshold
+}
+
+// CompactResult describes what was compacted.
+type CompactResult struct {
+ SummariesCreated []string `json:"summariesCreated"`
+ TokensSaved int `json:"tokensSaved"`
+ LeafSummaries int `json:"leafSummaries"`
+ CondensedSummaries int `json:"condensedSummaries"`
+}
+
+// NeedsCompaction returns true if context tokens >= ContextThreshold × contextWindow.
+func (e *CompactionEngine) NeedsCompaction(ctx context.Context, convID int64, contextWindow int) (bool, error) {
+ tokens, err := e.store.GetContextTokenCount(ctx, convID)
+ if err != nil {
+ return false, fmt.Errorf("get token count: %w", err)
+ }
+ threshold := int(float64(contextWindow) * ContextThreshold)
+ return tokens >= threshold, nil
+}
+
+// Close cancels the shutdown context, stopping async goroutines.
+func (e *CompactionEngine) Close() {
+ if e.shutdownCancel != nil {
+ e.shutdownCancel()
+ }
+}
+
+// Compact runs leaf compaction (sync) and optionally condensed compaction.
+func (e *CompactionEngine) Compact(ctx context.Context, convID int64, input CompactInput) (*CompactResult, error) {
+ result := &CompactResult{}
+
+ // Phase 1: leaf compaction (synchronous, every turn)
+ summaryID, err := e.compactLeaf(ctx, convID)
+ if err != nil {
+ return nil, fmt.Errorf("compact leaf: %w", err)
+ }
+ if summaryID != nil {
+ result.SummariesCreated = append(result.SummariesCreated, *summaryID)
+ result.LeafSummaries++
+ logger.InfoCF("seahorse", "compact: leaf", map[string]any{
+ "conv_id": convID,
+ "summary_id": *summaryID,
+ })
+ }
+
+ // Phase 2: condensed compaction if over threshold
+ tokensBefore, _ := e.store.GetContextTokenCount(ctx, convID)
+ var budget int
+ if input.Budget != nil {
+ budget = *input.Budget
+ if budget == 0 {
+ logger.ErrorCF("seahorse", "Compact: budget is 0, this should not happen", map[string]any{
+ "conv_id": convID,
+ })
+ }
+ } else {
+ budget = int(float64(tokensBefore) * ContextThreshold)
+ }
+
+ if input.Force || (tokensBefore > budget && budget > 0) {
+ // Launch async condensed compaction with dedup
+ if _, loaded := e.condensing.LoadOrStore(convID, struct{}{}); !loaded {
+ go func() {
+ defer e.condensing.Delete(convID)
+ e.runCondensedLoop(e.shutdownCtx, convID)
+ }()
+ }
+ }
+
+ tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
+ if tokensAfter < tokensBefore {
+ result.TokensSaved = tokensBefore - tokensAfter
+ }
+
+ return result, nil
+}
+
+// CompactUntilUnder aggressively compacts until context is under budget.
+func (e *CompactionEngine) CompactUntilUnder(ctx context.Context, convID int64, budget int) (*CompactResult, error) {
+ result := &CompactResult{}
+ prevTokens := 0
+ logger.InfoCF("seahorse", "compact_until_under: start", map[string]any{"conv_id": convID, "budget": budget})
+
+ for iter := 0; iter < MaxCompactIterations; iter++ {
+ tokens, err := e.store.GetContextTokenCount(ctx, convID)
+ if err != nil {
+ return result, fmt.Errorf("get tokens: %w", err)
+ }
+ if tokens <= budget {
+ logger.InfoCF("seahorse", "compact_until_under: done", map[string]any{
+ "conv_id": convID,
+ "budget": budget,
+ "tokens": tokens,
+ "leaf": result.LeafSummaries,
+ "condensed": result.CondensedSummaries,
+ })
+ return result, nil
+ }
+
+ // Try leaf first
+ summaryID, err := e.compactLeaf(ctx, convID, true)
+ if err != nil {
+ return result, err
+ }
+ if summaryID != nil {
+ result.SummariesCreated = append(result.SummariesCreated, *summaryID)
+ result.LeafSummaries++
+ logger.InfoCF("seahorse", "compact_until_under: leaf", map[string]any{
+ "conv_id": convID,
+ "summary_id": *summaryID,
+ })
+ continue
+ }
+
+ // Try condensed with forced fanout
+ condensedID, err := e.compactCondensed(ctx, convID)
+ if err != nil {
+ return result, err
+ }
+ if condensedID != nil {
+ result.SummariesCreated = append(result.SummariesCreated, *condensedID)
+ result.CondensedSummaries++
+ logger.InfoCF("seahorse", "compact_until_under: condensed", map[string]any{
+ "conv_id": convID,
+ "summary_id": *condensedID,
+ })
+ continue
+ }
+
+ // No progress
+ newTokens, _ := e.store.GetContextTokenCount(ctx, convID)
+ if newTokens >= prevTokens {
+ logger.WarnCF("seahorse", "compact_until_under: no progress", map[string]any{
+ "conv_id": convID,
+ "tokens": newTokens,
+ })
+ return result, nil
+ }
+ prevTokens = newTokens
+ }
+
+ // Safety cap exceeded — see MaxCompactIterations doc for rationale.
+ logger.WarnCF("seahorse", "compact_until_under: exceeded max iterations", map[string]any{
+ "conv_id": convID,
+ "budget": budget,
+ "iterations": MaxCompactIterations,
+ "tokens": prevTokens,
+ })
+ return result, nil
+}
+
+// compactLeaf compresses the oldest contiguous message chunk into a leaf summary.
+// When force is true, FreshTailCount protection is bypassed (used by CompactUntilUnder).
+func (e *CompactionEngine) compactLeaf(ctx context.Context, convID int64, force ...bool) (*string, error) {
+ items, err := e.store.GetContextItems(ctx, convID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Find oldest contiguous message chunk outside fresh tail
+ msgCount := 0
+ msgTokens := 0
+ for _, item := range items {
+ if item.ItemType == "message" {
+ msgCount++
+ msgTokens += item.TokenCount
+ }
+ }
+
+ // Trigger if either message count or token threshold is met
+ if msgCount < LeafMinFanout && msgTokens < LeafChunkTokens {
+ return nil, nil
+ }
+
+ // Calculate fresh tail boundary (bypass when forced)
+ useForce := len(force) > 0 && force[0]
+ tailStartIdx := len(items) - FreshTailCount
+ if useForce {
+ tailStartIdx = len(items) // allow compacting everything
+ }
+ if tailStartIdx < 0 {
+ tailStartIdx = 0
+ }
+
+ // Find oldest contiguous message chunk, accumulating up to LeafChunkTokens
+ var chunk []ContextItem
+ chunkStart := -1
+ chunkEnd := -1
+ accumTokens := 0
+ for i := 0; i < tailStartIdx; i++ {
+ if items[i].ItemType == "message" {
+ if chunkStart == -1 {
+ chunkStart = i
+ }
+ chunkEnd = i
+ accumTokens += items[i].TokenCount
+ // Stop accumulating once we reach the token budget
+ if accumTokens >= LeafChunkTokens {
+ break
+ }
+ } else {
+ // Non-message breaks the chunk
+ if chunkStart != -1 && (chunkEnd-chunkStart+1) >= LeafMinFanout {
+ break
+ }
+ chunkStart = -1
+ chunkEnd = -1
+ accumTokens = 0
+ }
+ }
+
+ if chunkStart == -1 || (chunkEnd-chunkStart+1) < LeafMinFanout {
+ return nil, nil
+ }
+
+ chunk = items[chunkStart : chunkEnd+1]
+
+ // Collect messages for the chunk
+ var messages []Message
+ for _, item := range chunk {
+ msg, innerErr := e.store.GetMessageByID(ctx, item.MessageID)
+ if innerErr != nil {
+ return nil, innerErr
+ }
+ messages = append(messages, *msg)
+ }
+
+ // Get prior summaries for context
+ priorSummary := ""
+ priorCount := 0
+ for i := chunkStart - 1; i >= 0 && priorCount < 2; i-- {
+ if items[i].ItemType == "summary" {
+ sum, innerErr2 := e.store.GetSummary(ctx, items[i].SummaryID)
+ if innerErr2 == nil {
+ priorSummary = sum.Content + "\n" + priorSummary
+ priorCount++
+ }
+ }
+ }
+
+ // Generate summary
+ content, err := e.generateLeafSummary(ctx, messages, priorSummary)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create summary in store
+ tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
+
+ var earliestAt, latestAt *time.Time
+ if len(messages) > 0 {
+ earliestAt = &messages[0].CreatedAt
+ latestAt = &messages[len(messages)-1].CreatedAt
+ }
+
+ summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: content,
+ TokenCount: tokenCount,
+ EarliestAt: earliestAt,
+ LatestAt: latestAt,
+ SourceMessageTokens: sumMessageTokens(messages),
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Link to source messages
+ msgIDs := make([]int64, len(messages))
+ for i, m := range messages {
+ msgIDs[i] = m.ID
+ }
+ if err := e.store.LinkSummaryToMessages(ctx, summary.SummaryID, msgIDs); err != nil {
+ return nil, err
+ }
+
+ // Replace context range with summary
+ if err := e.store.ReplaceContextRangeWithSummary(
+ ctx, convID, chunk[0].Ordinal, chunk[len(chunk)-1].Ordinal, summary.SummaryID,
+ ); err != nil {
+ return nil, err
+ }
+
+ return &summary.SummaryID, nil
+}
+
+// compactCondensed compresses multiple summaries into one higher-level summary.
+func (e *CompactionEngine) compactCondensed(ctx context.Context, convID int64) (*string, error) {
+ // Try ordinal-aware selection first (respects consecutive ordering)
+ var candidates []Summary
+
+ depths, err := e.store.GetDistinctDepthsInContext(ctx, convID, 0)
+ if err != nil {
+ return nil, err
+ }
+ for _, depth := range depths {
+ var chunkAtDepth []Summary
+ var err2 error
+ chunkAtDepth, err2 = e.selectOldestChunkAtDepth(ctx, convID, depth)
+ if err2 != nil {
+ continue
+ }
+ if len(chunkAtDepth) > 0 {
+ candidates = chunkAtDepth
+ break
+ }
+ }
+
+ // Fallback to depth-grouping selection
+ if len(candidates) == 0 {
+ candidates, err = e.selectShallowestCondensationCandidate(ctx, convID, false)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if len(candidates) == 0 {
+ return nil, nil
+ }
+
+ // Generate condensed summary
+ content, err := e.generateCondensedSummary(ctx, candidates)
+ if err != nil {
+ return nil, err
+ }
+
+ // Merge metadata
+ maxDepth := 0
+ descendantCount := 0
+ descendantTokenCount := 0
+ sourceMessageTokens := 0
+ var earliestAt, latestAt *time.Time
+
+ parentIDs := make([]string, len(candidates))
+ for i, c := range candidates {
+ parentIDs[i] = c.SummaryID
+ if c.Depth > maxDepth {
+ maxDepth = c.Depth
+ }
+ descendantCount += c.DescendantCount + 1
+ descendantTokenCount += c.TokenCount + c.DescendantTokenCount
+ sourceMessageTokens += c.SourceMessageTokenCount
+ if c.EarliestAt != nil {
+ if earliestAt == nil || c.EarliestAt.Before(*earliestAt) {
+ earliestAt = c.EarliestAt
+ }
+ }
+ if c.LatestAt != nil {
+ if latestAt == nil || c.LatestAt.After(*latestAt) {
+ latestAt = c.LatestAt
+ }
+ }
+ }
+
+ tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
+
+ summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindCondensed,
+ Depth: maxDepth + 1,
+ Content: content,
+ TokenCount: tokenCount,
+ EarliestAt: earliestAt,
+ LatestAt: latestAt,
+ DescendantCount: descendantCount,
+ DescendantTokenCount: descendantTokenCount,
+ SourceMessageTokens: sourceMessageTokens,
+ ParentIDs: parentIDs,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Find the ordinal range for the candidate summaries in context
+ items, err := e.store.GetContextItems(ctx, convID)
+ if err != nil {
+ return nil, err
+ }
+
+ candidateSet := make(map[string]bool)
+ for _, c := range candidates {
+ candidateSet[c.SummaryID] = true
+ }
+
+ startOrd := -1
+ endOrd := -1
+ hasNonCandidate := false
+ for _, item := range items {
+ if item.ItemType == "summary" && candidateSet[item.SummaryID] {
+ if startOrd == -1 {
+ startOrd, endOrd = item.Ordinal, item.Ordinal
+ } else {
+ // Check for non-candidate items between endOrd and current ordinal
+ for _, it := range items {
+ if it.Ordinal > endOrd && it.Ordinal <= item.Ordinal {
+ if it.ItemType != "summary" || !candidateSet[it.SummaryID] {
+ hasNonCandidate = true
+ break
+ }
+ }
+ }
+ if hasNonCandidate {
+ break
+ }
+ if item.Ordinal < startOrd {
+ startOrd = item.Ordinal
+ }
+ if item.Ordinal > endOrd {
+ endOrd = item.Ordinal
+ }
+ }
+ }
+ }
+
+ if startOrd == -1 || endOrd == -1 {
+ return nil, nil
+ }
+
+ // Collect candidate summary IDs
+ candidateIDs := make([]string, 0, len(candidates))
+ for _, c := range candidates {
+ candidateIDs = append(candidateIDs, c.SummaryID)
+ }
+
+ if hasNonCandidate {
+ // Use safe per-item deletion to avoid deleting non-candidate items
+ if err := e.store.ReplaceContextItemsWithSummary(ctx, convID, candidateIDs, summary.SummaryID); err != nil {
+ return nil, err
+ }
+ } else {
+ // Candidates are consecutive, use efficient range deletion
+ if err := e.store.ReplaceContextRangeWithSummary(ctx, convID, startOrd, endOrd, summary.SummaryID); err != nil {
+ return nil, err
+ }
+ }
+
+ return &summary.SummaryID, nil
+}
+
+// selectShallowestCondensationCandidate finds the shallowest consecutive summary group.
+func (e *CompactionEngine) selectShallowestCondensationCandidate(
+ ctx context.Context, convID int64, forced bool,
+) ([]Summary, error) {
+ items, err := e.store.GetContextItems(ctx, convID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Group by depth, find consecutive runs
+ tailStartIdx := len(items) - FreshTailCount
+ if tailStartIdx < 0 {
+ tailStartIdx = 0
+ }
+
+ minFanout := CondensedMinFanout
+ if forced {
+ minFanout = CondensedMinFanoutHard
+ }
+
+ // Track depth groups
+ depthGroups := make(map[int][]ContextItem)
+ for i := 0; i < tailStartIdx; i++ {
+ item := items[i]
+ if item.ItemType != "summary" {
+ continue
+ }
+ sum, err := e.store.GetSummary(ctx, item.SummaryID)
+ if err != nil {
+ continue
+ }
+ depthGroups[sum.Depth] = append(depthGroups[sum.Depth], item)
+ }
+
+ // Find shallowest depth with enough candidates
+ // Collect all depths and sort to handle non-consecutive depths
+ var depths []int
+ for depth := range depthGroups {
+ depths = append(depths, depth)
+ }
+ sort.Ints(depths)
+
+ for _, depth := range depths {
+ group := depthGroups[depth]
+ if len(group) >= minFanout {
+ // Load summaries
+ var result []Summary
+ for _, item := range group[:minFanout] {
+ sum, err := e.store.GetSummary(ctx, item.SummaryID)
+ if err != nil {
+ continue
+ }
+ result = append(result, *sum)
+ }
+ return result, nil
+ }
+ }
+
+ return nil, nil
+}
+
+// selectOldestChunkAtDepth scans context_items from oldest ordinal, collecting consecutive
+// summaries at the given depth. Stops at non-summary items, different depth, fresh tail, or
+// token overflow. Returns contiguous chunk of summaries.
+func (e *CompactionEngine) selectOldestChunkAtDepth(
+ ctx context.Context, convID int64, targetDepth int,
+) ([]Summary, error) {
+ items, err := e.store.GetContextItems(ctx, convID)
+ if err != nil {
+ return nil, err
+ }
+
+ tailStartIdx := len(items) - FreshTailCount
+ if tailStartIdx < 0 {
+ tailStartIdx = 0
+ }
+
+ var chunk []Summary
+ accumTokens := 0
+
+ for i := 0; i < tailStartIdx; i++ {
+ item := items[i]
+ if item.ItemType != "summary" {
+ // Non-summary breaks the chunk
+ break
+ }
+ sum, err := e.store.GetSummary(ctx, item.SummaryID)
+ if err != nil {
+ break
+ }
+ if sum.Depth != targetDepth {
+ // Different depth breaks the chunk
+ break
+ }
+ if accumTokens+sum.TokenCount > LeafChunkTokens {
+ // Token overflow stops collection
+ break
+ }
+ chunk = append(chunk, *sum)
+ accumTokens += sum.TokenCount
+ }
+
+ // Min tokens check: spec line 808
+ // chunk tokens must be >= max(CondensedTargetTokens, LeafChunkTokens × 0.1) = 2000
+ minTokens := CondensedTargetTokens // 2000
+ if accumTokens < minTokens {
+ return nil, nil
+ }
+
+ return chunk, nil
+}
+
+// generateLeafSummary calls the LLM to generate a leaf summary with 3-level escalation.
+// Level 1: normal LLM prompt. Level 2: aggressive prompt. Level 3: deterministic truncation.
+func (e *CompactionEngine) generateLeafSummary(
+ ctx context.Context,
+ messages []Message,
+ previousSummary string,
+) (string, error) {
+ if e.complete == nil {
+ return truncateSummary(messages), nil
+ }
+
+ sourceText := formatMessagesForSummary(messages)
+ inputTokens := sumMessageTokens(messages)
+ targetTokens := minInt(LeafTargetTokens, int(float64(inputTokens)*0.35))
+
+ // Level 1: normal prompt
+ prompt := buildLeafSummaryPrompt(sourceText, previousSummary, targetTokens)
+ content, err := e.complete(ctx, prompt, CompleteOptions{
+ MaxTokens: LeafTargetTokens * 2,
+ Temperature: 0.3,
+ })
+ if err != nil {
+ return "", err
+ }
+ if content == "" {
+ // Retry with temperature=0
+ content, err = e.complete(ctx, prompt, CompleteOptions{
+ MaxTokens: LeafTargetTokens * 2,
+ Temperature: 0,
+ })
+ if err != nil {
+ return "", err
+ }
+ }
+
+ // Check if level 1 succeeded
+ if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
+ return content, nil
+ }
+
+ // Level 2: aggressive prompt
+ aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
+ aggressivePrompt := buildAggressiveLeafSummaryPrompt(sourceText, previousSummary, aggressiveTarget)
+ content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
+ MaxTokens: aggressiveTarget * 2,
+ Temperature: 0.3,
+ })
+ if err != nil {
+ return "", err
+ }
+ if content == "" {
+ // Retry with temperature=0
+ content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
+ MaxTokens: aggressiveTarget * 2,
+ Temperature: 0,
+ })
+ if err != nil {
+ return "", err
+ }
+ }
+ if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
+ return content, nil
+ }
+
+ // Level 3: deterministic truncation
+ return truncateSummary(messages), nil
+}
+
+// generateCondensedSummary calls the LLM to generate a condensed summary with 3-level escalation.
+func (e *CompactionEngine) generateCondensedSummary(ctx context.Context, summaries []Summary) (string, error) {
+ if e.complete == nil {
+ return truncateCondensedSummaries(summaries), nil
+ }
+
+ sourceText := formatSummariesForCondensation(summaries)
+ inputTokens := sumSummaryTokens(summaries)
+ targetTokens := minInt(CondensedTargetTokens, int(float64(inputTokens)*0.35))
+
+ // Level 1: normal prompt
+ prompt := buildCondensedSummaryPrompt(sourceText, targetTokens)
+ content, err := e.complete(ctx, prompt, CompleteOptions{
+ MaxTokens: CondensedTargetTokens * 2,
+ Temperature: 0.3,
+ })
+ if err != nil {
+ return "", err
+ }
+ if content == "" {
+ content, err = e.complete(ctx, prompt, CompleteOptions{
+ MaxTokens: CondensedTargetTokens * 2,
+ Temperature: 0,
+ })
+ if err != nil {
+ return "", err
+ }
+ }
+ if content != "" {
+ return content, nil
+ }
+
+ // Level 2: aggressive prompt
+ aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
+ aggressivePrompt := buildCondensedSummaryPrompt(sourceText, aggressiveTarget)
+ content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
+ MaxTokens: aggressiveTarget * 2,
+ Temperature: 0.3,
+ })
+ if err != nil {
+ return "", err
+ }
+ if content != "" {
+ return content, nil
+ }
+
+ // Level 3: deterministic fallback
+ return truncateCondensedSummaries(summaries), nil
+}
+
+// runCondensedLoop runs condensed compaction in a loop until:
+// a) context tokens <= threshold (success), OR
+// b) No candidate found (nothing to condense), OR
+// c) tokensAfter >= tokensBefore (no progress this iteration), OR
+// d) tokensAfter >= previousTokens (no improvement over last iteration)
+func (e *CompactionEngine) runCondensedLoop(ctx context.Context, convID int64) {
+ var prevTokens int
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ tokensBefore, err := e.store.GetContextTokenCount(ctx, convID)
+ if err != nil {
+ logger.ErrorCF("seahorse", "condensed: get tokens", map[string]any{"error": err.Error()})
+ return
+ }
+
+ condensedID, err := e.compactCondensed(ctx, convID)
+ if err != nil {
+ logger.ErrorCF("seahorse", "condensed: compact", map[string]any{"error": err.Error()})
+ return
+ }
+ if condensedID == nil {
+ // No candidate found
+ logger.DebugCF("seahorse", "condensed: no candidate", map[string]any{"conv_id": convID})
+ return
+ }
+
+ tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
+
+ if tokensAfter >= tokensBefore {
+ // No progress this iteration
+ logger.DebugCF(
+ "seahorse",
+ "condensed: no progress",
+ map[string]any{"conv_id": convID, "tokens_before": tokensBefore, "tokens_after": tokensAfter},
+ )
+ return
+ }
+ if tokensAfter >= prevTokens && prevTokens > 0 {
+ // No improvement over last iteration
+ logger.DebugCF(
+ "seahorse",
+ "condensed: no improvement",
+ map[string]any{"conv_id": convID, "tokens": tokensAfter},
+ )
+ return
+ }
+
+ prevTokens = tokensAfter
+ }
+}
+
+// --- Helper functions ---
+
+func formatMessagesForSummary(messages []Message) string {
+ var result string
+ for _, m := range messages {
+ ts := m.CreatedAt.Format("2006-01-02 15:04 MST")
+ content := m.Content
+ if content == "" && len(m.Parts) > 0 {
+ content = partsToReadableContent(m.Parts)
+ }
+ result += fmt.Sprintf("[%s]\n%s\n\n", ts, content)
+ }
+ return result
+}
+
+func formatSummariesForCondensation(summaries []Summary) string {
+ var result string
+ for _, s := range summaries {
+ earliest := ""
+ if s.EarliestAt != nil {
+ earliest = s.EarliestAt.Format("2006-01-02")
+ }
+ latest := ""
+ if s.LatestAt != nil {
+ latest = s.LatestAt.Format("2006-01-02")
+ }
+ result += fmt.Sprintf("[%s - %s]\n%s\n\n", earliest, latest, s.Content)
+ }
+ return result
+}
+
+func buildLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
+ prev := "(none)"
+ if previousSummary != "" {
+ prev = previousSummary
+ }
+ return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
+Treat this as incremental memory compaction input, not a full-conversation summary.
+
+Normal summary policy:
+- Preserve key decisions, rationale, constraints, and active tasks.
+- Keep essential technical details needed to continue work safely.
+- Remove obvious repetition and conversational filler.
+
+Output requirements:
+- Plain text only.
+- No preamble, headings, or markdown formatting.
+- Track file operations (created, modified, deleted, renamed) with file paths and current status.
+- If no file operations appear, include exactly: "Files: none".
+- End with exactly: "Expand for details about: ".
+- Target length: about %d tokens or less.
+
+
+%s
+
+
+
+%s
+`, targetTokens, prev, sourceText)
+}
+
+func buildCondensedSummaryPrompt(sourceText string, targetTokens int) string {
+ return fmt.Sprintf(`You condense multiple summaries into a single higher-level summary.
+Preserve all important decisions, constraints, and outcomes.
+Merge overlapping topics. Keep technical details intact.
+
+Output requirements:
+- Plain text only.
+- No preamble, headings, or markdown formatting.
+- End with exactly: "Expand for details about: ".
+- Target length: about %d tokens or less.
+
+
+%s
+`, targetTokens, sourceText)
+}
+
+func buildAggressiveLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
+ prev := "(none)"
+ if previousSummary != "" {
+ prev = previousSummary
+ }
+ return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
+Aggressive summary policy:
+- Keep only durable facts and current task state.
+- Remove examples, repetition, and low-value narrative details.
+- Preserve explicit TODOs, blockers, decisions, and constraints.
+
+Output requirements:
+- Plain text only.
+- No preamble, headings, or markdown formatting.
+- Track file operations (created, modified, deleted, renamed) with file paths and current status.
+- If no file operations appear, include exactly: "Files: none".
+- End with exactly: "Expand for details about: ".
+- Target length: about %d tokens or less.
+
+
+%s
+
+
+
+%s
+`, targetTokens, prev, sourceText)
+}
+
+func truncateSummary(messages []Message) string {
+ content := ""
+ for _, m := range messages {
+ c := m.Content
+ if c == "" && len(m.Parts) > 0 {
+ c = partsToReadableContent(m.Parts)
+ }
+ content += c + "\n"
+ }
+ if len(content) > 2048 {
+ content = content[:2048]
+ }
+ content += fmt.Sprintf("\n[Truncated from %d messages]", len(messages))
+ return content
+}
+
+func truncateCondensedSummaries(summaries []Summary) string {
+ content := ""
+ for _, s := range summaries {
+ content += s.Content + "\n"
+ }
+ if len(content) > 2048 {
+ content = content[:2048]
+ }
+ content += fmt.Sprintf("\n[Condensed from %d summaries]", len(summaries))
+ return content
+}
+
+func sumMessageTokens(messages []Message) int {
+ total := 0
+ for _, m := range messages {
+ total += m.TokenCount
+ }
+ return total
+}
+
+func sumSummaryTokens(summaries []Summary) int {
+ total := 0
+ for _, s := range summaries {
+ total += s.TokenCount
+ }
+ return total
+}
+
+func minInt(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
+}
diff --git a/pkg/seahorse/short_compaction_test.go b/pkg/seahorse/short_compaction_test.go
new file mode 100644
index 000000000..ea7dcb52d
--- /dev/null
+++ b/pkg/seahorse/short_compaction_test.go
@@ -0,0 +1,974 @@
+package seahorse
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// --- Test Helpers ---
+
+// waitForCondensed blocks until the async condensed goroutine for convID finishes.
+// Returns false if timeout is reached.
+func waitForCondensed(ce *CompactionEngine, convID int64, timeout time.Duration) bool {
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ if _, exists := ce.condensing.Load(convID); !exists {
+ return true
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+ return false
+}
+
+// --- Compaction Tests ---
+
+func newTestCompactionEngine(t *testing.T) (*CompactionEngine, *Store, int64) {
+ t.Helper()
+ db := openTestDB(t)
+ if err := runSchema(db); err != nil {
+ t.Fatalf("migration: %v", err)
+ }
+ s := &Store{db: db}
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:compact")
+ shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
+ ce := &CompactionEngine{
+ store: s,
+ config: Config{},
+ complete: mockCompleteFn,
+ shutdownCtx: shutdownCtx,
+ shutdownCancel: shutdownCancel,
+ }
+ convID := conv.ConversationID
+ // Ensure async goroutines are stopped before database is closed.
+ // Register cleanup here (after openTestDB) so it runs BEFORE openTestDB's db.Close().
+ t.Cleanup(func() {
+ shutdownCancel()
+ // Wait for async condensed goroutine to finish (poll condensing map)
+ deadline := time.Now().Add(2 * time.Second)
+ for time.Now().Before(deadline) {
+ if _, exists := ce.condensing.Load(convID); !exists {
+ break
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+ })
+ return ce, s, conv.ConversationID
+}
+
+// newTestCompactionEngineWithStore creates a CompactionEngine with existing store.
+// Note: Caller is responsible for calling shutdownCancel when test ends.
+func newTestCompactionEngineWithStore(
+ s *Store, complete CompleteFn,
+) (ce *CompactionEngine, shutdownCancel context.CancelFunc) {
+ shutdownCtx, cancel := context.WithCancel(context.Background())
+ return &CompactionEngine{
+ store: s,
+ config: Config{},
+ complete: complete,
+ shutdownCtx: shutdownCtx,
+ shutdownCancel: cancel,
+ }, cancel
+}
+
+// mockCompleteFn returns a simple summary for testing
+var mockCompleteFn CompleteFn = func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ return "Mock summary of the conversation segment.", nil
+}
+
+func TestNeedsCompaction(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Empty context — no compaction needed
+ needed, err := ce.NeedsCompaction(ctx, convID, 10000)
+ if err != nil {
+ t.Fatalf("NeedsCompaction: %v", err)
+ }
+ if needed {
+ t.Error("expected no compaction for empty context")
+ }
+
+ // Add messages to context, total tokens = 8000
+ for i := 0; i < 8; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "test message content", 1000)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Threshold = 0.75 × 10000 = 7500. We have 8000 tokens → needs compaction
+ needed, err = ce.NeedsCompaction(ctx, convID, 10000)
+ if err != nil {
+ t.Fatalf("NeedsCompaction: %v", err)
+ }
+ if !needed {
+ t.Error("expected compaction needed at 8000/10000 tokens (threshold 75%)")
+ }
+
+ // Below threshold: 5000 / 10000 → no compaction
+ s.UpsertContextItems(ctx, convID, nil) // clear
+ for i := 0; i < 5; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "test", 1000)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+ needed, _ = ce.NeedsCompaction(ctx, convID, 10000)
+ if needed {
+ t.Error("expected no compaction at 5000/10000 tokens")
+ }
+}
+
+func TestCompactLeaf(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create enough messages to trigger leaf compaction:
+ // Need > FreshTailCount(32) evictable messages with >= LeafMinFanout(8) contiguous
+ for i := 0; i < 40; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "message content for compaction test", 100)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Compact
+ result, err := ce.Compact(ctx, convID, CompactInput{})
+ if err != nil {
+ t.Fatalf("Compact: %v", err)
+ }
+ if result == nil {
+ t.Fatal("expected non-nil result")
+ }
+
+ // Should have created at least one leaf summary
+ if result.LeafSummaries == 0 {
+ t.Error("expected at least 1 leaf summary")
+ }
+
+ // Context should now contain a summary item
+ items, _ := s.GetContextItems(ctx, convID)
+ foundSummary := false
+ for _, item := range items {
+ if item.ItemType == "summary" {
+ foundSummary = true
+ break
+ }
+ }
+ if !foundSummary {
+ t.Error("expected a summary in context_items after leaf compaction")
+ }
+
+ // Some messages should have been replaced
+ if len(result.SummariesCreated) == 0 {
+ t.Error("expected at least 1 summary created")
+ }
+}
+
+func TestCompactLeafNoCandidate(t *testing.T) {
+ ce, _, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Too few messages to trigger leaf compaction
+ m, _ := ce.store.AddMessage(ctx, convID, "user", "short", 10)
+ ce.store.AppendContextMessage(ctx, convID, m.ID)
+
+ result, err := ce.Compact(ctx, convID, CompactInput{})
+ if err != nil {
+ t.Fatalf("Compact: %v", err)
+ }
+ if result == nil {
+ t.Fatal("expected non-nil result even with no candidate")
+ }
+ if result.LeafSummaries != 0 {
+ t.Errorf("LeafSummaries = %d, want 0 (too few messages)", result.LeafSummaries)
+ }
+}
+
+func TestCompactCondensed(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create enough leaf summaries and fresh messages to enable condensation
+ leafIDs := make([]string, CondensedMinFanout)
+ for i := 0; i < CondensedMinFanout; i++ {
+ now := time.Now().UTC()
+ summary, err := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf summary content " + time.Now().String(),
+ TokenCount: 500,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ if err != nil {
+ t.Fatalf("CreateSummary %d: %v", i, err)
+ }
+ leafIDs[i] = summary.SummaryID
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+
+ // Add enough fresh messages to have a fresh tail (>= FreshTailCount)
+ for i := 0; i < FreshTailCount; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh message", 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Compact with force to trigger condensation
+ _, err := ce.Compact(ctx, convID, CompactInput{Force: true})
+ if err != nil {
+ t.Fatalf("Compact: %v", err)
+ }
+
+ // Wait for async condensed goroutine to complete
+ if !waitForCondensed(ce, convID, 2*time.Second) {
+ t.Fatal("timeout waiting for condensed compaction")
+ }
+
+ // Should have created a condensed summary in the DB
+ summaries, _ := s.GetSummariesByConversation(ctx, convID)
+ foundCondensed := false
+ for _, sum := range summaries {
+ if sum.Kind == SummaryKindCondensed {
+ foundCondensed = true
+ break
+ }
+ }
+ if !foundCondensed {
+ t.Error("expected at least 1 condensed summary")
+ }
+}
+
+func TestCompactCondensedDoesNotOrphanSummaryWhenCandidatesRemovedConcurrently(t *testing.T) {
+ // Reproduce orphan bug: candidates found by selectOldestChunkAtDepth are removed
+ // from context_items between candidate selection and ordinal range scan.
+ // Use a slow CompleteFn with barrier sync to control timing.
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:orphan-race")
+ convID := conv.ConversationID
+
+ // Create leaf summaries with enough tokens for condensation
+ var leafIDs []string
+ for i := 0; i < CondensedMinFanout; i++ {
+ now := time.Now().UTC()
+ sum, err := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("leaf summary %d", i),
+ TokenCount: 500,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ if err != nil {
+ t.Fatalf("CreateSummary: %v", err)
+ }
+ leafIDs = append(leafIDs, sum.SummaryID)
+ s.AppendContextSummary(ctx, convID, sum.SummaryID)
+ }
+
+ // Add fresh tail so leaf summaries are in evictable range
+ for i := 0; i < FreshTailCount+1; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Barrier: CompleteFn waits until test removes context_items, then returns
+ var barrier1, barrier2 sync.WaitGroup
+ barrier1.Add(1) // CompleteFn signals when called
+ barrier2.Add(1) // test signals when context_items removed
+
+ slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ barrier1.Done() // signal: LLM called, candidates selected
+ barrier2.Wait() // wait: test removes context_items
+ return "Condensed summary.", nil
+ }
+
+ ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
+ t.Cleanup(func() {
+ cancel()
+ time.Sleep(100 * time.Millisecond)
+ })
+
+ // Run compactCondensed in background
+ type compactResult struct {
+ summaryID *string
+ err error
+ }
+ resultCh := make(chan compactResult, 1)
+ go func() {
+ sid, err := ce.compactCondensed(context.Background(), convID)
+ resultCh <- compactResult{summaryID: sid, err: err}
+ }()
+
+ // Wait for CompleteFn to be called (candidates selected)
+ barrier1.Wait()
+
+ // Remove leaf summaries from context_items (simulating concurrent replacement)
+ items, _ := s.GetContextItems(ctx, convID)
+ var preserved []ContextItem
+ for _, item := range items {
+ isLeaf := false
+ for _, lid := range leafIDs {
+ if item.SummaryID == lid {
+ isLeaf = true
+ break
+ }
+ }
+ if !isLeaf {
+ preserved = append(preserved, item)
+ }
+ }
+ s.UpsertContextItems(ctx, convID, preserved)
+
+ // Let CompleteFn return
+ barrier2.Done()
+
+ // Get result
+ res := <-resultCh
+ if res.err != nil {
+ t.Fatalf("compactCondensed: %v", res.err)
+ }
+
+ // With the bug: returns non-nil summaryID even though context_items has no matching ordinals
+ // The fix: should return nil when startOrd == -1
+ if res.summaryID != nil {
+ t.Errorf("compactCondensed returned summaryID=%s, want nil (orphan created)", *res.summaryID)
+
+ // Verify the orphan exists in DB
+ summary, _ := s.GetSummary(context.Background(), *res.summaryID)
+ if summary != nil && summary.Kind == SummaryKindCondensed {
+ // Check it's NOT in context_items (orphan)
+ items2, _ := s.GetContextItems(context.Background(), convID)
+ found := false
+ for _, item := range items2 {
+ if item.SummaryID == *res.summaryID {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("condensed summary exists in DB but not in context_items — orphan confirmed")
+ }
+ }
+ }
+}
+
+func TestCompactUntilUnder(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create many leaf summaries to ensure we can condense
+ for i := 0; i < 8; i++ {
+ now := time.Now().UTC()
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf summary for condensation test",
+ TokenCount: 500,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+
+ // Force compact until under budget
+ result, err := ce.CompactUntilUnder(ctx, convID, 2000)
+ if err != nil {
+ t.Fatalf("CompactUntilUnder: %v", err)
+ }
+
+ if result == nil {
+ t.Fatal("expected non-nil result")
+ }
+}
+
+func TestSelectShallowestCondensationCandidate(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create enough leaf summaries + fresh messages for candidates
+ for i := 0; i < LeafMinFanout; i++ {
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf",
+ TokenCount: 100,
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+
+ // Add fresh tail messages so summaries are in evictable range
+ for i := 0; i < FreshTailCount+1; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
+ if err != nil {
+ t.Fatalf("selectShallowestCondensationCandidate: %v", err)
+ }
+
+ // Should find leaf summaries at depth 0
+ if len(candidates) < CondensedMinFanout {
+ t.Errorf("candidates = %d, want >= %d", len(candidates), CondensedMinFanout)
+ }
+}
+
+func TestSelectShallowestCondensationCandidateEmpty(t *testing.T) {
+ ce, _, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
+ if err != nil {
+ t.Fatalf("selectShallowestCondensationCandidate: %v", err)
+ }
+ if len(candidates) != 0 {
+ t.Errorf("candidates = %d, want 0 for empty context", len(candidates))
+ }
+}
+
+func TestCompactCondensedUsesSelectOldestChunk(t *testing.T) {
+ // Verify that compactCondensed prefers ordinal-ordered chunks via selectOldestChunkAtDepth
+ // rather than just grouping by depth without regard to order
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create interleaved summaries at depth 0 with a message in between:
+ // sum1 (ordinal 100), msg (ordinal 200), sum2 (ordinal 300)
+
+ for i := 0; i < LeafMinFanout+2; i++ {
+ now := time.Now().UTC()
+
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("leaf summary %d", i),
+ TokenCount: 100,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ }
+
+ // Insert a message between first two summaries to break contiguity
+ // for selectShallowestCondensationCandidate but would still find all 3
+ // but selectOldestChunkAtDepth should only find sum1 + sum2 (not sum3)
+
+ msg, _ := s.AddMessage(ctx, convID, "user", "interrupting message", 5)
+ s.AppendContextMessage(ctx, convID, msg.ID)
+
+ // Run compactCondensed
+ result, err := ce.compactCondensed(ctx, convID)
+ if err != nil {
+ t.Fatalf("compactCondensed: %v", err)
+ }
+
+ // The result should have merged the two summaries at the start
+ // (skipping the message in between), This proves ordinal-aware selection works.
+
+ _ = result // verify summary was created
+
+ if result != nil {
+ summaries, _ := s.GetSummariesByConversation(ctx, convID)
+ found := false
+ for _, sum := range summaries {
+ if sum.Kind == SummaryKindCondensed {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("expected condensed summary to be created via ordinal-aware selection")
+ }
+ }
+}
+
+func TestCompactCondensedUsesOrdinalAwareSelection(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create leaf summaries at depth 0 (total tokens >= CondensedTargetTokens)
+ for i := 0; i < 5; i++ {
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("leaf summary %d", i),
+ TokenCount: 500, // 5 × 500 = 2500 >= CondensedTargetTokens (2000)
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+
+ // Add fresh tail
+ for i := 0; i < FreshTailCount+1; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
+ if err != nil {
+ t.Fatalf("selectOldestChunkAtDepth: %v", err)
+ }
+ if len(chunk) < 2 {
+ t.Errorf("chunk length = %d, want >= 2 contiguous summaries", len(chunk))
+ }
+ for _, s := range chunk {
+ if s.Depth != 0 {
+ t.Errorf("got depth %d, want 0", s.Depth)
+ }
+ }
+}
+
+func TestSelectOldestChunkAtDepthBreaksOnMessage(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create 3 summaries, then a message, then 3 more summaries
+ for i := 0; i < 3; i++ {
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("leaf %d", i),
+ TokenCount: 100,
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+ msg, _ := s.AddMessage(ctx, convID, "user", "break", 10)
+ s.AppendContextMessage(ctx, convID, msg.ID)
+ for i := 0; i < 3; i++ {
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("leaf-after %d", i),
+ TokenCount: 100,
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+ for i := 0; i < FreshTailCount+1; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ chunk, _ := ce.selectOldestChunkAtDepth(ctx, convID, 0)
+ if len(chunk) > 3 {
+ t.Errorf("chunk length = %d, want <= 3 (message breaks chain)", len(chunk))
+ }
+}
+
+func TestSelectOldestChunkAtDepthMinTokens(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create summaries with very low token counts (total < 2000)
+ for i := 0; i < 5; i++ {
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("tiny summary %d", i),
+ TokenCount: 50, // very small
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+
+ // Add fresh tail to protect from compaction
+ for i := 0; i < FreshTailCount+1; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Should return nil because total tokens (250) < 2000 minimum
+ chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
+ if err != nil {
+ t.Fatalf("selectOldestChunkAtDepth: %v", err)
+ }
+ if len(chunk) > 0 {
+ t.Errorf("expected empty chunk when tokens < 2000, got %d summaries", len(chunk))
+ }
+}
+
+func TestSelectOldestChunkAtDepthPassesMinTokens(t *testing.T) {
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create summaries with enough tokens (total >= 2000)
+ for i := 0; i < 5; i++ {
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf(
+ "substantial summary with enough content to meet minimum token threshold for condensation candidate %d",
+ i,
+ ),
+ TokenCount: 500, // 5 × 500 = 2500 >= 2000
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+
+ // Add fresh tail
+ for i := 0; i < FreshTailCount+1; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Should return chunk because total tokens (2500) >= 2000
+ chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
+ if err != nil {
+ t.Fatalf("selectOldestChunkAtDepth: %v", err)
+ }
+ if len(chunk) == 0 {
+ t.Error("expected non-empty chunk when tokens >= 2000")
+ }
+}
+
+func TestGenerateLeafSummary(t *testing.T) {
+ ce, _, _ := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ msgs := []Message{
+ {Role: "user", Content: "hello world", TokenCount: 5},
+ {Role: "assistant", Content: "hi there", TokenCount: 5},
+ }
+
+ content, err := ce.generateLeafSummary(ctx, msgs, "")
+ if err != nil {
+ t.Fatalf("generateLeafSummary: %v", err)
+ }
+ if content == "" {
+ t.Error("expected non-empty summary content")
+ }
+}
+
+func TestGenerateLeafSummaryEscalationToAggressive(t *testing.T) {
+ // Level 1 returns summary that's too large (tokens >= input), should escalate to level 2
+ var calls []string
+ escalateComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ if contains(prompt, "Aggressive summary policy") {
+ calls = append(calls, "aggressive")
+ return "Short aggressive summary.", nil
+ }
+ calls = append(calls, "normal")
+ // Return a very long summary to trigger escalation
+ longContent := make([]byte, 5000)
+ for i := range longContent {
+ longContent[i] = 'x'
+ }
+ return string(longContent), nil
+ }
+
+ s := openTestStore(t)
+ ce, _ := newTestCompactionEngineWithStore(s, escalateComplete)
+
+ msgs := []Message{
+ {Role: "user", Content: "hello world", TokenCount: 10},
+ {Role: "assistant", Content: "response", TokenCount: 10},
+ }
+
+ content, err := ce.generateLeafSummary(context.Background(), msgs, "")
+ if err != nil {
+ t.Fatalf("generateLeafSummary: %v", err)
+ }
+ if content == "" {
+ t.Error("expected non-empty summary content")
+ }
+ // Should have called both normal and aggressive
+ foundNormal := false
+ foundAggressive := false
+ for _, c := range calls {
+ if c == "normal" {
+ foundNormal = true
+ }
+ if c == "aggressive" {
+ foundAggressive = true
+ }
+ }
+ if !foundNormal {
+ t.Error("expected normal LLM call")
+ }
+ if !foundAggressive {
+ t.Error("expected aggressive LLM call (level 2 escalation)")
+ }
+}
+
+func TestGenerateLeafSummaryEscalationToTruncation(t *testing.T) {
+ // Both normal and aggressive return empty, should escalate to level 3 truncation
+ emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ return "", nil
+ }
+
+ s := openTestStore(t)
+ ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
+
+ msgs := []Message{
+ {Role: "user", Content: "hello world from test", TokenCount: 10},
+ {Role: "assistant", Content: "response text here", TokenCount: 10},
+ }
+
+ content, err := ce.generateLeafSummary(context.Background(), msgs, "")
+ if err != nil {
+ t.Fatalf("generateLeafSummary: %v", err)
+ }
+ // Level 3 truncation should have produced something
+ if content == "" {
+ t.Error("expected non-empty content from level 3 truncation fallback")
+ }
+ if !contains(content, "Truncated from") {
+ t.Errorf("expected truncation marker in content: %q", content)
+ }
+}
+
+func TestGenerateCondensedSummary(t *testing.T) {
+ ce, _, _ := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ summaries := []Summary{
+ {SummaryID: "sum_a", Content: "first summary", TokenCount: 100},
+ {SummaryID: "sum_b", Content: "second summary", TokenCount: 100},
+ }
+
+ content, err := ce.generateCondensedSummary(ctx, summaries)
+ if err != nil {
+ t.Fatalf("generateCondensedSummary: %v", err)
+ }
+ if content == "" {
+ t.Error("expected non-empty condensed summary content")
+ }
+}
+
+func TestGenerateCondensedSummaryEscalation(t *testing.T) {
+ // When LLM returns empty, should fall back to deterministic concatenation
+ emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ return "", nil
+ }
+
+ s := openTestStore(t)
+ ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
+
+ summaries := []Summary{
+ {SummaryID: "sum_a", Content: "first summary text", TokenCount: 50},
+ {SummaryID: "sum_b", Content: "second summary text", TokenCount: 50},
+ }
+
+ content, err := ce.generateCondensedSummary(context.Background(), summaries)
+ if err != nil {
+ t.Fatalf("generateCondensedSummary: %v", err)
+ }
+ // Should fall back to concatenation
+ if content == "" {
+ t.Error("expected non-empty content from fallback")
+ }
+}
+
+// --- Async Condensed Compaction (Phase 2) ---
+
+func TestCompactAsyncReturnsBeforeCondensed(t *testing.T) {
+ // Use a slow CompleteFn to verify Compact returns before condensed finishes
+ var callCount int32
+ slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ atomic.AddInt32(&callCount, 1)
+ time.Sleep(500 * time.Millisecond) // simulate slow LLM
+ return "Slow condensed summary.", nil
+ }
+
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:async")
+ convID := conv.ConversationID
+
+ ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
+ t.Cleanup(func() {
+ cancel()
+ time.Sleep(100 * time.Millisecond)
+ })
+
+ // Create enough leaf summaries for condensation + fresh tail
+ for i := 0; i < CondensedMinFanout; i++ {
+ now := time.Now().UTC()
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf for async test",
+ TokenCount: 500,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+ for i := 0; i < FreshTailCount; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Compact with force — should return quickly, condensed runs async
+ start := time.Now()
+ result, err := ce.Compact(ctx, convID, CompactInput{Force: true})
+ elapsed := time.Since(start)
+
+ if err != nil {
+ t.Fatalf("Compact: %v", err)
+ }
+ if result == nil {
+ t.Fatal("expected non-nil result")
+ }
+
+ // Should return well before the 500ms LLM call
+ if elapsed > 200*time.Millisecond {
+ t.Errorf("Compact took %v, should return before async condensed finishes", elapsed)
+ }
+
+ // Wait for async to complete
+ time.Sleep(800 * time.Millisecond)
+
+ // Verify condensed summary was created by background goroutine
+ summaries, _ := s.GetSummariesByConversation(ctx, convID)
+ foundCondensed := false
+ for _, sum := range summaries {
+ if sum.Kind == SummaryKindCondensed {
+ foundCondensed = true
+ break
+ }
+ }
+ if !foundCondensed {
+ t.Error("expected at least one condensed summary from async Phase 2")
+ }
+}
+
+func TestCompactAsyncDedup(t *testing.T) {
+ var callCount int32
+ slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
+ atomic.AddInt32(&callCount, 1)
+ time.Sleep(300 * time.Millisecond)
+ return "Slow condensed summary.", nil
+ }
+
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:dedup")
+ convID := conv.ConversationID
+
+ ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
+ t.Cleanup(func() {
+ cancel()
+ waitForCondensed(ce, convID, 2*time.Second)
+ })
+
+ // Create conditions for condensed compaction
+ for i := 0; i < CondensedMinFanout; i++ {
+ now := time.Now().UTC()
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf for dedup",
+ TokenCount: 500,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ s.AppendContextSummary(ctx, convID, summary.SummaryID)
+ }
+ for i := 0; i < FreshTailCount; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Call Compact twice rapidly
+ ce.Compact(ctx, convID, CompactInput{Force: true})
+ ce.Compact(ctx, convID, CompactInput{Force: true})
+
+ // Wait for async to finish
+ time.Sleep(600 * time.Millisecond)
+
+ // LLM should only be called once for condensed (dedup)
+ // callCount may be 0 if no leaf was created (only condensed in goroutine)
+ // The key is that we don't get 2+ condensed calls
+ if atomic.LoadInt32(&callCount) > 1 {
+ t.Errorf("LLM called %d times, expected at most 1 (dedup)", callCount)
+ }
+}
+
+func TestCompactLeafForceBypassesFreshTail(t *testing.T) {
+ // Spec: compactLeaf with force=true should bypass FreshTailCount protection
+ // so CompactUntilUnder can compress messages inside the fresh tail
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create exactly FreshTailCount+4 messages (36 total)
+ // Without force: all messages are in fresh tail → no candidate
+ // With force: should compact the oldest messages
+ total := FreshTailCount + 4
+ for i := 0; i < total; i++ {
+ m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("message %d for force test", i), 100)
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ // Without force: should return nil (all in fresh tail)
+ summaryID, err := ce.compactLeaf(ctx, convID)
+ if err != nil {
+ t.Fatalf("compactLeaf no-force: %v", err)
+ }
+ if summaryID != nil {
+ t.Error("expected nil without force (all messages in fresh tail)")
+ }
+
+ // With force: should compact despite fresh tail protection
+ summaryID, err = ce.compactLeaf(ctx, convID, true)
+ if err != nil {
+ t.Fatalf("compactLeaf force: %v", err)
+ }
+ if summaryID == nil {
+ t.Error("expected summary with force=true (bypasses fresh tail)")
+ }
+}
+
+func TestCompactLeafAccumulatesUpToLeafChunkTokens(t *testing.T) {
+ // Spec: compactLeaf should accumulate messages up to LeafChunkTokens before stopping
+ // It should NOT take the entire contiguous chunk regardless of token count
+ ce, s, convID := newTestCompactionEngine(t)
+ ctx := context.Background()
+
+ // Create messages totaling far more than LeafChunkTokens (20000)
+ // Each message is ~500 tokens, create 80 messages = 40000 tokens
+ for i := 0; i < 80; i++ {
+ m, _ := s.AddMessage(
+ ctx,
+ convID,
+ "user",
+ fmt.Sprintf(
+ "message %d with lots of content to make it big enough for token counting purposes and this should be a substantial message body that represents a meaningful conversation turn",
+ i,
+ ),
+ 500,
+ )
+ s.AppendContextMessage(ctx, convID, m.ID)
+ }
+
+ summaryID, err := ce.compactLeaf(ctx, convID)
+ if err != nil {
+ t.Fatalf("compactLeaf: %v", err)
+ }
+ if summaryID == nil {
+ t.Fatal("expected a summary to be created")
+ }
+
+ // The source messages that were compacted should total roughly LeafChunkTokens (20000),
+ // not the entire 40000 tokens worth of messages
+ summary, _ := s.GetSummary(ctx, *summaryID)
+ if summary == nil {
+ t.Fatal("summary not found")
+ }
+
+ // Source message tokens should be roughly <= LeafChunkTokens (20000)
+ // Spec says: "Stop when accumulated tokens >= LeafChunkTokens"
+ if summary.SourceMessageTokenCount > LeafChunkTokens {
+ t.Errorf("source tokens = %d, should be <= LeafChunkTokens (%d)",
+ summary.SourceMessageTokenCount, LeafChunkTokens)
+ }
+}
diff --git a/pkg/seahorse/short_constants.go b/pkg/seahorse/short_constants.go
new file mode 100644
index 000000000..943d7931e
--- /dev/null
+++ b/pkg/seahorse/short_constants.go
@@ -0,0 +1,30 @@
+package seahorse
+
+// Short-term memory configuration constants — all are experience-based defaults.
+
+const (
+ // OrdinalStep is the gap between ordinals in context_items.
+ // Insert at midpoint; resequence only when precision exhausted.
+ OrdinalStep = 100
+
+ // ContextThreshold is the compaction trigger for the context window.
+ ContextThreshold float64 = 0.75 // Compact at 75% of context window
+ FreshTailCount int = 32 // Recent messages protected from compaction
+
+ // LeafMinFanout is the fanout parameter.
+ LeafMinFanout int = 8 // Min messages per leaf summary
+ CondensedMinFanout int = 4 // Min summaries per condensed
+ CondensedMinFanoutHard int = 2 // Min for forced compaction
+
+ // LeafChunkTokens is the token target.
+ LeafChunkTokens int = 20000 // Max tokens per leaf chunk
+ LeafTargetTokens int = 1200 // Target tokens for leaf summaries
+ CondensedTargetTokens int = 2000 // Target tokens for condensed summaries
+ MaxExpandTokens int = 4000 // Token cap for expansion queries
+
+ // MaxCompactIterations caps CompactUntilUnder to prevent infinite loops.
+ // Each iteration reduces ~4x tokens via leaf (8:1) or condensed (4:1) compaction.
+ // With a 200k token context window and 75% threshold, ~20 iterations is enough
+ // for any realistic scenario. If exceeded, the issue is logged as a warning.
+ MaxCompactIterations int = 20
+)
diff --git a/pkg/seahorse/short_engine.go b/pkg/seahorse/short_engine.go
new file mode 100644
index 000000000..4cd4d3887
--- /dev/null
+++ b/pkg/seahorse/short_engine.go
@@ -0,0 +1,568 @@
+package seahorse
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "sync"
+
+ _ "modernc.org/sqlite"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// Config holds engine configuration.
+type Config struct {
+ DBPath string `json:"dbPath"`
+ IgnoreSessionPatterns []string `json:"ignoreSessionPatterns,omitempty"`
+ StatelessSessionPatterns []string `json:"statelessSessionPatterns,omitempty"`
+}
+
+// CompleteFn is the LLM completion function type.
+type CompleteFn func(ctx context.Context, prompt string, opts CompleteOptions) (string, error)
+
+// CompleteOptions holds LLM completion parameters.
+type CompleteOptions struct {
+ Model string
+ MaxTokens int
+ Temperature float64
+}
+
+// IngestResult is the result of message ingestion.
+type IngestResult struct {
+ MessageCount int `json:"messageCount"`
+ TokenCount int `json:"tokenCount"`
+}
+
+// AssembleInput controls context assembly.
+type AssembleInput struct {
+ Budget int `json:"budget"`
+ Query string `json:"query,omitempty"`
+}
+
+// AssembleResult contains assembled context.
+type AssembleResult struct {
+ Messages []Message `json:"messages"`
+ Summary string `json:"summary"` // formatted XML summaries + system prompt addition
+}
+
+const numSessionShards = 256
+
+// Engine is the main short-term memory engine.
+type Engine struct {
+ store *Store
+ compaction *CompactionEngine
+ compactionMu sync.Mutex
+ assembler *Assembler
+ assemblerMu sync.Mutex
+ retrieval *RetrievalEngine
+ config Config
+ complete CompleteFn
+ ignorePatterns []*regexp.Regexp
+ statelessPatterns []*regexp.Regexp
+ sessionShards [numSessionShards]struct {
+ mu sync.Mutex
+ }
+}
+
+// CompactionEngine handles LLM-based summarization (defined in short_compaction.go).
+type CompactionEngine struct {
+ store *Store
+ config Config
+ complete CompleteFn
+ condensing sync.Map // map[int64]struct{} — dedup for async condensed goroutines
+ shutdownCtx context.Context
+ shutdownCancel context.CancelFunc
+}
+
+// Assembler handles budget-aware context assembly (defined in short_assembler.go).
+type Assembler struct {
+ store *Store
+ config Config
+}
+
+// RetrievalEngine handles search and expansion (defined in short_retrieval.go).
+type RetrievalEngine struct {
+ store *Store
+ config Config
+}
+
+// Store returns the underlying store for direct access.
+func (r *RetrievalEngine) Store() *Store {
+ return r.store
+}
+
+// NewEngine creates a new short-term memory engine.
+func NewEngine(config Config, completeFn CompleteFn) (*Engine, error) {
+ dir := filepath.Dir(config.DBPath)
+ if dir != "" && dir != "." {
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return nil, fmt.Errorf("create db directory: %w", err)
+ }
+ }
+
+ db, err := sql.Open("sqlite", config.DBPath)
+ if err != nil {
+ return nil, fmt.Errorf("open db: %w", err)
+ }
+
+ // Configure SQLite for concurrent access
+ if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
+ db.Close()
+ return nil, fmt.Errorf("enable WAL: %w", err)
+ }
+ if _, err := db.Exec("PRAGMA busy_timeout = 5000;"); err != nil {
+ db.Close()
+ return nil, fmt.Errorf("set busy_timeout: %w", err)
+ }
+ if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
+ db.Close()
+ return nil, fmt.Errorf("set synchronous: %w", err)
+ }
+
+ if err := runSchema(db); err != nil {
+ db.Close()
+ return nil, fmt.Errorf("migrations: %w", err)
+ }
+
+ store := &Store{db: db}
+
+ // Prepend hardcoded ignore patterns (spec lines 1326-1328)
+ ignorePatterns := make([]string, 0, 1+len(config.IgnoreSessionPatterns))
+ ignorePatterns = append(ignorePatterns, "heartbeat")
+ ignorePatterns = append(ignorePatterns, config.IgnoreSessionPatterns...)
+
+ retrieval := &RetrievalEngine{store: store, config: config}
+
+ return &Engine{
+ store: store,
+ compaction: nil,
+ assembler: nil,
+ retrieval: retrieval,
+ config: config,
+ complete: completeFn,
+ ignorePatterns: compileSessionPatterns(ignorePatterns),
+ statelessPatterns: compileSessionPatterns(config.StatelessSessionPatterns),
+ }, nil
+}
+
+// compileSessionPattern converts a glob pattern to a compiled regex.
+// Pattern rules:
+// - * matches any sequence of non-colon characters ([^:]*)
+// - ** matches any sequence of characters including colons (.*)
+// - All other characters are treated literally
+// - Pattern is anchored (^...$)
+func compileSessionPattern(pattern string) *regexp.Regexp {
+ var b strings.Builder
+ b.WriteByte('^')
+
+ i := 0
+ for i < len(pattern) {
+ if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' {
+ b.WriteString(".*")
+ i += 2
+ continue
+ }
+ if pattern[i] == '*' {
+ b.WriteString("[^:]*")
+ i++
+ continue
+ }
+ b.WriteString(regexp.QuoteMeta(string(pattern[i])))
+ i++
+ }
+
+ b.WriteByte('$')
+ return regexp.MustCompile(b.String())
+}
+
+// compileSessionPatterns compiles multiple glob patterns into regex patterns.
+func compileSessionPatterns(patterns []string) []*regexp.Regexp {
+ result := make([]*regexp.Regexp, 0, len(patterns))
+ for _, p := range patterns {
+ if p == "" {
+ continue
+ }
+ result = append(result, compileSessionPattern(p))
+ }
+ return result
+}
+
+// shouldIgnoreSession returns true if the session key matches any ignore pattern.
+func (e *Engine) shouldIgnoreSession(sessionKey string) bool {
+ for _, p := range e.ignorePatterns {
+ if p.MatchString(sessionKey) {
+ return true
+ }
+ }
+ return false
+}
+
+// isStatelessSession returns true if the session key matches any stateless pattern.
+func (e *Engine) isStatelessSession(sessionKey string) bool {
+ for _, p := range e.statelessPatterns {
+ if p.MatchString(sessionKey) {
+ return true
+ }
+ }
+ return false
+}
+
+// fnv32 computes FNV-1a 32-bit hash for session key sharding.
+func fnv32(key string) uint32 {
+ h := uint32(2166136261)
+ for _, c := range key {
+ h ^= uint32(c)
+ h *= 16777619
+ }
+ return h
+}
+
+// getSessionMutex returns the sharded mutex for a session key.
+func (e *Engine) getSessionMutex(sessionKey string) *sync.Mutex {
+ h := fnv32(sessionKey)
+ shard := h % numSessionShards
+ return &e.sessionShards[shard].mu
+}
+
+// Ingest adds messages to a conversation identified by sessionKey.
+func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
+ if e.shouldIgnoreSession(sessionKey) {
+ return nil, nil
+ }
+ if e.isStatelessSession(sessionKey) {
+ return nil, nil
+ }
+
+ mu := e.getSessionMutex(sessionKey)
+ mu.Lock()
+ defer mu.Unlock()
+
+ conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ return nil, fmt.Errorf("get conversation: %w", err)
+ }
+
+ var totalTokens int
+ var msgIDs []int64
+ for _, msg := range messages {
+ var added *Message
+ var err error
+ if len(msg.Parts) > 0 {
+ added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount)
+ } else {
+ added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount)
+ }
+ if err != nil {
+ return nil, fmt.Errorf("add message: %w", err)
+ }
+ totalTokens += msg.TokenCount
+ msgIDs = append(msgIDs, added.ID)
+ }
+
+ // Append to context_items using actual inserted IDs
+ if err := e.store.AppendContextMessages(ctx, conv.ConversationID, msgIDs); err != nil {
+ return nil, fmt.Errorf("append context: %w", err)
+ }
+
+ logger.InfoCF("seahorse", "ingest", map[string]any{
+ "conv_id": conv.ConversationID,
+ "messages": len(messages),
+ "tokens": totalTokens,
+ })
+ return &IngestResult{
+ MessageCount: len(messages),
+ TokenCount: totalTokens,
+ }, nil
+}
+
+// Close releases resources.
+func (e *Engine) Close() error {
+ // Signal compaction goroutines to stop
+ if e.compaction != nil {
+ e.compaction.Close()
+ }
+ if e.store != nil && e.store.db != nil {
+ return e.store.db.Close()
+ }
+ return nil
+}
+
+// GetRetrieval returns the retrieval engine for tool implementations.
+func (e *Engine) GetRetrieval() *RetrievalEngine {
+ return e.retrieval
+}
+
+// Assemble builds budget-constrained context for a session.
+func (e *Engine) Assemble(ctx context.Context, sessionKey string, input AssembleInput) (*AssembleResult, error) {
+ if e.shouldIgnoreSession(sessionKey) {
+ return nil, nil
+ }
+
+ conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ return nil, fmt.Errorf("get conversation: %w", err)
+ }
+
+ e.initAssemblerOnce()
+ return e.assembler.Assemble(ctx, conv.ConversationID, input)
+}
+
+// Compact compresses conversation history for a session.
+func (e *Engine) Compact(ctx context.Context, sessionKey string, input CompactInput) (*CompactResult, error) {
+ if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
+ return &CompactResult{}, nil
+ }
+
+ conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ return nil, fmt.Errorf("get conversation: %w", err)
+ }
+
+ e.initCompactionOnce()
+ return e.compaction.Compact(ctx, conv.ConversationID, input)
+}
+
+// CompactUntilUnder aggressively compacts until context is under budget.
+// Used for emergency compaction after LLM overflow (retry reason).
+func (e *Engine) CompactUntilUnder(ctx context.Context, sessionKey string, budget int) (*CompactResult, error) {
+ if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
+ return &CompactResult{}, nil
+ }
+
+ conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ return nil, fmt.Errorf("get conversation: %w", err)
+ }
+
+ e.initCompactionOnce()
+ return e.compaction.CompactUntilUnder(ctx, conv.ConversationID, budget)
+}
+
+// initCompactionOnce lazily initializes the compaction engine.
+func (e *Engine) initCompactionOnce() {
+ if e.compaction == nil {
+ e.compactionMu.Lock()
+ defer e.compactionMu.Unlock()
+ if e.compaction == nil {
+ shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
+ e.compaction = &CompactionEngine{
+ store: e.store,
+ config: e.config,
+ complete: e.complete,
+ shutdownCtx: shutdownCtx,
+ shutdownCancel: shutdownCancel,
+ }
+ }
+ }
+}
+
+// initAssemblerOnce lazily initializes the assembler.
+func (e *Engine) initAssemblerOnce() {
+ if e.assembler == nil {
+ e.assemblerMu.Lock()
+ defer e.assemblerMu.Unlock()
+ if e.assembler == nil {
+ e.assembler = &Assembler{store: e.store, config: e.config}
+ }
+ }
+}
+
+// IngestMessages is an alias for Ingest.
+func (e *Engine) IngestMessages(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
+ return e.Ingest(ctx, sessionKey, messages)
+}
+
+// Bootstrap reconciles a session's messages with the database.
+// Called once at startup for each known session.
+// Bootstrap reconciles JSONL history with SQLite by ingesting only the delta.
+// Simple approach: find longest matching prefix and append delta.
+// If any mismatch is detected, clear and rebuild.
+func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Message) error {
+ if e.shouldIgnoreSession(sessionKey) {
+ return nil
+ }
+ if e.isStatelessSession(sessionKey) {
+ return nil
+ }
+ if len(messages) == 0 {
+ return nil
+ }
+
+ conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
+ if err != nil {
+ return fmt.Errorf("bootstrap: get conversation: %w", err)
+ }
+
+ // Get messages already in DB
+ dbMsgs, err := e.store.GetMessages(ctx, conv.ConversationID, len(messages), 0)
+ if err != nil {
+ return fmt.Errorf("bootstrap: get messages: %w", err)
+ }
+
+ // Fast path: DB has same count and exact match → no-op
+ if len(dbMsgs) == len(messages) {
+ matched := true
+ for i := 0; i < len(messages); i++ {
+ if !messageMatches(dbMsgs[i], messages[i]) {
+ matched = false
+ break
+ }
+ }
+ if matched {
+ return nil // DB is up to date
+ }
+ }
+
+ // Find longest matching prefix from the start
+ anchor := -1
+ compareLen := len(dbMsgs)
+ if compareLen > len(messages) {
+ compareLen = len(messages)
+ }
+
+ for i := 0; i < compareLen; i++ {
+ if messageMatches(dbMsgs[i], messages[i]) {
+ anchor = i
+ } else {
+ // Mismatch detected - log details and rebuild
+ logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
+ "conv_id": conv.ConversationID,
+ "index": i,
+ "db_role": dbMsgs[i].Role,
+ "db_content": truncate(dbMsgs[i].Content, 50),
+ "db_parts": len(dbMsgs[i].Parts),
+ "msg_role": messages[i].Role,
+ "msg_content": truncate(messages[i].Content, 50),
+ "msg_parts": len(messages[i].Parts),
+ })
+ break
+ }
+ }
+
+ // If we hit a mismatch before reaching the end of DB messages, delete delta and re-ingest
+ // Note: anchor can be -1 if first message didn't match (history completely changed)
+ if anchor >= 0 && anchor < len(dbMsgs)-1 && len(dbMsgs) > 0 {
+ anchorID := dbMsgs[anchor].ID
+ logger.InfoCF("seahorse", "bootstrap: history edit detected", map[string]any{
+ "conv_id": conv.ConversationID,
+ "db_count": len(dbMsgs),
+ "anchor": anchor,
+ "anchor_id": anchorID,
+ "msg_count": len(messages),
+ "delta_start": anchor + 1,
+ })
+
+ // Delete messages after anchor (also clears context_items)
+ if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, anchorID); err != nil {
+ return fmt.Errorf("bootstrap: delete messages: %w", err)
+ }
+
+ // Re-ingest from anchor+1 to end
+ delta := messages[anchor+1:]
+ if len(delta) > 0 {
+ _, err := e.Ingest(ctx, sessionKey, delta)
+ if err != nil {
+ return fmt.Errorf("bootstrap: re-ingest: %w", err)
+ }
+ }
+ return nil
+ }
+
+ // Normal case: append delta after anchor
+ if anchor >= 0 && anchor < len(messages)-1 {
+ delta := messages[anchor+1:]
+ if len(delta) > 0 {
+ _, err := e.Ingest(ctx, sessionKey, delta)
+ if err != nil {
+ return fmt.Errorf("bootstrap: ingest delta: %w", err)
+ }
+ }
+ } else if anchor == -1 && len(dbMsgs) > 0 {
+ // First message changed (history completely different) - rebuild from scratch
+ logger.InfoCF("seahorse", "bootstrap: history replaced, rebuilding", map[string]any{
+ "conv_id": conv.ConversationID,
+ "db_count": len(dbMsgs),
+ "msg_count": len(messages),
+ })
+ // Delete all existing messages
+ if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, 0); err != nil {
+ return fmt.Errorf("bootstrap: delete all messages: %w", err)
+ }
+ // Re-ingest everything
+ if len(messages) > 0 {
+ _, err := e.Ingest(ctx, sessionKey, messages)
+ if err != nil {
+ return fmt.Errorf("bootstrap: re-ingest all: %w", err)
+ }
+ }
+ } else if anchor == -1 && len(dbMsgs) == 0 {
+ // DB is empty, ingest everything
+ _, err := e.Ingest(ctx, sessionKey, messages)
+ if err != nil {
+ return fmt.Errorf("bootstrap: ingest all: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// truncate shortens a string for logging.
+func truncate(s string, maxLen int) string {
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "..."
+}
+
+// messageMatches compares two messages using (role, content) or (role, parts).
+// TokenCount is NOT compared because it may be re-estimated differently
+// during bootstrap (e.g., via tokenizer.EstimateMessageTokens).
+// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
+// since AddMessageWithParts stores empty Content in DB.
+func messageMatches(a, b Message) bool {
+ if a.Role != b.Role {
+ return false
+ }
+ // If either message has Parts, compare Parts
+ if len(a.Parts) > 0 || len(b.Parts) > 0 {
+ return partsMatch(a.Parts, b.Parts)
+ }
+ // Simple text messages: compare Content
+ return a.Content == b.Content
+}
+
+// partsMatch compares two slices of MessagePart for equality.
+func partsMatch(a, b []MessagePart) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i].Type != b[i].Type {
+ return false
+ }
+ switch a[i].Type {
+ case "text":
+ if a[i].Text != b[i].Text {
+ return false
+ }
+ case "tool_use":
+ if a[i].Name != b[i].Name || a[i].Arguments != b[i].Arguments || a[i].ToolCallID != b[i].ToolCallID {
+ return false
+ }
+ case "tool_result":
+ if a[i].ToolCallID != b[i].ToolCallID || a[i].Text != b[i].Text {
+ return false
+ }
+ case "media":
+ if a[i].MediaURI != b[i].MediaURI || a[i].MimeType != b[i].MimeType {
+ return false
+ }
+ }
+ }
+ return true
+}
diff --git a/pkg/seahorse/short_engine_test.go b/pkg/seahorse/short_engine_test.go
new file mode 100644
index 000000000..d64634fb7
--- /dev/null
+++ b/pkg/seahorse/short_engine_test.go
@@ -0,0 +1,1448 @@
+package seahorse
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+// helper: open a test engine with in-memory DB
+func newTestEngine(t *testing.T) *Engine {
+ t.Helper()
+ db := openTestDB(t)
+ if err := runSchema(db); err != nil {
+ t.Fatalf("migration: %v", err)
+ }
+ store := &Store{db: db}
+ return &Engine{
+ store: store,
+ config: Config{},
+ }
+}
+
+// --- compileSessionPattern ---
+
+func TestCompileSessionPattern(t *testing.T) {
+ tests := []struct {
+ pattern string
+ input string
+ want bool
+ }{
+ // Exact match
+ {"agent:abc123", "agent:abc123", true},
+ {"agent:abc123", "agent:def456", false},
+ // Single * — matches non-colon chars
+ {"agent:*", "agent:abc123", true},
+ {"agent:*", "agent:abc:def", false}, // * doesn't match colons
+ // ** — matches everything including colons
+ {"cron:**", "cron:backup", true},
+ {"cron:**", "cron:backup:daily", true},
+ {"cron:**", "agent:abc", false},
+ // Mixed
+ {"agent:*:sub:**", "agent:abc:sub:def", true},
+ {"agent:*:sub:**", "agent:abc:sub:def:ghi", true},
+ {"agent:*:sub:**", "agent:abc:def", false},
+ // Empty pattern — matches nothing meaningful
+ {"", "", true},
+ {"", "agent:abc", false},
+ }
+
+ for _, tt := range tests {
+ re := compileSessionPattern(tt.pattern)
+ if re == nil && tt.pattern != "" {
+ t.Fatalf("compileSessionPattern(%q) returned nil", tt.pattern)
+ }
+ if tt.pattern == "" {
+ continue
+ }
+ got := re.MatchString(tt.input)
+ if got != tt.want {
+ t.Errorf("compileSessionPattern(%q).Match(%q) = %v, want %v", tt.pattern, tt.input, got, tt.want)
+ }
+ }
+}
+
+// --- Session Pattern Filtering ---
+
+func TestEngineShouldIgnoreSession(t *testing.T) {
+ eng := &Engine{
+ ignorePatterns: compileSessionPatterns([]string{"cron:**", "test:*"}),
+ }
+
+ tests := []struct {
+ key string
+ want bool
+ }{
+ {"cron:backup", true},
+ {"cron:backup:daily", true},
+ {"test:session", true},
+ {"agent:abc", false},
+ {"", false},
+ }
+
+ for _, tt := range tests {
+ got := eng.shouldIgnoreSession(tt.key)
+ if got != tt.want {
+ t.Errorf("shouldIgnoreSession(%q) = %v, want %v", tt.key, got, tt.want)
+ }
+ }
+}
+
+func TestEngineIsStatelessSession(t *testing.T) {
+ eng := &Engine{
+ statelessPatterns: compileSessionPatterns([]string{"agent:*:sub:**"}),
+ }
+
+ tests := []struct {
+ key string
+ want bool
+ }{
+ {"agent:abc:sub:def", true},
+ {"agent:abc:sub:def:ghi", true},
+ {"agent:abc", false},
+ {"cron:backup", false},
+ }
+
+ for _, tt := range tests {
+ got := eng.isStatelessSession(tt.key)
+ if got != tt.want {
+ t.Errorf("isStatelessSession(%q) = %v, want %v", tt.key, got, tt.want)
+ }
+ }
+}
+
+// --- NewEngine ---
+
+func TestNewEngine(t *testing.T) {
+ dir := t.TempDir()
+ dbPath := filepath.Join(dir, "short.db")
+
+ eng, err := NewEngine(Config{DBPath: dbPath}, nil)
+ if err != nil {
+ t.Fatalf("NewEngine: %v", err)
+ }
+ defer eng.Close()
+
+ // DB file should exist
+ if _, pathErr := os.Stat(dbPath); os.IsNotExist(pathErr) {
+ t.Error("expected DB file to be created")
+ }
+
+ // Store should be usable
+ ctx := context.Background()
+ conv, err := eng.store.GetOrCreateConversation(ctx, "test:session")
+ if err != nil {
+ t.Fatalf("store should work: %v", err)
+ }
+ if conv.ConversationID == 0 {
+ t.Error("expected valid conversation ID")
+ }
+
+ // GetRetrieval should return non-nil RetrievalEngine
+ retrieval := eng.GetRetrieval()
+ if retrieval == nil {
+ t.Error("expected GetRetrieval to return non-nil RetrievalEngine")
+ }
+}
+
+func TestNewEngineWithPatterns(t *testing.T) {
+ dir := t.TempDir()
+ dbPath := filepath.Join(dir, "short.db")
+
+ eng, err := NewEngine(Config{
+ DBPath: dbPath,
+ IgnoreSessionPatterns: []string{"cron:**"},
+ StatelessSessionPatterns: []string{"agent:*:sub:**"},
+ }, nil)
+ if err != nil {
+ t.Fatalf("NewEngine: %v", err)
+ }
+ defer eng.Close()
+
+ if !eng.shouldIgnoreSession("cron:backup") {
+ t.Error("expected cron:backup to be ignored")
+ }
+ if !eng.isStatelessSession("agent:abc:sub:def") {
+ t.Error("expected agent:abc:sub:def to be stateless")
+ }
+}
+
+// --- Ingest ---
+
+func TestEngineIngest(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ msgs := []Message{
+ {Role: "user", Content: "hello", TokenCount: 2},
+ {Role: "assistant", Content: "world", TokenCount: 2},
+ }
+
+ result, err := eng.Ingest(ctx, "agent:test", msgs)
+ if err != nil {
+ t.Fatalf("Ingest: %v", err)
+ }
+ if result.MessageCount != 2 {
+ t.Errorf("MessageCount = %d, want 2", result.MessageCount)
+ }
+ if result.TokenCount != 4 {
+ t.Errorf("TokenCount = %d, want 4", result.TokenCount)
+ }
+
+ // Verify messages were stored
+ conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:test")
+ stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if len(stored) != 2 {
+ t.Fatalf("stored messages = %d, want 2", len(stored))
+ }
+ if stored[0].Content != "hello" {
+ t.Errorf("stored[0].Content = %q, want 'hello'", stored[0].Content)
+ }
+
+ // Verify context_items were populated
+ items, _ := eng.store.GetContextItems(ctx, conv.ConversationID)
+ if len(items) != 2 {
+ t.Fatalf("context items = %d, want 2", len(items))
+ }
+ if items[0].ItemType != "message" {
+ t.Errorf("item[0].ItemType = %q, want 'message'", items[0].ItemType)
+ }
+}
+
+func TestEngineIngestIgnoresSession(t *testing.T) {
+ eng := newTestEngine(t)
+ eng.ignorePatterns = compileSessionPatterns([]string{"cron:**"})
+ ctx := context.Background()
+
+ msgs := []Message{{Role: "user", Content: "hello", TokenCount: 2}}
+ result, err := eng.Ingest(ctx, "cron:backup", msgs)
+ if err != nil {
+ t.Fatalf("Ingest: %v", err)
+ }
+ if result != nil {
+ t.Error("expected nil result for ignored session")
+ }
+
+ // Verify no data was stored
+ conv, _ := eng.store.GetConversationBySessionKey(ctx, "cron:backup")
+ if conv != nil {
+ t.Error("expected no conversation for ignored session")
+ }
+}
+
+func TestEngineIngestStatelessSession(t *testing.T) {
+ eng := newTestEngine(t)
+ eng.statelessPatterns = compileSessionPatterns([]string{"agent:*:ro"})
+ ctx := context.Background()
+
+ msgs := []Message{{Role: "user", Content: "hello", TokenCount: 2}}
+ result, err := eng.Ingest(ctx, "agent:abc:ro", msgs)
+ if err != nil {
+ t.Fatalf("Ingest: %v", err)
+ }
+ if result != nil {
+ t.Error("expected nil result for stateless session")
+ }
+}
+
+func TestEngineIngestIncremental(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ // First ingest
+ eng.Ingest(ctx, "agent:test", []Message{
+ {Role: "user", Content: "msg1", TokenCount: 1},
+ })
+ // Second ingest — should append, not replace
+ eng.Ingest(ctx, "agent:test", []Message{
+ {Role: "assistant", Content: "msg2", TokenCount: 1},
+ })
+
+ conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:test")
+ stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if len(stored) != 2 {
+ t.Errorf("stored messages = %d, want 2", len(stored))
+ }
+}
+
+func TestEngineIngestWithParts(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ msgs := []Message{
+ {
+ Role: "assistant",
+ Content: "",
+ TokenCount: 10,
+ Parts: []MessagePart{
+ {Type: "tool_use", Name: "read_file", Arguments: `{"path":"/tmp/test"}`, ToolCallID: "tc_123"},
+ {Type: "text", Text: "here is the file content"},
+ },
+ },
+ }
+
+ result, err := eng.Ingest(ctx, "agent:parts-test", msgs)
+ if err != nil {
+ t.Fatalf("Ingest with parts: %v", err)
+ }
+ if result.MessageCount != 1 {
+ t.Errorf("MessageCount = %d, want 1", result.MessageCount)
+ }
+
+ // Verify message was stored WITH parts
+ conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:parts-test")
+ stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if len(stored) != 1 {
+ t.Fatalf("stored messages = %d, want 1", len(stored))
+ }
+ if len(stored[0].Parts) != 2 {
+ t.Fatalf("stored message parts = %d, want 2", len(stored[0].Parts))
+ }
+ if stored[0].Parts[0].Type != "tool_use" {
+ t.Errorf("part[0].Type = %q, want tool_use", stored[0].Parts[0].Type)
+ }
+ if stored[0].Parts[0].Name != "read_file" {
+ t.Errorf("part[0].Name = %q, want read_file", stored[0].Parts[0].Name)
+ }
+ if stored[0].Parts[0].ToolCallID != "tc_123" {
+ t.Errorf("part[0].ToolCallID = %q, want tc_123", stored[0].Parts[0].ToolCallID)
+ }
+ if stored[0].Parts[1].Type != "text" {
+ t.Errorf("part[1].Type = %q, want text", stored[0].Parts[1].Type)
+ }
+ if stored[0].Parts[1].Text != "here is the file content" {
+ t.Errorf("part[1].Text = %q, want 'here is the file content'", stored[0].Parts[1].Text)
+ }
+}
+
+func TestEngineIngestAssemblePreservesParts(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ // Ingest a message with tool_use parts
+ eng.Ingest(ctx, "agent:parts-roundtrip", []Message{
+ {Role: "user", Content: "list files", TokenCount: 3},
+ {
+ Role: "assistant",
+ Content: "",
+ TokenCount: 5,
+ Parts: []MessagePart{
+ {Type: "tool_use", Name: "bash", Arguments: `{"cmd":"ls"}`, ToolCallID: "tc_1"},
+ {Type: "text", Text: "found 3 files"},
+ },
+ },
+ })
+
+ // Assemble should return messages with parts intact
+ result, err := eng.Assemble(ctx, "agent:parts-roundtrip", AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ if len(result.Messages) != 2 {
+ t.Fatalf("Assemble returned %d messages, want 2", len(result.Messages))
+ }
+
+ // The second message should have Parts populated
+ assistantMsg := result.Messages[1]
+ if len(assistantMsg.Parts) != 2 {
+ t.Fatalf("Assembled assistant message Parts = %d, want 2", len(assistantMsg.Parts))
+ }
+ if assistantMsg.Parts[0].Type != "tool_use" {
+ t.Errorf("part[0].Type = %q, want tool_use", assistantMsg.Parts[0].Type)
+ }
+ if assistantMsg.Parts[0].ToolCallID != "tc_1" {
+ t.Errorf("part[0].ToolCallID = %q, want tc_1", assistantMsg.Parts[0].ToolCallID)
+ }
+}
+
+// --- Session Mutex ---
+
+func TestEngineSessionMutex(t *testing.T) {
+ eng := newTestEngine(t)
+
+ mu1 := eng.getSessionMutex("agent:test")
+ mu2 := eng.getSessionMutex("agent:test")
+ mu3 := eng.getSessionMutex("agent:other")
+
+ if mu1 != mu2 {
+ t.Error("expected same mutex for same session key")
+ }
+ if mu1 == mu3 {
+ t.Error("expected different mutex for different session key")
+ }
+}
+
+// --- Close ---
+
+func TestEngineClose(t *testing.T) {
+ eng := newTestEngine(t)
+ if err := eng.Close(); err != nil {
+ t.Errorf("Close: %v", err)
+ }
+}
+
+// --- compileSessionPatterns (batch) ---
+
+func TestCompileSessionPatterns(t *testing.T) {
+ patterns := compileSessionPatterns([]string{"cron:**", "agent:*:ro"})
+ if len(patterns) != 2 {
+ t.Fatalf("expected 2 patterns, got %d", len(patterns))
+ }
+
+ tests := []struct {
+ input string
+ want bool
+ }{
+ {"cron:backup", true},
+ {"agent:abc:ro", true},
+ {"agent:abc:def", false},
+ {"", false},
+ }
+
+ for _, tt := range tests {
+ matched := false
+ for _, p := range patterns {
+ if p.MatchString(tt.input) {
+ matched = true
+ break
+ }
+ }
+ if matched != tt.want {
+ t.Errorf("patterns.Match(%q) = %v, want %v", tt.input, matched, tt.want)
+ }
+ }
+}
+
+func TestCompileSessionPatternsEmpty(t *testing.T) {
+ patterns := compileSessionPatterns(nil)
+ if len(patterns) != 0 {
+ t.Errorf("expected 0 patterns for nil input, got %d", len(patterns))
+ }
+}
+
+// --- Bootstrap ---
+
+func TestEngineBootstrap(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ msgs := []Message{
+ {Role: "user", Content: "hello", TokenCount: 3},
+ {Role: "assistant", Content: "world", TokenCount: 3},
+ {Role: "user", Content: "how are you", TokenCount: 5},
+ }
+
+ err := eng.Bootstrap(ctx, "agent:boot1", msgs)
+ if err != nil {
+ t.Fatalf("Bootstrap: %v", err)
+ }
+
+ // Verify conversation was created
+ conv, err := eng.store.GetConversationBySessionKey(ctx, "agent:boot1")
+ if err != nil {
+ t.Fatalf("GetConversation: %v", err)
+ }
+ if conv == nil {
+ t.Fatal("expected conversation to exist after bootstrap")
+ }
+
+ // Verify messages were stored
+ stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if err != nil {
+ t.Fatalf("GetMessages: %v", err)
+ }
+ if len(stored) != 3 {
+ t.Fatalf("expected 3 stored messages, got %d", len(stored))
+ }
+ if stored[0].Content != "hello" {
+ t.Errorf("stored[0].Content = %q, want 'hello'", stored[0].Content)
+ }
+
+ // Verify context_items were populated
+ items, err := eng.store.GetContextItems(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetContextItems: %v", err)
+ }
+ if len(items) != 3 {
+ t.Fatalf("expected 3 context items, got %d", len(items))
+ }
+}
+
+func TestEngineBootstrapEmpty(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ err := eng.Bootstrap(ctx, "agent:empty", nil)
+ if err != nil {
+ t.Fatalf("Bootstrap empty: %v", err)
+ }
+
+ // No conversation should be created for empty messages
+ conv, _ := eng.store.GetConversationBySessionKey(ctx, "agent:empty")
+ if conv != nil {
+ t.Error("expected no conversation for empty bootstrap")
+ }
+}
+
+func TestEngineBootstrapIdempotent(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ msgs := []Message{
+ {Role: "user", Content: "hello", TokenCount: 3},
+ {Role: "assistant", Content: "world", TokenCount: 3},
+ }
+
+ // Bootstrap twice with same messages
+ eng.Bootstrap(ctx, "agent:idem", msgs)
+ eng.Bootstrap(ctx, "agent:idem", msgs)
+
+ // Should still have exactly 2 messages (no duplicates)
+ conv, _ := eng.store.GetConversationBySessionKey(ctx, "agent:idem")
+ if conv == nil {
+ t.Fatal("expected conversation")
+ }
+ stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if len(stored) != 2 {
+ t.Errorf("expected 2 messages (idempotent), got %d", len(stored))
+ }
+}
+
+func TestEngineBootstrapDelta(t *testing.T) {
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ // First bootstrap with 2 messages
+ msgs1 := []Message{
+ {Role: "user", Content: "hello", TokenCount: 3},
+ {Role: "assistant", Content: "world", TokenCount: 3},
+ }
+ eng.Bootstrap(ctx, "agent:delta", msgs1)
+
+ // Second bootstrap with 4 messages (2 existing + 2 new)
+ msgs2 := []Message{
+ {Role: "user", Content: "hello", TokenCount: 3},
+ {Role: "assistant", Content: "world", TokenCount: 3},
+ {Role: "user", Content: "new question", TokenCount: 5},
+ {Role: "assistant", Content: "new answer", TokenCount: 5},
+ }
+ eng.Bootstrap(ctx, "agent:delta", msgs2)
+
+ conv, _ := eng.store.GetConversationBySessionKey(ctx, "agent:delta")
+ if conv == nil {
+ t.Fatal("expected conversation")
+ }
+ stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if len(stored) != 4 {
+ t.Errorf("expected 4 messages (delta), got %d", len(stored))
+ }
+}
+
+func TestBootstrapPopulatesContextItems(t *testing.T) {
+ // Bootstrap ingests messages and populates context_items
+ e := newTestEngine(t)
+ ctx := context.Background()
+
+ messages := []Message{
+ {Role: "user", Content: "hello from bootstrap test", TokenCount: 10},
+ {Role: "assistant", Content: "hi there", TokenCount: 5},
+ {Role: "user", Content: "how are you", TokenCount: 5},
+ {Role: "assistant", Content: "doing well", TokenCount: 5},
+ {Role: "user", Content: "great news", TokenCount: 5},
+ {Role: "assistant", Content: "awesome", TokenCount: 5},
+ {Role: "user", Content: "lets code", TokenCount: 5},
+ {Role: "assistant", Content: "sure thing", TokenCount: 5},
+ }
+
+ // Bootstrap should ingest and rebuild context_items
+ err := e.Bootstrap(ctx, "test-bootstrap-rebuild", messages)
+ if err != nil {
+ t.Fatalf("Bootstrap: %v", err)
+ }
+
+ // After bootstrap, context_items should be populated
+ conv, _ := e.store.GetOrCreateConversation(ctx, "test-bootstrap-rebuild")
+ items, err := e.store.GetContextItems(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetContextItems: %v", err)
+ }
+
+ if len(items) == 0 {
+ t.Error("expected context_items to be populated after Bootstrap, got 0 items")
+ }
+
+ // Should have one item per message
+ if len(items) != len(messages) {
+ t.Errorf("expected %d context items, got %d", len(messages), len(items))
+ }
+}
+
+func TestBootstrapDeltaPreservesOrder(t *testing.T) {
+ // When Bootstrap does delta ingest, context_items should maintain
+ // correct order with new messages appended after anchor.
+ e := newTestEngine(t)
+ ctx := context.Background()
+ sessionKey := "test-bootstrap-delta-order"
+
+ // First: bootstrap with 4 messages
+ initialMsgs := []Message{
+ {Role: "user", Content: "msg1", TokenCount: 5},
+ {Role: "assistant", Content: "msg2", TokenCount: 5},
+ {Role: "user", Content: "msg3", TokenCount: 5},
+ {Role: "assistant", Content: "msg4", TokenCount: 5},
+ }
+ err := e.Bootstrap(ctx, sessionKey, initialMsgs)
+ if err != nil {
+ t.Fatalf("first Bootstrap: %v", err)
+ }
+
+ conv, _ := e.store.GetOrCreateConversation(ctx, sessionKey)
+ items1, _ := e.store.GetContextItems(ctx, conv.ConversationID)
+ if len(items1) != 4 {
+ t.Fatalf("after first bootstrap: expected 4 items, got %d", len(items1))
+ }
+
+ // Now bootstrap again with 6 messages (4 existing + 2 new)
+ // The delta (msg5, msg6) should be appended
+ updatedMsgs := []Message{
+ {Role: "user", Content: "msg1", TokenCount: 5},
+ {Role: "assistant", Content: "msg2", TokenCount: 5},
+ {Role: "user", Content: "msg3", TokenCount: 5},
+ {Role: "assistant", Content: "msg4", TokenCount: 5},
+ {Role: "user", Content: "msg5", TokenCount: 5},
+ {Role: "assistant", Content: "msg6", TokenCount: 5},
+ }
+ err = e.Bootstrap(ctx, sessionKey, updatedMsgs)
+ if err != nil {
+ t.Fatalf("second Bootstrap: %v", err)
+ }
+
+ items2, _ := e.store.GetContextItems(ctx, conv.ConversationID)
+ if len(items2) != 6 {
+ t.Errorf("after delta bootstrap: expected 6 items, got %d", len(items2))
+ }
+}
+
+func TestBootstrapHistoryEditFirstMessageChanged(t *testing.T) {
+ // When the first message changes (anchor = -1), Bootstrap should rebuild
+ // from scratch without panicking (regression test for index out of range [-1])
+ e := newTestEngine(t)
+ ctx := context.Background()
+ sessionKey := "test-bootstrap-history-edit"
+
+ // First: bootstrap with some messages
+ initialMsgs := []Message{
+ {Role: "user", Content: "original first", TokenCount: 5},
+ {Role: "assistant", Content: "response", TokenCount: 5},
+ {Role: "user", Content: "question", TokenCount: 5},
+ }
+ err := e.Bootstrap(ctx, sessionKey, initialMsgs)
+ if err != nil {
+ t.Fatalf("first Bootstrap: %v", err)
+ }
+
+ // Now bootstrap with completely different messages (first message changed)
+ // This should NOT panic - it should rebuild from scratch
+ editedMsgs := []Message{
+ {Role: "user", Content: "DIFFERENT first message", TokenCount: 5},
+ {Role: "assistant", Content: "DIFFERENT response", TokenCount: 5},
+ {Role: "user", Content: "DIFFERENT question", TokenCount: 5},
+ }
+ err = e.Bootstrap(ctx, sessionKey, editedMsgs)
+ if err != nil {
+ t.Fatalf("second Bootstrap (history edit): %v", err)
+ }
+
+ conv, _ := e.store.GetOrCreateConversation(ctx, sessionKey)
+ stored, _ := e.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+
+ // Should have the NEW messages (history was rebuilt)
+ if len(stored) != 3 {
+ t.Errorf("expected 3 messages after history edit, got %d", len(stored))
+ }
+ if len(stored) > 0 && stored[0].Content != "DIFFERENT first message" {
+ t.Errorf("first message = %q, want 'DIFFERENT first message'", stored[0].Content)
+ }
+}
+
+func TestBootstrapSameContentDifferentTokenCountNoRebuild(t *testing.T) {
+ // Bootstrap should NOT rebuild when content is identical but TokenCount differs.
+ // This happens when TokenCount is re-estimated (e.g., via tokenizer.EstimateMessageTokens)
+ // during bootstrap, which may give slightly different values.
+ e := newTestEngine(t)
+ ctx := context.Background()
+ sessionKey := "test-bootstrap-token-diff"
+
+ // First: bootstrap with some messages
+ initialMsgs := []Message{
+ {Role: "user", Content: "hello world", TokenCount: 10},
+ {Role: "assistant", Content: "hi there", TokenCount: 5},
+ }
+ err := e.Bootstrap(ctx, sessionKey, initialMsgs)
+ if err != nil {
+ t.Fatalf("first Bootstrap: %v", err)
+ }
+
+ conv, _ := e.store.GetOrCreateConversation(ctx, sessionKey)
+ storedBefore, _ := e.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+
+ // Second: bootstrap with SAME content but DIFFERENT TokenCount
+ // This should be a no-op (not rebuild)
+ sameContentMsgs := []Message{
+ {Role: "user", Content: "hello world", TokenCount: 999}, // Different token count!
+ {Role: "assistant", Content: "hi there", TokenCount: 888}, // Different token count!
+ }
+ err = e.Bootstrap(ctx, sessionKey, sameContentMsgs)
+ if err != nil {
+ t.Fatalf("second Bootstrap: %v", err)
+ }
+
+ storedAfter, _ := e.store.GetMessages(ctx, conv.ConversationID, 10, 0)
+
+ // Should have same number of messages (no rebuild)
+ if len(storedAfter) != len(storedBefore) {
+ t.Errorf("expected %d messages (no rebuild), got %d", len(storedBefore), len(storedAfter))
+ }
+
+ // Message IDs should be the same (no delete+re-ingest)
+ for i := range storedBefore {
+ if storedBefore[i].ID != storedAfter[i].ID {
+ t.Errorf("message %d ID changed: before=%d, after=%d (should be no-op)",
+ i, storedBefore[i].ID, storedAfter[i].ID)
+ }
+ }
+}
+
+// --- Session Mutex ---
+
+func TestEngineSessionMutexSharded(t *testing.T) {
+ eng := newTestEngine(t)
+
+ // Same session key should always return the same mutex (deterministic hash)
+ mu1 := eng.getSessionMutex("agent:test")
+ mu2 := eng.getSessionMutex("agent:test")
+ if mu1 != mu2 {
+ t.Error("expected same mutex for same session key")
+ }
+
+ // Different session keys may share the same shard (hash collision)
+ // This is expected behavior - we just need bounded memory, not unique locks
+ mu3 := eng.getSessionMutex("agent:other")
+
+ // Both mutexes should be valid and usable
+ mu1.Lock()
+ mu1.Unlock()
+ mu3.Lock()
+ mu3.Unlock()
+}
+
+func TestEngineSessionMutexBoundedMemory(t *testing.T) {
+ // Verify that session mutexes use bounded memory (256 shards)
+ eng := newTestEngine(t)
+
+ // Get mutexes for many different sessions
+ seen := make(map[*sync.Mutex]bool)
+ for i := 0; i < 1000; i++ {
+ sessionKey := fmt.Sprintf("agent:session-%d", i)
+ mu := eng.getSessionMutex(sessionKey)
+ seen[mu] = true
+ }
+
+ // With 256 shards and 1000 sessions, we should see at most 256 unique mutexes
+ // (likely fewer due to hash collisions)
+ if len(seen) > 256 {
+ t.Errorf("expected at most 256 unique mutexes (shards), got %d", len(seen))
+ }
+}
+
+func TestEngineSessionMutexConsistentHash(t *testing.T) {
+ // Same session key should always hash to the same shard
+ eng := newTestEngine(t)
+
+ sessionKey := "agent:consistent-hash-test"
+ mu1 := eng.getSessionMutex(sessionKey)
+ mu2 := eng.getSessionMutex(sessionKey)
+ mu3 := eng.getSessionMutex(sessionKey)
+
+ if mu1 != mu2 || mu2 != mu3 {
+ t.Error("hash function should be deterministic - same key must map to same shard")
+ }
+}
+
+// --- Summary Role ---
+
+func TestAssemblerSummaryRoleNotUser(t *testing.T) {
+ // Summaries should use "system" role, not "user"
+ eng := newTestEngine(t)
+ ctx := context.Background()
+
+ // Ingest messages
+ eng.Ingest(ctx, "agent:summary-role-test", []Message{
+ {Role: "user", Content: "hello", TokenCount: 5},
+ {Role: "assistant", Content: "world", TokenCount: 5},
+ })
+
+ conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:summary-role-test")
+
+ // Create a summary and add it to context
+ sum, err := eng.store.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Content: "Test summary content",
+ TokenCount: 10,
+ Kind: SummaryKindCondensed,
+ Depth: 1,
+ })
+ if err != nil {
+ t.Fatalf("CreateSummary: %v", err)
+ }
+ eng.store.AppendContextSummary(ctx, conv.ConversationID, sum.SummaryID)
+
+ // Assemble and check summary message role
+ result, err := eng.Assemble(ctx, "agent:summary-role-test", AssembleInput{Budget: 1000})
+ if err != nil {
+ t.Fatalf("Assemble: %v", err)
+ }
+
+ // Find the summary message (should have XML content with )
+ for _, msg := range result.Messages {
+ if strings.Contains(msg.Content, "= 5
+ // This tests the bug: when depth=2 is missing, the loop breaks and depth=3 is never checked
+ // Need > FreshTailCount(32) summaries so they are not all in fresh tail
+ // Depth 0: 3 summaries (not enough), Depth 1: 3 summaries (not enough)
+ // Depth 2: 0 summaries (missing), Depth 3: 40 summaries (enough)
+ depths := []int{0, 0, 0, 1, 1, 1}
+ for i := 0; i < 40; i++ {
+ depths = append(depths, 3)
+ }
+ now := time.Now().UTC()
+
+ for i, depth := range depths {
+ sum, createErr := e.store.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: depth,
+ Content: fmt.Sprintf("summary depth %d #%d", depth, i),
+ TokenCount: 10,
+ EarliestAt: &now,
+ LatestAt: &now,
+ })
+ if createErr != nil {
+ t.Fatalf("CreateSummary: %v", createErr)
+ }
+ // Add to context items (not in fresh tail)
+ if appendErr := e.store.AppendContextSummary(ctx, conv.ConversationID, sum.SummaryID); appendErr != nil {
+ t.Fatalf("AppendContextSummary: %v", appendErr)
+ }
+ }
+
+ // Initialize compaction engine (lazy init)
+ e.initCompactionOnce()
+
+ // Call selectShallowestCondensationCandidate
+ candidates, err := e.compaction.selectShallowestCondensationCandidate(ctx, conv.ConversationID, false)
+ if err != nil {
+ t.Fatalf("selectShallowestCondensationCandidate: %v", err)
+ }
+
+ // Should find depth=0 (shallowest) with 5 summaries
+ if candidates == nil {
+ t.Fatal("expected candidates, got nil")
+ }
+ if len(candidates) < CondensedMinFanout {
+ t.Errorf("expected at least %d candidates, got %d", CondensedMinFanout, len(candidates))
+ }
+
+ // Verify all returned summaries have the same depth
+ if len(candidates) > 0 {
+ expectedDepth := candidates[0].Depth
+ for _, c := range candidates[1:] {
+ if c.Depth != expectedDepth {
+ t.Errorf("candidates have mixed depths: %d vs %d", expectedDepth, c.Depth)
+ }
+ }
+ }
+}
diff --git a/pkg/seahorse/short_retrieval.go b/pkg/seahorse/short_retrieval.go
new file mode 100644
index 000000000..3e94eec14
--- /dev/null
+++ b/pkg/seahorse/short_retrieval.go
@@ -0,0 +1,212 @@
+package seahorse
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// ParseLastDuration parses a "last" duration string like "6h", "7d", "2w", "1m".
+// Returns the duration and nil error, or zero and error if invalid.
+func ParseLastDuration(s string) (time.Duration, error) {
+ if s == "" {
+ return 0, fmt.Errorf("empty duration")
+ }
+
+ re := regexp.MustCompile(`^(\d+)([hdwm])$`)
+ matches := re.FindStringSubmatch(s)
+ if matches == nil {
+ return 0, fmt.Errorf("invalid duration format: %q (use format like 6h, 7d, 2w, 1m)", s)
+ }
+
+ value, _ := strconv.Atoi(matches[1])
+ unit := matches[2]
+
+ switch unit {
+ case "h":
+ return time.Duration(value) * time.Hour, nil
+ case "d":
+ return time.Duration(value) * 24 * time.Hour, nil
+ case "w":
+ return time.Duration(value) * 7 * 24 * time.Hour, nil
+ case "m":
+ return time.Duration(value) * 30 * 24 * time.Hour, nil
+ default:
+ return 0, fmt.Errorf("unknown unit: %q", unit)
+ }
+}
+
+// GrepInput controls search across summaries and messages.
+type GrepInput struct {
+ Pattern string `json:"pattern"`
+ Scope string `json:"scope,omitempty"` // "both" (default), "summary", or "message"
+ Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
+ AllConversations bool `json:"allConversations,omitempty"`
+ Since *time.Time `json:"since,omitempty"`
+ Before *time.Time `json:"before,omitempty"`
+ Last string `json:"last,omitempty"` // shortcut: "6h", "7d", "2w", "1m"
+ Limit int `json:"limit,omitempty"`
+}
+
+// GrepResult contains search results.
+type GrepResult struct {
+ Success bool `json:"success"`
+ Summaries []GrepSummaryResult `json:"summaries"`
+ Messages []GrepMessageResult `json:"messages"`
+ TotalSummaries int `json:"totalSummaries"`
+ TotalMessages int `json:"totalMessages"`
+ Hint string `json:"hint,omitempty"`
+}
+
+// GrepSummaryResult is a summary match from grep.
+type GrepSummaryResult struct {
+ ID string `json:"id"`
+ Content string `json:"content"`
+ Depth int `json:"depth"`
+ Kind SummaryKind `json:"kind"`
+ ConversationID int64 `json:"conversationId"`
+ // Rank is the bm25 relevance score (negative value, lower = better match).
+ // Examples: -5.0 = excellent match, -2.0 = good match, -0.5 = partial match.
+ Rank float64 `json:"rank,omitempty"`
+}
+
+// GrepMessageResult is a message match from grep.
+type GrepMessageResult struct {
+ ID int64 `json:"id,string"`
+ Snippet string `json:"snippet"`
+ Role string `json:"role"`
+ ConversationID int64 `json:"conversationId"`
+ Rank float64 `json:"rank,omitempty"` // Relevance score (more negative = better match)
+}
+
+// ExpandMessagesResult contains expanded messages.
+type ExpandMessagesResult struct {
+ Messages []Message `json:"messages"`
+ TokenCount int `json:"tokenCount"`
+}
+
+// Grep searches summaries and messages for matching content.
+func (r *RetrievalEngine) Grep(ctx context.Context, input GrepInput) (*GrepResult, error) {
+ if input.Pattern == "" {
+ return nil, fmt.Errorf("grep: pattern is required")
+ }
+
+ limit := input.Limit
+ if limit == 0 {
+ limit = 20
+ }
+
+ // Handle Last parameter: convert to Since
+ since := input.Since
+ if input.Last != "" {
+ dur, err := ParseLastDuration(input.Last)
+ if err != nil {
+ return nil, fmt.Errorf("grep: invalid last: %w", err)
+ }
+ t := time.Now().UTC().Add(-dur)
+ since = &t
+ }
+
+ // Auto-detect mode: use LIKE if pattern contains %, otherwise full-text
+ mode := ""
+ if strings.Contains(input.Pattern, "%") {
+ mode = "like"
+ }
+
+ searchInput := SearchInput{
+ Pattern: input.Pattern,
+ Mode: mode,
+ Role: input.Role,
+ AllConversations: input.AllConversations,
+ Since: since,
+ Before: input.Before,
+ Limit: limit,
+ }
+
+ result := &GrepResult{
+ Success: true,
+ Summaries: make([]GrepSummaryResult, 0),
+ Messages: make([]GrepMessageResult, 0),
+ TotalSummaries: 0,
+ TotalMessages: 0,
+ }
+
+ // Determine scope
+ scope := input.Scope
+ if scope == "" {
+ scope = "both"
+ }
+
+ // Search summaries if requested
+ if scope == "both" || scope == "summary" {
+ sumResults, err := r.store.SearchSummaries(ctx, searchInput)
+ if err != nil {
+ return nil, fmt.Errorf("search summaries: %w", err)
+ }
+ for _, sr := range sumResults {
+ if sr.SummaryID != "" {
+ result.Summaries = append(result.Summaries, GrepSummaryResult{
+ ID: sr.SummaryID,
+ Content: sr.Content,
+ Depth: sr.Depth,
+ Kind: sr.Kind,
+ ConversationID: sr.ConversationID,
+ Rank: sr.Rank,
+ })
+ }
+ }
+ if len(sumResults) > 0 {
+ result.TotalSummaries = sumResults[0].TotalCount
+ }
+ }
+
+ // Search messages if requested
+ if scope == "both" || scope == "message" {
+ msgResults, err := r.store.SearchMessages(ctx, searchInput)
+ if err != nil {
+ return nil, fmt.Errorf("search messages: %w", err)
+ }
+ for _, sr := range msgResults {
+ if sr.MessageID > 0 {
+ result.Messages = append(result.Messages, GrepMessageResult{
+ ID: sr.MessageID,
+ Snippet: sr.Snippet,
+ Role: sr.Role,
+ ConversationID: sr.ConversationID,
+ Rank: sr.Rank,
+ })
+ }
+ }
+ if len(msgResults) > 0 {
+ result.TotalMessages = msgResults[0].TotalCount
+ }
+ }
+
+ // Add hint if no results
+ if len(result.Summaries) == 0 && len(result.Messages) == 0 {
+ result.Hint = "No matches. Try: %keyword% for fuzzy search, or all_conversations: true"
+ }
+
+ return result, nil
+}
+
+// ExpandMessages retrieves full message content by IDs.
+func (r *RetrievalEngine) ExpandMessages(ctx context.Context, messageIDs []int64) (*ExpandMessagesResult, error) {
+ result := &ExpandMessagesResult{
+ Messages: make([]Message, 0, len(messageIDs)),
+ }
+
+ for _, msgID := range messageIDs {
+ msg, err := r.store.GetMessageByID(ctx, msgID)
+ if err != nil {
+ continue
+ }
+ result.Messages = append(result.Messages, *msg)
+ result.TokenCount += msg.TokenCount
+ }
+
+ return result, nil
+}
diff --git a/pkg/seahorse/short_retrieval_test.go b/pkg/seahorse/short_retrieval_test.go
new file mode 100644
index 000000000..9d9bc3640
--- /dev/null
+++ b/pkg/seahorse/short_retrieval_test.go
@@ -0,0 +1,362 @@
+package seahorse
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+)
+
+// --- Retrieval Tests ---
+
+func newTestRetrieval(t *testing.T) (*RetrievalEngine, *Store, int64) {
+ t.Helper()
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:retrieval")
+ return &RetrievalEngine{store: s}, s, conv.ConversationID
+}
+
+func TestRetrievalGrepSummaries(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "数据库连接配置说明",
+ TokenCount: 50,
+ })
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "API endpoint documentation",
+ TokenCount: 50,
+ })
+
+ // FTS5 search (trigram, needs >= 3 chars)
+ results, err := r.Grep(ctx, GrepInput{
+ Pattern: "数据库连",
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ if len(results.Summaries) == 0 {
+ t.Error("expected at least 1 FTS result")
+ }
+
+ // LIKE search with wildcard
+ results, err = r.Grep(ctx, GrepInput{
+ Pattern: "%endpoint%",
+ })
+ if err != nil {
+ t.Fatalf("Grep LIKE: %v", err)
+ }
+ if len(results.Summaries) == 0 {
+ t.Error("expected at least 1 LIKE result")
+ }
+}
+
+func TestRetrievalGrepMessages(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ s.AddMessage(ctx, convID, "user", "find this message about testing", 5)
+ s.AddMessage(ctx, convID, "user", "unrelated content here", 5)
+
+ results, err := r.Grep(ctx, GrepInput{
+ Pattern: "testing",
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ if len(results.Messages) == 0 {
+ t.Error("expected at least 1 result for 'testing'")
+ }
+}
+
+func TestRetrievalExpandMessages(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ msg, _ := s.AddMessage(ctx, convID, "user", "expand this message", 10)
+
+ result, err := r.ExpandMessages(ctx, []int64{msg.ID})
+ if err != nil {
+ t.Fatalf("ExpandMessages: %v", err)
+ }
+ if len(result.Messages) != 1 {
+ t.Errorf("Messages = %d, want 1", len(result.Messages))
+ }
+ if result.Messages[0].Content != "expand this message" {
+ t.Errorf("Content = %q, want 'expand this message'", result.Messages[0].Content)
+ }
+}
+
+func TestRetrievalExpandMultipleMessages(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ msg1, _ := s.AddMessage(ctx, convID, "user", "first message", 10)
+ msg2, _ := s.AddMessage(ctx, convID, "assistant", "second message", 10)
+ msg3, _ := s.AddMessage(ctx, convID, "user", "third message", 10)
+
+ result, err := r.ExpandMessages(ctx, []int64{msg1.ID, msg2.ID, msg3.ID})
+ if err != nil {
+ t.Fatalf("ExpandMessages: %v", err)
+ }
+ if len(result.Messages) != 3 {
+ t.Errorf("Messages = %d, want 3", len(result.Messages))
+ }
+ if result.TokenCount != 30 {
+ t.Errorf("TokenCount = %d, want 30", result.TokenCount)
+ }
+}
+
+func TestRetrievalGrepWithTimeFilter(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ now := time.Now().UTC()
+ before := now.Add(-2 * time.Hour)
+
+ // Create messages at different times
+ s.AddMessage(ctx, convID, "user", "old message about auth", 5)
+ s.AddMessage(ctx, convID, "user", "recent message about auth", 5)
+
+ // Search with time filter
+ results, err := r.Grep(ctx, GrepInput{
+ Pattern: "auth",
+ Since: &before,
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ _ = results // Just verify no error
+}
+
+func TestRetrievalGrepAllConversations(t *testing.T) {
+ r, s, _ := newTestRetrieval(t)
+ ctx := context.Background()
+
+ // Create another conversation
+ conv2, _ := s.GetOrCreateConversation(ctx, "test:retrieval2")
+
+ // Add messages to both
+ s.AddMessage(ctx, conv2.ConversationID, "user", "unique keyword xyz", 5)
+
+ // Search all conversations
+ results, err := r.Grep(ctx, GrepInput{
+ Pattern: "xyz",
+ AllConversations: true,
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ if len(results.Messages) == 0 {
+ t.Error("expected to find message in other conversation")
+ }
+}
+
+// --- Last Duration Parsing Tests ---
+
+func TestParseLastDuration(t *testing.T) {
+ tests := []struct {
+ input string
+ wantDur time.Duration
+ wantErr bool
+ }{
+ {"6h", 6 * time.Hour, false},
+ {"1d", 24 * time.Hour, false},
+ {"7d", 7 * 24 * time.Hour, false},
+ {"2w", 14 * 24 * time.Hour, false},
+ {"1m", 30 * 24 * time.Hour, false}, // month = 30 days
+ {"3m", 90 * 24 * time.Hour, false},
+ {"", 0, true},
+ {"invalid", 0, true},
+ {"5x", 0, true}, // unknown unit
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ got, err := ParseLastDuration(tt.input)
+ if tt.wantErr {
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != tt.wantDur {
+ t.Errorf("ParseLastDuration(%q) = %v, want %v", tt.input, got, tt.wantDur)
+ }
+ }
+ })
+ }
+}
+
+// --- Role Filter Tests ---
+
+func TestRetrievalGrepRoleFilter(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ s.AddMessage(ctx, convID, "user", "user message about alpha", 5)
+ s.AddMessage(ctx, convID, "assistant", "assistant reply about alpha", 5)
+ s.AddMessage(ctx, convID, "user", "another user message", 5)
+
+ // Search all roles
+ allResults, err := r.Grep(ctx, GrepInput{
+ Pattern: "alpha",
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ if len(allResults.Messages) != 2 {
+ t.Errorf("expected 2 messages, got %d", len(allResults.Messages))
+ }
+
+ // Search user only
+ userResults, err := r.Grep(ctx, GrepInput{
+ Pattern: "alpha",
+ Role: "user",
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ if len(userResults.Messages) != 1 {
+ t.Errorf("expected 1 user message, got %d", len(userResults.Messages))
+ }
+ if userResults.Messages[0].Role != "user" {
+ t.Errorf("expected role=user, got %s", userResults.Messages[0].Role)
+ }
+
+ // Search assistant only
+ assistantResults, err := r.Grep(ctx, GrepInput{
+ Pattern: "alpha",
+ Role: "assistant",
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ if len(assistantResults.Messages) != 1 {
+ t.Errorf("expected 1 assistant message, got %d", len(assistantResults.Messages))
+ }
+}
+
+// --- Last Parameter Tests ---
+
+func TestRetrievalGrepWithLast(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ // Add messages (we can't control timestamps in SQLite easily,
+ // but we can verify the parameter is parsed correctly)
+ s.AddMessage(ctx, convID, "user", "recent message about testing", 5)
+
+ // Test that Last parameter is converted to Since
+ results, err := r.Grep(ctx, GrepInput{
+ Pattern: "testing",
+ Last: "1d", // last 1 day
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ // Should still find the message since it's recent
+ if len(results.Messages) == 0 {
+ t.Error("expected to find recent message")
+ }
+}
+
+// TestRetrievalGrepRoleFilterWithSummaries tests that role filter works when
+// searching both summaries and messages (summaries don't have role column).
+func TestRetrievalGrepRoleFilterWithSummaries(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ // Create a summary (no role column)
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "summary about testing",
+ TokenCount: 50,
+ })
+
+ // Add messages with different roles
+ s.AddMessage(ctx, convID, "user", "user message about testing", 5)
+ s.AddMessage(ctx, convID, "assistant", "assistant reply about testing", 5)
+
+ // Search with role filter and scope=both (default), using LIKE mode (%)
+ // This should NOT error even though summaries don't have role column
+ bothResults, err := r.Grep(ctx, GrepInput{
+ Pattern: "%testing%", // LIKE mode to trigger the bug
+ Role: "user",
+ Scope: "both",
+ })
+ if err != nil {
+ t.Fatalf("Grep with role and scope=both: %v", err)
+ }
+
+ // Should only return user messages, not summaries or assistant messages
+ if len(bothResults.Messages) != 1 {
+ t.Errorf("expected 1 user message, got %d", len(bothResults.Messages))
+ }
+ if len(bothResults.Messages) > 0 && bothResults.Messages[0].Role != "user" {
+ t.Errorf("expected role=user, got %s", bothResults.Messages[0].Role)
+ }
+
+ // Summaries should be empty since they don't have roles to filter
+ // (or we could return all summaries - either is acceptable)
+}
+
+// TestRetrievalGrepTotalCounts tests that grep returns total counts.
+func TestRetrievalGrepTotalCounts(t *testing.T) {
+ r, s, convID := newTestRetrieval(t)
+ ctx := context.Background()
+
+ // Create 3 summaries
+ for i := 0; i < 3; i++ {
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: convID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("summary about testing %d", i),
+ TokenCount: 50,
+ })
+ }
+
+ // Add 5 messages
+ for i := 0; i < 5; i++ {
+ s.AddMessage(ctx, convID, "user", fmt.Sprintf("message about testing %d", i), 5)
+ }
+
+ // Search with limit smaller than total
+ results, err := r.Grep(ctx, GrepInput{
+ Pattern: "%testing%", // LIKE mode
+ Scope: "both",
+ Limit: 2,
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+
+ // Should return limited results
+ if len(results.Summaries) > 2 {
+ t.Errorf("expected at most 2 summaries, got %d", len(results.Summaries))
+ }
+ if len(results.Messages) > 2 {
+ t.Errorf("expected at most 2 messages, got %d", len(results.Messages))
+ }
+
+ // But total counts should reflect all matches
+ if results.TotalSummaries != 3 {
+ t.Errorf("expected TotalSummaries=3, got %d", results.TotalSummaries)
+ }
+ if results.TotalMessages != 5 {
+ t.Errorf("expected TotalMessages=5, got %d", results.TotalMessages)
+ }
+}
diff --git a/pkg/seahorse/store.go b/pkg/seahorse/store.go
new file mode 100644
index 000000000..3d85c7b9c
--- /dev/null
+++ b/pkg/seahorse/store.go
@@ -0,0 +1,1532 @@
+package seahorse
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+ "time"
+)
+
+// Store provides SQLite storage for seahorse.
+type Store struct {
+ db *sql.DB
+}
+
+// CreateSummaryInput holds parameters for creating a summary.
+type CreateSummaryInput struct {
+ ConversationID int64
+ Kind SummaryKind
+ Depth int
+ Content string
+ TokenCount int
+ EarliestAt *time.Time
+ LatestAt *time.Time
+ DescendantCount int
+ DescendantTokenCount int
+ SourceMessageTokens int
+ Model string
+ ParentIDs []string // For condensed: child summary IDs being condensed
+}
+
+// --- Conversation Operations ---
+
+// GetOrCreateConversation returns the conversation for a sessionKey, creating if needed.
+func (s *Store) GetOrCreateConversation(ctx context.Context, sessionKey string) (*Conversation, error) {
+ // Try to get first
+ conv, err := s.GetConversationBySessionKey(ctx, sessionKey)
+ if err != nil {
+ return nil, err
+ }
+ if conv != nil {
+ return conv, nil
+ }
+
+ // Create
+ result, err := s.db.ExecContext(ctx,
+ "INSERT INTO conversations (session_key) VALUES (?)",
+ sessionKey,
+ )
+ if err != nil {
+ // Race: another goroutine may have inserted
+ if isUniqueViolation(err) {
+ return s.GetConversationBySessionKey(ctx, sessionKey)
+ }
+ return nil, fmt.Errorf("create conversation: %w", err)
+ }
+ id, _ := result.LastInsertId()
+ return &Conversation{
+ ConversationID: id,
+ SessionKey: sessionKey,
+ }, nil
+}
+
+// GetConversationBySessionKey retrieves a conversation by session key.
+func (s *Store) GetConversationBySessionKey(ctx context.Context, sessionKey string) (*Conversation, error) {
+ var conv Conversation
+ var createdAt, updatedAt string
+ err := s.db.QueryRowContext(ctx,
+ "SELECT conversation_id, session_key, created_at, updated_at FROM conversations WHERE session_key = ?",
+ sessionKey,
+ ).Scan(&conv.ConversationID, &conv.SessionKey, &createdAt, &updatedAt)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, fmt.Errorf("get conversation by session key: %w", err)
+ }
+ conv.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ conv.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
+ return &conv, nil
+}
+
+// GetSessionStatus returns status for a specific session.
+func (s *Store) GetSessionStatus(ctx context.Context, sessionKey string) (*SessionStatus, error) {
+ conv, err := s.GetConversationBySessionKey(ctx, sessionKey)
+ if err != nil {
+ return nil, err
+ }
+ if conv == nil {
+ return nil, nil
+ }
+
+ msgCount, _ := s.GetMessageCount(ctx, conv.ConversationID)
+ sumCount, _ := s.getSummaryCount(ctx, conv.ConversationID)
+ tokenCount, _ := s.GetContextTokenCount(ctx, conv.ConversationID)
+
+ oldest, newest, _ := s.getMessageTimeRange(ctx, conv.ConversationID)
+
+ return &SessionStatus{
+ SessionKey: conv.SessionKey,
+ ConversationID: conv.ConversationID,
+ Messages: msgCount,
+ TotalTokens: tokenCount,
+ Summaries: sumCount,
+ OldestAt: oldest,
+ NewestAt: newest,
+ }, nil
+}
+
+// GetAllSessionStatuses returns status for all sessions.
+func (s *Store) GetAllSessionStatuses(ctx context.Context) ([]SessionStatus, error) {
+ rows, err := s.db.QueryContext(ctx, "SELECT session_key FROM conversations")
+ if err != nil {
+ return nil, fmt.Errorf("list sessions: %w", err)
+ }
+ defer rows.Close()
+
+ var statuses []SessionStatus
+ for rows.Next() {
+ var sessionKey string
+ if err := rows.Scan(&sessionKey); err != nil {
+ continue
+ }
+ status, err := s.GetSessionStatus(ctx, sessionKey)
+ if err != nil {
+ continue
+ }
+ if status != nil {
+ statuses = append(statuses, *status)
+ }
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate sessions: %w", err)
+ }
+ return statuses, nil
+}
+
+func (s *Store) getSummaryCount(ctx context.Context, convID int64) (int, error) {
+ var count int
+ err := s.db.QueryRowContext(ctx,
+ "SELECT COUNT(*) FROM summaries WHERE conversation_id = ?",
+ convID,
+ ).Scan(&count)
+ return count, err
+}
+
+func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Time, time.Time, error) {
+ var minTime, maxTime string
+ err := s.db.QueryRowContext(ctx,
+ "SELECT MIN(created_at), MAX(created_at) FROM messages WHERE conversation_id = ?",
+ convID,
+ ).Scan(&minTime, &maxTime)
+ if err != nil || minTime == "" {
+ return time.Time{}, time.Time{}, err
+ }
+ oldest, _ := time.Parse("2006-01-02 15:04:05", minTime)
+ newest, _ := time.Parse("2006-01-02 15:04:05", maxTime)
+ return oldest, newest, nil
+}
+
+// --- Message Operations ---
+
+// AddMessage appends a message to a conversation.
+func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) {
+ result, err := s.db.ExecContext(ctx,
+ "INSERT INTO messages (conversation_id, role, content, token_count) VALUES (?, ?, ?, ?)",
+ convID, role, content, tokenCount,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("add message: %w", err)
+ }
+ id, _ := result.LastInsertId()
+ return &Message{
+ ID: id,
+ ConversationID: convID,
+ Role: role,
+ Content: content,
+ TokenCount: tokenCount,
+ }, nil
+}
+
+// partsToReadableContent derives a readable text summary from message parts.
+// This ensures FTS5 indexing and summary formatting can access tool call information.
+func partsToReadableContent(parts []MessagePart) string {
+ var b strings.Builder
+ for i, p := range parts {
+ if i > 0 {
+ b.WriteString("\n")
+ }
+ switch p.Type {
+ case "text":
+ b.WriteString(p.Text)
+ case "tool_use":
+ fmt.Fprintf(&b, "[tool_use: %s, args: %s]", p.Name, p.Arguments)
+ case "tool_result":
+ fmt.Fprintf(&b, "[tool_result for %s: %s]", p.ToolCallID, p.Text)
+ case "media":
+ fmt.Fprintf(&b, "[media: %s (%s)]", p.MediaURI, p.MimeType)
+ default:
+ if p.Text != "" {
+ b.WriteString(p.Text)
+ }
+ }
+ }
+ return b.String()
+}
+
+// AddMessageWithParts adds a message with structured parts.
+func (s *Store) AddMessageWithParts(
+ ctx context.Context,
+ convID int64,
+ role string,
+ parts []MessagePart,
+ tokenCount int,
+) (*Message, error) {
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return nil, fmt.Errorf("begin tx: %w", err)
+ }
+ defer tx.Rollback()
+
+ // Derive readable content from Parts for FTS5 indexing and summary formatting
+ readableContent := partsToReadableContent(parts)
+
+ result, err := tx.ExecContext(ctx,
+ "INSERT INTO messages (conversation_id, role, content, token_count) VALUES (?, ?, ?, ?)",
+ convID, role, readableContent, tokenCount,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("add message: %w", err)
+ }
+ msgID, _ := result.LastInsertId()
+
+ for i, p := range parts {
+ _, err = tx.ExecContext(
+ ctx,
+ `INSERT INTO message_parts (message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type, ordinal)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ msgID,
+ p.Type,
+ p.Text,
+ p.Name,
+ p.Arguments,
+ p.ToolCallID,
+ p.MediaURI,
+ p.MimeType,
+ i,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("add message part %d: %w", i, err)
+ }
+ }
+ if err := tx.Commit(); err != nil {
+ return nil, fmt.Errorf("commit: %w", err)
+ }
+
+ // Return message with parts
+ msg := &Message{
+ ID: msgID,
+ ConversationID: convID,
+ Role: role,
+ TokenCount: tokenCount,
+ Parts: make([]MessagePart, len(parts)),
+ }
+ for i, p := range parts {
+ p.MessageID = msgID
+ msg.Parts[i] = p
+ }
+ return msg, nil
+}
+
+// GetMessages retrieves messages for a conversation.
+func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, beforeID int64) ([]Message, error) {
+ query := "SELECT message_id, conversation_id, role, content, token_count, created_at FROM messages WHERE conversation_id = ?"
+ args := []any{convID}
+ if beforeID > 0 {
+ query += " AND message_id < ?"
+ args = append(args, beforeID)
+ }
+ query += " ORDER BY message_id ASC"
+ if limit > 0 {
+ query += " LIMIT ?"
+ args = append(args, limit)
+ }
+
+ rows, err := s.db.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, fmt.Errorf("get messages: %w", err)
+ }
+ defer rows.Close()
+
+ var msgs []Message
+ for rows.Next() {
+ var msg Message
+ var createdAt string
+ if err := rows.Scan(
+ &msg.ID,
+ &msg.ConversationID,
+ &msg.Role,
+ &msg.Content,
+ &msg.TokenCount,
+ &createdAt,
+ ); err != nil {
+ return nil, err
+ }
+ msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ msgs = append(msgs, msg)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ // Load parts for all messages
+ for i := range msgs {
+ parts, err := s.loadMessageParts(ctx, msgs[i].ID)
+ if err != nil {
+ return nil, err
+ }
+ msgs[i].Parts = parts
+ }
+
+ return msgs, nil
+}
+
+// GetMessageCount returns total message count for a conversation.
+func (s *Store) GetMessageCount(ctx context.Context, convID int64) (int, error) {
+ var count int
+ err := s.db.QueryRowContext(ctx,
+ "SELECT count(*) FROM messages WHERE conversation_id = ?", convID,
+ ).Scan(&count)
+ return count, err
+}
+
+// GetMessageByID retrieves a single message by ID.
+func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message, error) {
+ var msg Message
+ var createdAt string
+ err := s.db.QueryRowContext(ctx,
+ "SELECT message_id, conversation_id, role, content, token_count, created_at FROM messages WHERE message_id = ?",
+ messageID,
+ ).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.TokenCount, &createdAt)
+ if err == sql.ErrNoRows {
+ return nil, fmt.Errorf("message %d not found", messageID)
+ }
+ if err != nil {
+ return nil, err
+ }
+ msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ msg.Parts, _ = s.loadMessageParts(ctx, msg.ID)
+ return &msg, nil
+}
+
+func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) {
+ rows, err := s.db.QueryContext(ctx,
+ `SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type
+ FROM message_parts WHERE message_id = ? ORDER BY ordinal`,
+ msgID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var parts []MessagePart
+ for rows.Next() {
+ var p MessagePart
+ if err := rows.Scan(&p.ID, &p.MessageID, &p.Type, &p.Text, &p.Name, &p.Arguments,
+ &p.ToolCallID, &p.MediaURI, &p.MimeType); err != nil {
+ return nil, err
+ }
+ parts = append(parts, p)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return parts, nil
+}
+
+// --- Summary Operations ---
+
+// CreateSummary creates a new summary and indexes it in FTS5.
+func (s *Store) CreateSummary(ctx context.Context, input CreateSummaryInput) (*Summary, error) {
+ // Generate summary ID
+ now := time.Now().UTC()
+ summaryID := generateSummaryID(input.Content, now)
+
+ var earliestAt, latestAt sql.NullString
+ if input.EarliestAt != nil {
+ earliestAt = sql.NullString{String: input.EarliestAt.Format(time.RFC3339), Valid: true}
+ }
+ if input.LatestAt != nil {
+ latestAt = sql.NullString{String: input.LatestAt.Format(time.RFC3339), Valid: true}
+ }
+
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return nil, fmt.Errorf("begin tx: %w", err)
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx,
+ `INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count,
+ earliest_at, latest_at, descendant_count, descendant_token_count,
+ source_message_token_count, model)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ summaryID, input.ConversationID, string(input.Kind), input.Depth,
+ input.Content, input.TokenCount,
+ earliestAt, latestAt,
+ input.DescendantCount, input.DescendantTokenCount,
+ input.SourceMessageTokens, input.Model,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("insert summary: %w", err)
+ }
+
+ // FTS trigger will fire automatically for summaries table insert
+
+ // Link parent summaries (DAG edges) for condensed summaries
+ for _, parentID := range input.ParentIDs {
+ _, err = tx.ExecContext(ctx,
+ "INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES (?, ?)",
+ summaryID, parentID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("link parent %s: %w", parentID, err)
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ return nil, fmt.Errorf("commit: %w", err)
+ }
+
+ return &Summary{
+ SummaryID: summaryID,
+ ConversationID: input.ConversationID,
+ Kind: input.Kind,
+ Depth: input.Depth,
+ Content: input.Content,
+ TokenCount: input.TokenCount,
+ EarliestAt: input.EarliestAt,
+ LatestAt: input.LatestAt,
+ DescendantCount: input.DescendantCount,
+ DescendantTokenCount: input.DescendantTokenCount,
+ SourceMessageTokenCount: input.SourceMessageTokens,
+ Model: input.Model,
+ CreatedAt: now,
+ }, nil
+}
+
+// GetSummary retrieves a summary by ID.
+func (s *Store) GetSummary(ctx context.Context, summaryID string) (*Summary, error) {
+ return s.scanSummary(ctx, "WHERE summary_id = ?", summaryID)
+}
+
+// GetSummariesByConversation retrieves all summaries for a conversation.
+func (s *Store) GetSummariesByConversation(ctx context.Context, convID int64) ([]Summary, error) {
+ rows, err := s.db.QueryContext(ctx,
+ `SELECT summary_id, conversation_id, kind, depth, content, token_count,
+ earliest_at, latest_at, descendant_count, descendant_token_count,
+ source_message_token_count, model, created_at
+ FROM summaries WHERE conversation_id = ? ORDER BY created_at`,
+ convID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ return s.scanSummaries(rows)
+}
+
+// GetSummaryChildren retrieves child summary IDs (summaries that list this summary as parent).
+func (s *Store) GetSummaryChildren(ctx context.Context, summaryID string) ([]string, error) {
+ rows, err := s.db.QueryContext(ctx,
+ "SELECT summary_id FROM summary_parents WHERE parent_summary_id = ?",
+ summaryID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var ids []string
+ for rows.Next() {
+ var id string
+ if err := rows.Scan(&id); err != nil {
+ return nil, err
+ }
+ ids = append(ids, id)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// GetSummaryParents retrieves parent summaries (full objects) for a summary.
+func (s *Store) GetSummaryParents(ctx context.Context, summaryID string) ([]Summary, error) {
+ rows, err := s.db.QueryContext(ctx,
+ `SELECT s.summary_id, s.conversation_id, s.kind, s.depth, s.content, s.token_count,
+ s.earliest_at, s.latest_at, s.descendant_count, s.descendant_token_count,
+ s.source_message_token_count, s.model, s.created_at
+ FROM summary_parents sp
+ JOIN summaries s ON s.summary_id = sp.parent_summary_id
+ WHERE sp.summary_id = ?`,
+ summaryID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ return s.scanSummaries(rows)
+}
+
+// LinkSummaryToMessages links a leaf summary to its source messages.
+func (s *Store) LinkSummaryToMessages(ctx context.Context, summaryID string, messageIDs []int64) error {
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ for i, msgID := range messageIDs {
+ _, err = tx.ExecContext(ctx,
+ "INSERT OR IGNORE INTO summary_messages (summary_id, message_id, ordinal) VALUES (?, ?, ?)",
+ summaryID, msgID, i,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return tx.Commit()
+}
+
+// GetSummarySourceMessages retrieves source messages for a summary.
+func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) ([]Message, error) {
+ rows, err := s.db.QueryContext(ctx,
+ `SELECT m.message_id, m.conversation_id, m.role, m.content, m.token_count, m.created_at
+ FROM summary_messages sm
+ JOIN messages m ON m.message_id = sm.message_id
+ WHERE sm.summary_id = ?
+ ORDER BY sm.ordinal`,
+ summaryID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var msgs []Message
+ for rows.Next() {
+ var msg Message
+ var createdAt string
+ if err := rows.Scan(
+ &msg.ID,
+ &msg.ConversationID,
+ &msg.Role,
+ &msg.Content,
+ &msg.TokenCount,
+ &createdAt,
+ ); err != nil {
+ return nil, err
+ }
+ msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ msgs = append(msgs, msg)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return msgs, nil
+}
+
+// GetRootSummaries retrieves root summaries (not children of any other summary).
+func (s *Store) GetRootSummaries(ctx context.Context, convID int64) ([]Summary, error) {
+ rows, err := s.db.QueryContext(ctx,
+ `SELECT s.summary_id, s.conversation_id, s.kind, s.depth, s.content, s.token_count,
+ s.earliest_at, s.latest_at, s.descendant_count, s.descendant_token_count,
+ s.source_message_token_count, s.model, s.created_at
+ FROM summaries s
+ WHERE s.conversation_id = ?
+ AND s.summary_id NOT IN (SELECT sp.parent_summary_id FROM summary_parents sp)
+ ORDER BY s.created_at`,
+ convID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ return s.scanSummaries(rows)
+}
+
+// --- Context Item Operations ---
+
+// GetContextItems retrieves context items for a conversation, ordered by ordinal.
+func (s *Store) GetContextItems(ctx context.Context, convID int64) ([]ContextItem, error) {
+ rows, err := s.db.QueryContext(
+ ctx,
+ "SELECT ordinal, item_type, summary_id, message_id, token_count, created_at FROM context_items WHERE conversation_id = ? ORDER BY ordinal",
+ convID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var items []ContextItem
+ for rows.Next() {
+ var item ContextItem
+ var summaryID sql.NullString
+ var messageID sql.NullInt64
+ var createdAt sql.NullString
+ if err := rows.Scan(
+ &item.Ordinal,
+ &item.ItemType,
+ &summaryID,
+ &messageID,
+ &item.TokenCount,
+ &createdAt,
+ ); err != nil {
+ return nil, err
+ }
+ item.ConversationID = convID
+ if summaryID.Valid {
+ item.SummaryID = summaryID.String
+ }
+ if messageID.Valid {
+ item.MessageID = messageID.Int64
+ }
+ if createdAt.Valid {
+ t, _ := time.Parse("2006-01-02 15:04:05", createdAt.String)
+ item.CreatedAt = t
+ }
+ items = append(items, item)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return items, nil
+}
+
+// UpsertContextItems replaces all context items for a conversation.
+func (s *Store) UpsertContextItems(ctx context.Context, convID int64, items []ContextItem) error {
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ _, err = tx.ExecContext(ctx, "DELETE FROM context_items WHERE conversation_id = ?", convID)
+ if err != nil {
+ return err
+ }
+
+ for _, item := range items {
+ _, err = tx.ExecContext(ctx,
+ `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, message_id, token_count)
+ VALUES (?, ?, ?, ?, ?, ?)`,
+ convID, item.Ordinal, item.ItemType,
+ nullString(item.SummaryID), nullInt64(item.MessageID),
+ item.TokenCount,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return tx.Commit()
+}
+
+// ClearContextItems removes all context items for a conversation.
+func (s *Store) ClearContextItems(ctx context.Context, convID int64) error {
+ _, err := s.db.ExecContext(ctx, "DELETE FROM context_items WHERE conversation_id = ?", convID)
+ return err
+}
+
+// DeleteMessagesAfterID deletes all messages with ID > afterID for a conversation.
+// Also clears related context_items, message_parts, summary_messages, and FTS entries.
+// Uses transaction to ensure atomicity of the delete cascade.
+func (s *Store) DeleteMessagesAfterID(ctx context.Context, convID int64, afterID int64) error {
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ // Get message IDs to delete for cleaning up related tables
+ rows, err := tx.QueryContext(ctx,
+ "SELECT message_id FROM messages WHERE conversation_id = ? AND message_id > ?", convID, afterID)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ var msgIDs []int64
+ for rows.Next() {
+ var id int64
+ if scanErr := rows.Scan(&id); scanErr != nil {
+ return scanErr
+ }
+ msgIDs = append(msgIDs, id)
+ }
+ if rows.Err() != nil {
+ return rows.Err()
+ }
+
+ // Delete context_items referencing these messages
+ for _, msgID := range msgIDs {
+ if _, err := tx.ExecContext(ctx, "DELETE FROM context_items WHERE message_id = ?", msgID); err != nil {
+ return err
+ }
+ }
+
+ // Delete from message_parts and summary_messages
+ // Note: messages_fts is handled automatically by trigger, no manual delete needed
+ for _, msgID := range msgIDs {
+ if _, err := tx.ExecContext(ctx, "DELETE FROM message_parts WHERE message_id = ?", msgID); err != nil {
+ return err
+ }
+ if _, err := tx.ExecContext(ctx, "DELETE FROM summary_messages WHERE message_id = ?", msgID); err != nil {
+ return err
+ }
+ }
+
+ // Delete messages
+ if _, err := tx.ExecContext(ctx,
+ "DELETE FROM messages WHERE conversation_id = ? AND message_id > ?", convID, afterID); err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+// AppendContextMessage appends a single message to context_items at next ordinal.
+func (s *Store) AppendContextMessage(ctx context.Context, convID int64, messageID int64) error {
+ return s.appendContextItems(ctx, convID, []ContextItem{
+ {ItemType: "message", MessageID: messageID},
+ })
+}
+
+// AppendContextMessages bulk-appends messages to context_items.
+func (s *Store) AppendContextMessages(ctx context.Context, convID int64, messageIDs []int64) error {
+ items := make([]ContextItem, len(messageIDs))
+ for i, id := range messageIDs {
+ items[i] = ContextItem{ItemType: "message", MessageID: id}
+ }
+ return s.appendContextItems(ctx, convID, items)
+}
+
+// AppendContextSummary appends a summary to context_items at next ordinal.
+func (s *Store) AppendContextSummary(ctx context.Context, convID int64, summaryID string) error {
+ return s.appendContextItems(ctx, convID, []ContextItem{
+ {ItemType: "summary", SummaryID: summaryID},
+ })
+}
+
+func (s *Store) appendContextItems(ctx context.Context, convID int64, items []ContextItem) error {
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ maxOrd, err := s.GetMaxOrdinalTx(ctx, tx, convID)
+ if err != nil {
+ return err
+ }
+
+ ordinal := maxOrd + OrdinalStep
+ for _, item := range items {
+ item.ConversationID = convID
+ item.Ordinal = ordinal
+
+ // Resolve token count if not set
+ tokenCount := item.TokenCount
+ if tokenCount == 0 {
+ tokenCount = s.resolveItemTokenCountTx(ctx, tx, item)
+ }
+
+ _, err = tx.ExecContext(ctx,
+ `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, message_id, token_count)
+ VALUES (?, ?, ?, ?, ?, ?)`,
+ convID, ordinal, item.ItemType,
+ nullString(item.SummaryID), nullInt64(item.MessageID),
+ tokenCount,
+ )
+ if err != nil {
+ return err
+ }
+ ordinal += OrdinalStep
+ }
+ return tx.Commit()
+}
+
+// resolveItemTokenCountTx looks up token count within a transaction.
+func (s *Store) resolveItemTokenCountTx(ctx context.Context, tx *sql.Tx, item ContextItem) int {
+ if item.ItemType == "message" && item.MessageID > 0 {
+ var tc int
+ err := tx.QueryRowContext(ctx,
+ "SELECT token_count FROM messages WHERE message_id = ?", item.MessageID,
+ ).Scan(&tc)
+ if err == nil {
+ return tc
+ }
+ }
+ if item.ItemType == "summary" && item.SummaryID != "" {
+ var tc int
+ err := tx.QueryRowContext(ctx,
+ "SELECT token_count FROM summaries WHERE summary_id = ?", item.SummaryID,
+ ).Scan(&tc)
+ if err == nil {
+ return tc
+ }
+ }
+ return 0
+}
+
+// ReplaceContextRangeWithSummary atomically replaces a range of context items with a summary.
+// If ordinal gap is exhausted, triggers resequencing (spec lines 1204-1209).
+func (s *Store) ReplaceContextRangeWithSummary(
+ ctx context.Context,
+ convID int64,
+ startOrdinal, endOrdinal int,
+ summaryID string,
+) error {
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ // Delete the range
+ _, err = tx.ExecContext(ctx,
+ "DELETE FROM context_items WHERE conversation_id = ? AND ordinal >= ? AND ordinal <= ?",
+ convID, startOrdinal, endOrdinal,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Insert summary at midpoint of replaced range
+ midpoint := (startOrdinal + endOrdinal) / 2
+
+ // Check if midpoint conflicts with existing ordinal
+ var conflict bool
+ var existingOrd int
+ err = tx.QueryRowContext(ctx,
+ "SELECT ordinal FROM context_items WHERE conversation_id = ? AND ordinal = ?",
+ convID, midpoint,
+ ).Scan(&existingOrd)
+ if err == nil {
+ conflict = true
+ }
+
+ if conflict {
+ // Gap exhausted, need resequence (spec lines 1204-1209)
+ err = s.resequenceContextItemsTx(ctx, tx, convID, summaryID)
+ if err != nil {
+ return fmt.Errorf("resequence: %w", err)
+ }
+ } else {
+ // Normal insert at midpoint with token_count from summary
+ _, err = tx.ExecContext(ctx,
+ `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, token_count)
+ SELECT ?, ?, 'summary', ?, token_count FROM summaries WHERE summary_id = ?`,
+ convID, midpoint, summaryID, summaryID,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+// ReplaceContextItemsWithSummary replaces specific context items (by summary_id) with a new summary.
+// Use this when candidates are not contiguous in ordinal space to avoid deleting non-candidate items.
+func (s *Store) ReplaceContextItemsWithSummary(
+ ctx context.Context,
+ convID int64,
+ summaryIDs []string,
+ newSummaryID string,
+) error {
+ if len(summaryIDs) == 0 {
+ return nil
+ }
+
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer tx.Rollback()
+
+ // Find the ordinals of items to delete and calculate midpoint
+ placeholders := make([]string, len(summaryIDs))
+ args := make([]any, len(summaryIDs)+1)
+ args[0] = convID
+ for i, sid := range summaryIDs {
+ placeholders[i] = "?"
+ args[i+1] = sid
+ }
+
+ query := fmt.Sprintf(
+ "SELECT ordinal FROM context_items WHERE conversation_id = ? AND summary_id IN (%s) ORDER BY ordinal",
+ strings.Join(placeholders, ","),
+ )
+ rows, err := tx.QueryContext(ctx, query, args...)
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ var ordinals []int
+ for rows.Next() {
+ var ord int
+ if scanErr := rows.Scan(&ord); scanErr != nil {
+ return scanErr
+ }
+ ordinals = append(ordinals, ord)
+ }
+ if err = rows.Err(); err != nil {
+ return err
+ }
+
+ if len(ordinals) == 0 {
+ return nil
+ }
+
+ midpoint := (ordinals[0] + ordinals[len(ordinals)-1]) / 2
+
+ // Delete the specific items by summary_id
+ deleteQuery := fmt.Sprintf(
+ "DELETE FROM context_items WHERE conversation_id = ? AND summary_id IN (%s)",
+ strings.Join(placeholders, ","),
+ )
+ _, err = tx.ExecContext(ctx, deleteQuery, args...)
+ if err != nil {
+ return err
+ }
+
+ // Check if midpoint conflicts with existing ordinal
+ var conflict bool
+ var existingOrd int
+ err = tx.QueryRowContext(ctx,
+ "SELECT ordinal FROM context_items WHERE conversation_id = ? AND ordinal = ?",
+ convID, midpoint,
+ ).Scan(&existingOrd)
+ if err == nil {
+ conflict = true
+ }
+
+ if conflict {
+ // Gap exhausted, need resequence
+ err = s.resequenceContextItemsTx(ctx, tx, convID, newSummaryID)
+ if err != nil {
+ return fmt.Errorf("resequence: %w", err)
+ }
+ } else {
+ // Normal insert at midpoint
+ _, err = tx.ExecContext(ctx,
+ `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, token_count)
+ SELECT ?, ?, 'summary', ?, token_count FROM summaries WHERE summary_id = ?`,
+ convID, midpoint, newSummaryID, newSummaryID,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return tx.Commit()
+}
+
+// resequenceContextItemsTx renumbers context_items with fresh OrdinalStep gaps.
+// Uses temp negative ordinals to avoid PRIMARY KEY constraint violations (spec lines 1240-1247).
+func (s *Store) resequenceContextItemsTx(ctx context.Context, tx *sql.Tx, convID int64, newSummaryID string) error {
+ // Get all remaining items sorted by current ordinal
+ rows, err := tx.QueryContext(
+ ctx,
+ "SELECT ordinal, item_type, summary_id, message_id, token_count FROM context_items WHERE conversation_id = ? ORDER BY ordinal",
+ convID,
+ )
+ if err != nil {
+ return err
+ }
+ defer rows.Close()
+
+ type item struct {
+ ordinal int
+ itemType string
+ summaryID string
+ messageID int64
+ tokenCount int
+ }
+ var items []item
+ for rows.Next() {
+ var i item
+ var sid sql.NullString
+ var mid sql.NullInt64
+ var scanErr error
+ if scanErr = rows.Scan(&i.ordinal, &i.itemType, &sid, &mid, &i.tokenCount); scanErr != nil {
+ return scanErr
+ }
+ if sid.Valid {
+ i.summaryID = sid.String
+ }
+ if mid.Valid {
+ i.messageID = mid.Int64
+ }
+ items = append(items, i)
+ }
+ if rowsErr := rows.Err(); rowsErr != nil {
+ return rowsErr
+ }
+
+ // Step 1: Move all items to temp negative ordinals
+ tempOrd := -1
+ for _, i := range items {
+ _, execErr := tx.ExecContext(ctx,
+ "UPDATE context_items SET ordinal = ? WHERE conversation_id = ? AND ordinal = ?",
+ tempOrd, convID, i.ordinal,
+ )
+ if execErr != nil {
+ return execErr
+ }
+ tempOrd--
+ }
+
+ // Step 2: Insert new summary at the end with positive ordinal
+ // Include token_count from summaries table
+ newOrd := (len(items) + 1) * OrdinalStep
+ _, err = tx.ExecContext(ctx,
+ `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, token_count)
+ SELECT ?, ?, 'summary', ?, token_count FROM summaries WHERE summary_id = ?`,
+ convID, newOrd, newSummaryID, newSummaryID,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Step 3: Update each temp item to its final positive ordinal
+ // Use specific temp ordinal matching (not ordinal < 0) to avoid updating all items
+ finalOrd := OrdinalStep
+ tempOrd = -1 // Reset to first temp ordinal (already declared in Step 1)
+ for range items {
+ _, execErr := tx.ExecContext(ctx,
+ "UPDATE context_items SET ordinal = ? WHERE conversation_id = ? AND ordinal = ?",
+ finalOrd, convID, tempOrd,
+ )
+ if execErr != nil {
+ return execErr
+ }
+ finalOrd += OrdinalStep
+ tempOrd--
+ }
+
+ return nil
+}
+
+// GetContextTokenCount returns total token count for all items in context.
+func (s *Store) GetContextTokenCount(ctx context.Context, convID int64) (int, error) {
+ var count int
+ err := s.db.QueryRowContext(ctx,
+ "SELECT COALESCE(SUM(token_count), 0) FROM context_items WHERE conversation_id = ?",
+ convID,
+ ).Scan(&count)
+ return count, err
+}
+
+// GetMaxOrdinal returns the highest ordinal in context_items for a conversation.
+func (s *Store) GetMaxOrdinal(ctx context.Context, convID int64) (int, error) {
+ var maxOrd sql.NullInt64
+ err := s.db.QueryRowContext(ctx,
+ "SELECT MAX(ordinal) FROM context_items WHERE conversation_id = ?",
+ convID,
+ ).Scan(&maxOrd)
+ if err != nil {
+ return 0, err
+ }
+ if !maxOrd.Valid {
+ return 0, nil
+ }
+ return int(maxOrd.Int64), nil
+}
+
+// GetMaxOrdinalTx returns the highest ordinal within a transaction.
+func (s *Store) GetMaxOrdinalTx(ctx context.Context, tx *sql.Tx, convID int64) (int, error) {
+ var maxOrd sql.NullInt64
+ err := tx.QueryRowContext(ctx,
+ "SELECT MAX(ordinal) FROM context_items WHERE conversation_id = ?",
+ convID,
+ ).Scan(&maxOrd)
+ if err != nil {
+ return 0, err
+ }
+ if !maxOrd.Valid {
+ return 0, nil
+ }
+ return int(maxOrd.Int64), nil
+}
+
+// GetDistinctDepthsInContext returns distinct depth levels of summaries currently in context.
+// maxOrdinalExclusive filters out summaries with ordinal >= this value (0 = no filter).
+func (s *Store) GetDistinctDepthsInContext(ctx context.Context, convID int64, maxOrdinalExclusive int) ([]int, error) {
+ query := `SELECT DISTINCT s.depth
+ FROM context_items ci
+ JOIN summaries s ON s.summary_id = ci.summary_id
+ WHERE ci.conversation_id = ? AND ci.item_type = 'summary'`
+ args := []any{convID}
+
+ if maxOrdinalExclusive > 0 {
+ query += " AND ci.ordinal < ?"
+ args = append(args, maxOrdinalExclusive)
+ }
+
+ query += " ORDER BY s.depth"
+
+ rows, err := s.db.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, fmt.Errorf("get distinct depths: %w", err)
+ }
+ defer rows.Close()
+
+ var depths []int
+ for rows.Next() {
+ var d int
+ if err := rows.Scan(&d); err != nil {
+ return nil, err
+ }
+ depths = append(depths, d)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return depths, nil
+}
+
+// GetSummarySubtree returns all summaries in the subtree rooted at summaryID,
+// including summaryID itself. Uses a recursive CTE to traverse the DAG.
+func (s *Store) GetSummarySubtree(ctx context.Context, summaryID string) ([]SummarySubtreeNode, error) {
+ rows, err := s.db.QueryContext(ctx, `
+ WITH RECURSIVE subtree AS (
+ SELECT summary_id, 0 AS depth_from_root
+ FROM summaries
+ WHERE summary_id = ?
+ UNION ALL
+ SELECT sp.parent_summary_id, st.depth_from_root + 1
+ FROM summary_parents sp
+ JOIN subtree st ON sp.summary_id = st.summary_id
+ )
+ SELECT summary_id, depth_from_root FROM subtree`,
+ summaryID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("get summary subtree: %w", err)
+ }
+ defer rows.Close()
+
+ var nodes []SummarySubtreeNode
+ for rows.Next() {
+ var n SummarySubtreeNode
+ if err := rows.Scan(&n.SummaryID, &n.DepthFromRoot); err != nil {
+ return nil, err
+ }
+ nodes = append(nodes, n)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return nodes, nil
+}
+
+// --- Search Operations ---
+
+// SearchSummaries performs full-text search on summaries.
+func (s *Store) SearchSummaries(ctx context.Context, input SearchInput) ([]SearchResult, error) {
+ // "like" → LIKE search, anything else (including "full_text" or empty) → FTS5
+ if input.Mode == "like" {
+ return s.searchSummariesLike(ctx, input)
+ }
+ return s.searchSummariesFTS(ctx, input)
+}
+
+func (s *Store) searchSummariesFTS(ctx context.Context, input SearchInput) ([]SearchResult, error) {
+ // Build WHERE clause for filters (used in both count and data queries)
+ whereClauses := []string{"summaries_fts MATCH ?"}
+ args := []any{input.Pattern}
+
+ if input.ConversationID > 0 && !input.AllConversations {
+ whereClauses = append(whereClauses, "s.conversation_id = ?")
+ args = append(args, input.ConversationID)
+ }
+
+ if input.Since != nil {
+ whereClauses = append(whereClauses, "s.created_at >= ?")
+ args = append(args, input.Since.Format("2006-01-02 15:04:05"))
+ }
+ if input.Before != nil {
+ whereClauses = append(whereClauses, "s.created_at < ?")
+ args = append(args, input.Before.Format("2006-01-02 15:04:05"))
+ }
+
+ whereStr := strings.Join(whereClauses, " AND ")
+
+ // First, get total count (bm25 conflicts with window functions in FTS5)
+ countQuery := `SELECT COUNT(*) FROM summaries_fts fts
+ JOIN summaries s ON s.summary_id = fts.summary_id
+ WHERE ` + whereStr
+ var totalCount int
+ if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil {
+ return nil, err
+ }
+
+ // Then, get actual results with bm25 ranking
+ dataQuery := `SELECT s.summary_id, s.conversation_id, s.kind, s.content, s.created_at, bm25(summaries_fts) as rank
+ FROM summaries_fts fts
+ JOIN summaries s ON s.summary_id = fts.summary_id
+ WHERE ` + whereStr + ` ORDER BY rank`
+
+ dataArgs := append([]any{}, args...) // copy args
+ if input.Limit > 0 {
+ dataQuery += " LIMIT ?"
+ dataArgs = append(dataArgs, input.Limit)
+ }
+
+ rows, err := s.db.QueryContext(ctx, dataQuery, dataArgs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ results, err := s.scanSearchResults(rows, true)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set total count on all results
+ for i := range results {
+ results[i].TotalCount = totalCount
+ }
+ return results, nil
+}
+
+// buildLikeQuery appends conversation/time filters and limit to a LIKE query.
+// Note: role filtering is NOT applied here since summaries don't have role column.
+// Use buildMessagesLikeQuery for message searches that need role filtering.
+func buildLikeQuery(query string, args []any, input SearchInput) (string, []any) {
+ if input.ConversationID > 0 && !input.AllConversations {
+ query += " AND conversation_id = ?"
+ args = append(args, input.ConversationID)
+ }
+ if input.Since != nil {
+ query += " AND created_at >= ?"
+ args = append(args, input.Since.Format("2006-01-02 15:04:05"))
+ }
+ if input.Before != nil {
+ query += " AND created_at < ?"
+ args = append(args, input.Before.Format("2006-01-02 15:04:05"))
+ }
+ // Order by newest first for LIKE mode
+ query += " ORDER BY created_at DESC"
+ if input.Limit > 0 {
+ query += " LIMIT ?"
+ args = append(args, input.Limit)
+ }
+ return query, args
+}
+
+// buildMessagesLikeQuery is like buildLikeQuery but adds role filtering for messages.
+func buildMessagesLikeQuery(query string, args []any, input SearchInput) (string, []any) {
+ if input.Role != "" {
+ query += " AND role = ?"
+ args = append(args, input.Role)
+ }
+ return buildLikeQuery(query, args, input)
+}
+
+func (s *Store) searchSummariesLike(ctx context.Context, input SearchInput) ([]SearchResult, error) {
+ query := `SELECT summary_id, conversation_id, kind, content, created_at, COUNT(*) OVER() as total_count
+ FROM summaries WHERE content LIKE ?`
+ args := []any{"%" + input.Pattern + "%"}
+ query, args = buildLikeQuery(query, args, input)
+
+ rows, err := s.db.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ return s.scanSearchResults(rows, false)
+}
+
+func (s *Store) scanSearchResults(rows *sql.Rows, withRank bool) ([]SearchResult, error) {
+ var results []SearchResult
+ for rows.Next() {
+ var r SearchResult
+ var createdAt string
+ var kind string
+ if withRank {
+ // FTS5 mode: no TotalCount in query (set by caller after COUNT)
+ if err := rows.Scan(&r.SummaryID, &r.ConversationID, &kind, &r.Content, &createdAt, &r.Rank); err != nil {
+ return nil, err
+ }
+ } else {
+ // LIKE mode: TotalCount from window function
+ if err := rows.Scan(&r.SummaryID, &r.ConversationID, &kind,
+ &r.Content, &createdAt, &r.TotalCount); err != nil {
+ return nil, err
+ }
+ }
+ r.Kind = SummaryKind(kind)
+ r.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ results = append(results, r)
+ }
+ return results, nil
+}
+
+// SearchMessages performs full-text or regex search on messages.
+func (s *Store) SearchMessages(ctx context.Context, input SearchInput) ([]SearchResult, error) {
+ // Try FTS5 first for full-text mode
+ if input.Mode == "" || input.Mode == "full_text" {
+ results, err := s.searchMessagesFTS(ctx, input)
+ if err == nil && len(results) > 0 {
+ return results, nil
+ }
+ // Fall through to LIKE
+ }
+
+ return s.searchMessagesLike(ctx, input)
+}
+
+func (s *Store) searchMessagesFTS(ctx context.Context, input SearchInput) ([]SearchResult, error) {
+ // Build WHERE clause for filters (used in both count and data queries)
+ whereClauses := []string{"messages_fts MATCH ?"}
+ args := []any{input.Pattern}
+
+ if input.ConversationID > 0 && !input.AllConversations {
+ whereClauses = append(whereClauses, "m.conversation_id = ?")
+ args = append(args, input.ConversationID)
+ }
+
+ if input.Role != "" {
+ whereClauses = append(whereClauses, "m.role = ?")
+ args = append(args, input.Role)
+ }
+
+ if input.Since != nil {
+ whereClauses = append(whereClauses, "m.created_at >= ?")
+ args = append(args, input.Since.Format("2006-01-02 15:04:05"))
+ }
+ if input.Before != nil {
+ whereClauses = append(whereClauses, "m.created_at < ?")
+ args = append(args, input.Before.Format("2006-01-02 15:04:05"))
+ }
+
+ whereStr := strings.Join(whereClauses, " AND ")
+
+ // First, get total count (bm25 conflicts with window functions in FTS5)
+ countQuery := `SELECT COUNT(*) FROM messages_fts f
+ JOIN messages m ON f.message_id = m.message_id
+ WHERE ` + whereStr
+ var totalCount int
+ if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil {
+ return nil, err
+ }
+
+ // Then, get actual results with bm25 ranking
+ dataQuery := `SELECT m.message_id, m.conversation_id, m.role, m.content, m.created_at, bm25(messages_fts) as rank
+ FROM messages_fts f
+ JOIN messages m ON f.message_id = m.message_id
+ WHERE ` + whereStr + ` ORDER BY rank`
+
+ dataArgs := append([]any{}, args...) // copy args
+ if input.Limit > 0 {
+ dataQuery += " LIMIT ?"
+ dataArgs = append(dataArgs, input.Limit)
+ }
+
+ rows, err := s.db.QueryContext(ctx, dataQuery, dataArgs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ results, err := s.scanMessageSearchResults(rows, true)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set total count on all results
+ for i := range results {
+ results[i].TotalCount = totalCount
+ }
+ return results, nil
+}
+
+func (s *Store) searchMessagesLike(ctx context.Context, input SearchInput) ([]SearchResult, error) {
+ query := `SELECT message_id, conversation_id, role, content, created_at, COUNT(*) OVER() as total_count
+ FROM messages WHERE content LIKE ?`
+ args := []any{"%" + input.Pattern + "%"}
+ query, args = buildMessagesLikeQuery(query, args, input)
+
+ rows, err := s.db.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ return s.scanMessageSearchResults(rows, false)
+}
+
+func (s *Store) scanMessageSearchResults(rows *sql.Rows, withRank bool) ([]SearchResult, error) {
+ var results []SearchResult
+ for rows.Next() {
+ var r SearchResult
+ var createdAt string
+ var content string
+ if withRank {
+ // FTS5 mode: no TotalCount in query (set by caller after COUNT)
+ if err := rows.Scan(&r.MessageID, &r.ConversationID, &r.Role, &content, &createdAt, &r.Rank); err != nil {
+ return nil, err
+ }
+ } else {
+ // LIKE mode: TotalCount from window function
+ if err := rows.Scan(&r.MessageID, &r.ConversationID, &r.Role, &content,
+ &createdAt, &r.TotalCount); err != nil {
+ return nil, err
+ }
+ }
+ r.Snippet = content
+ r.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ results = append(results, r)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return results, nil
+}
+
+// --- Helpers ---
+
+func (s *Store) scanSummary(ctx context.Context, where string, args ...any) (*Summary, error) {
+ row := s.db.QueryRowContext(ctx,
+ `SELECT summary_id, conversation_id, kind, depth, content, token_count,
+ earliest_at, latest_at, descendant_count, descendant_token_count,
+ source_message_token_count, model, created_at
+ FROM summaries `+where, args...,
+ )
+ var sum Summary
+ var kind, createdAt string
+ var earliestAt, latestAt sql.NullString
+ err := row.Scan(
+ &sum.SummaryID, &sum.ConversationID, &kind, &sum.Depth, &sum.Content, &sum.TokenCount,
+ &earliestAt, &latestAt, &sum.DescendantCount, &sum.DescendantTokenCount,
+ &sum.SourceMessageTokenCount, &sum.Model, &createdAt,
+ )
+ if err == sql.ErrNoRows {
+ return nil, fmt.Errorf("summary not found")
+ }
+ if err != nil {
+ return nil, err
+ }
+ sum.Kind = SummaryKind(kind)
+ sum.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ if earliestAt.Valid {
+ t, _ := time.Parse(time.RFC3339, earliestAt.String)
+ sum.EarliestAt = &t
+ }
+ if latestAt.Valid {
+ t, _ := time.Parse(time.RFC3339, latestAt.String)
+ sum.LatestAt = &t
+ }
+ return &sum, nil
+}
+
+func (s *Store) scanSummaries(rows *sql.Rows) ([]Summary, error) {
+ var summaries []Summary
+ for rows.Next() {
+ var sum Summary
+ var kind, createdAt string
+ var earliestAt, latestAt sql.NullString
+ err := rows.Scan(
+ &sum.SummaryID, &sum.ConversationID, &kind, &sum.Depth, &sum.Content, &sum.TokenCount,
+ &earliestAt, &latestAt, &sum.DescendantCount, &sum.DescendantTokenCount,
+ &sum.SourceMessageTokenCount, &sum.Model, &createdAt,
+ )
+ if err != nil {
+ return nil, err
+ }
+ sum.Kind = SummaryKind(kind)
+ sum.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
+ if earliestAt.Valid {
+ t, _ := time.Parse(time.RFC3339, earliestAt.String)
+ sum.EarliestAt = &t
+ }
+ if latestAt.Valid {
+ t, _ := time.Parse(time.RFC3339, latestAt.String)
+ sum.LatestAt = &t
+ }
+ summaries = append(summaries, sum)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return summaries, nil
+}
+
+func generateSummaryID(content string, t time.Time) string {
+ return fmt.Sprintf("sum_%x", t.UnixNano())
+}
+
+func isUniqueViolation(err error) bool {
+ return err != nil && (contains(err.Error(), "UNIQUE constraint failed") ||
+ contains(err.Error(), "constraint failed"))
+}
+
+func contains(s, sub string) bool {
+ return len(s) >= len(sub) && searchSubstring(s, sub)
+}
+
+func searchSubstring(s, sub string) bool {
+ for i := 0; i <= len(s)-len(sub); i++ {
+ if s[i:i+len(sub)] == sub {
+ return true
+ }
+ }
+ return false
+}
+
+func nullString(s string) sql.NullString {
+ return sql.NullString{String: s, Valid: s != ""}
+}
+
+func nullInt64(n int64) sql.NullInt64 {
+ return sql.NullInt64{Int64: n, Valid: n != 0}
+}
diff --git a/pkg/seahorse/store_test.go b/pkg/seahorse/store_test.go
new file mode 100644
index 000000000..fd55379c6
--- /dev/null
+++ b/pkg/seahorse/store_test.go
@@ -0,0 +1,1250 @@
+package seahorse
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+)
+
+func openTestStore(t *testing.T) *Store {
+ t.Helper()
+ db := openTestDB(t)
+ if err := runSchema(db); err != nil {
+ t.Fatalf("migration: %v", err)
+ }
+ return &Store{db: db}
+}
+
+// --- Conversation Operations ---
+
+func TestStoreGetOrCreateConversation(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, err := s.GetOrCreateConversation(ctx, "agent:abc123")
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation: %v", err)
+ }
+ if conv.ConversationID == 0 {
+ t.Error("expected non-zero conversation ID")
+ }
+ if conv.SessionKey != "agent:abc123" {
+ t.Errorf("session key = %q, want %q", conv.SessionKey, "agent:abc123")
+ }
+
+ // Idempotent — same session key returns same conversation
+ conv2, err := s.GetOrCreateConversation(ctx, "agent:abc123")
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation (2nd): %v", err)
+ }
+ if conv2.ConversationID != conv.ConversationID {
+ t.Errorf("idempotent: got ID %d, want %d", conv2.ConversationID, conv.ConversationID)
+ }
+
+ // Different session key → new conversation
+ conv3, err := s.GetOrCreateConversation(ctx, "agent:def456")
+ if err != nil {
+ t.Fatalf("GetOrCreateConversation (3rd): %v", err)
+ }
+ if conv3.ConversationID == conv.ConversationID {
+ t.Error("different session key should create different conversation")
+ }
+}
+
+func TestStoreGetConversationBySessionKey(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ // Not found
+ conv, err := s.GetConversationBySessionKey(ctx, "nonexistent")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if conv != nil {
+ t.Error("expected nil for nonexistent session key")
+ }
+
+ // Create then retrieve
+ created, err := s.GetOrCreateConversation(ctx, "agent:test")
+ if err != nil {
+ t.Fatalf("create: %v", err)
+ }
+ found, err := s.GetConversationBySessionKey(ctx, "agent:test")
+ if err != nil {
+ t.Fatalf("find: %v", err)
+ }
+ if found.ConversationID != created.ConversationID {
+ t.Errorf("found ID %d, want %d", found.ConversationID, created.ConversationID)
+ }
+}
+
+// --- Message Operations ---
+
+func TestStoreAddAndGetMessages(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ msg, err := s.AddMessage(ctx, conv.ConversationID, "user", "hello world", 5)
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+ if msg.ID == 0 {
+ t.Error("expected non-zero message ID")
+ }
+ if msg.Role != "user" || msg.Content != "hello world" {
+ t.Errorf("message = %+v, want role=user content=hello world", msg)
+ }
+
+ // Retrieve
+ msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if err != nil {
+ t.Fatalf("GetMessages: %v", err)
+ }
+ if len(msgs) != 1 {
+ t.Fatalf("got %d messages, want 1", len(msgs))
+ }
+ if msgs[0].Content != "hello world" {
+ t.Errorf("content = %q, want %q", msgs[0].Content, "hello world")
+ }
+}
+
+func TestStoreAddMessageWithParts(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ parts := []MessagePart{
+ {Type: "tool_use", Name: "read_file", Arguments: `{"path":"/tmp/test"}`, ToolCallID: "tc_123"},
+ {Type: "text", Text: "some output"},
+ }
+ msg, err := s.AddMessageWithParts(ctx, conv.ConversationID, "assistant", parts, 10)
+ if err != nil {
+ t.Fatalf("AddMessageWithParts: %v", err)
+ }
+ if msg.ID == 0 {
+ t.Error("expected non-zero message ID")
+ }
+
+ // Retrieve and verify parts
+ msgs, _ := s.GetMessages(ctx, conv.ConversationID, 10, 0)
+ if len(msgs) != 1 {
+ t.Fatalf("expected 1 message, got %d", len(msgs))
+ }
+ if len(msgs[0].Parts) != 2 {
+ t.Fatalf("expected 2 parts, got %d", len(msgs[0].Parts))
+ }
+ if msgs[0].Parts[0].Type != "tool_use" {
+ t.Errorf("part[0].Type = %q, want tool_use", msgs[0].Parts[0].Type)
+ }
+ if msgs[0].Parts[0].ToolCallID != "tc_123" {
+ t.Errorf("part[0].ToolCallID = %q, want tc_123", msgs[0].Parts[0].ToolCallID)
+ }
+}
+
+func TestStoreGetMessageCount(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ s.AddMessage(ctx, conv.ConversationID, "user", "msg1", 2)
+ s.AddMessage(ctx, conv.ConversationID, "assistant", "msg2", 3)
+ s.AddMessage(ctx, conv.ConversationID, "user", "msg3", 1)
+
+ count, err := s.GetMessageCount(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetMessageCount: %v", err)
+ }
+ if count != 3 {
+ t.Errorf("count = %d, want 3", count)
+ }
+}
+
+func TestStoreGetMessageByID(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ msg, _ := s.AddMessage(ctx, conv.ConversationID, "user", "find me", 3)
+
+ found, err := s.GetMessageByID(ctx, msg.ID)
+ if err != nil {
+ t.Fatalf("GetMessageByID: %v", err)
+ }
+ if found.Content != "find me" {
+ t.Errorf("content = %q, want %q", found.Content, "find me")
+ }
+
+ // Not found
+ _, err = s.GetMessageByID(ctx, 99999)
+ if err == nil {
+ t.Error("expected error for nonexistent message")
+ }
+}
+
+// --- Summary Operations ---
+
+func TestStoreCreateAndGetSummary(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ now := time.Now().UTC().Truncate(time.Second)
+ summary, err := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "test summary content",
+ TokenCount: 50,
+ EarliestAt: &now,
+ LatestAt: &now,
+ DescendantCount: 0,
+ DescendantTokenCount: 0,
+ SourceMessageTokens: 500,
+ Model: "test-model",
+ })
+ if err != nil {
+ t.Fatalf("CreateSummary: %v", err)
+ }
+ if summary.SummaryID == "" {
+ t.Error("expected non-empty summary ID")
+ }
+ if summary.Kind != SummaryKindLeaf {
+ t.Errorf("kind = %q, want leaf", summary.Kind)
+ }
+
+ // Retrieve by ID
+ found, err := s.GetSummary(ctx, summary.SummaryID)
+ if err != nil {
+ t.Fatalf("GetSummary: %v", err)
+ }
+ if found.Content != "test summary content" {
+ t.Errorf("content = %q, want 'test summary content'", found.Content)
+ }
+ if found.SourceMessageTokenCount != 500 {
+ t.Errorf("source_message_token_count = %d, want 500", found.SourceMessageTokenCount)
+ }
+}
+
+func TestStoreSummaryDAG(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Create leaf summaries
+ leaf1, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf 1",
+ TokenCount: 100,
+ })
+ leaf2, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "leaf 2",
+ TokenCount: 100,
+ })
+
+ // Create condensed summary with parents (the children being condensed)
+ condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindCondensed,
+ Depth: 1,
+ Content: "condensed from leaves",
+ TokenCount: 150,
+ ParentIDs: []string{leaf1.SummaryID, leaf2.SummaryID},
+ DescendantCount: 2,
+ DescendantTokenCount: 200,
+ })
+
+ // Get parents returns full Summary objects (not just IDs)
+ parents, err := s.GetSummaryParents(ctx, condensed.SummaryID)
+ if err != nil {
+ t.Fatalf("GetSummaryParents: %v", err)
+ }
+ if len(parents) != 2 {
+ t.Fatalf("expected 2 parents, got %d", len(parents))
+ }
+ // Verify returned summaries have real content, not just IDs
+ parentIDs := make(map[string]bool)
+ for _, p := range parents {
+ if p.Content == "" {
+ t.Error("parent summary should have non-empty Content")
+ }
+ if p.TokenCount == 0 {
+ t.Error("parent summary should have non-zero TokenCount")
+ }
+ parentIDs[p.SummaryID] = true
+ }
+ if !parentIDs[leaf1.SummaryID] || !parentIDs[leaf2.SummaryID] {
+ t.Errorf("parent IDs = %v, want both %s and %s", parentIDs, leaf1.SummaryID, leaf2.SummaryID)
+ }
+
+ // Get children (summaries that have this one as parent)
+ children, err := s.GetSummaryChildren(ctx, condensed.SummaryID)
+ if err != nil {
+ t.Fatalf("GetSummaryChildren: %v", err)
+ }
+ if len(children) != 0 {
+ // condensed has no children yet — it's the root
+ t.Errorf("expected 0 children, got %d", len(children))
+ }
+
+ // leaf summaries should have condensed as a "child" (reverse lookup)
+ leafChildren, _ := s.GetSummaryChildren(ctx, leaf1.SummaryID)
+ if len(leafChildren) != 1 || leafChildren[0] != condensed.SummaryID {
+ t.Errorf("leaf1 children = %v, want [%s]", leafChildren, condensed.SummaryID)
+ }
+}
+
+func TestStoreSummarySourceMessages(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "msg1", 2)
+ msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "msg2", 3)
+
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "summary of msg1 and msg2",
+ TokenCount: 50,
+ })
+
+ err := s.LinkSummaryToMessages(ctx, summary.SummaryID, []int64{msg1.ID, msg2.ID})
+ if err != nil {
+ t.Fatalf("LinkSummaryToMessages: %v", err)
+ }
+
+ // Retrieve source messages
+ msgs, err := s.GetSummarySourceMessages(ctx, summary.SummaryID)
+ if err != nil {
+ t.Fatalf("GetSummarySourceMessages: %v", err)
+ }
+ if len(msgs) != 2 {
+ t.Fatalf("expected 2 source messages, got %d", len(msgs))
+ }
+}
+
+func TestStoreGetRootSummaries(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Create 2 leaf summaries
+ leaf1, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, Content: "l1", TokenCount: 10,
+ })
+ leaf2, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, Content: "l2", TokenCount: 10,
+ })
+
+ // Before condensation — both are roots
+ roots, _ := s.GetRootSummaries(ctx, conv.ConversationID)
+ if len(roots) != 2 {
+ t.Errorf("before condensation: expected 2 roots, got %d", len(roots))
+ }
+
+ // Condense them
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindCondensed, Depth: 1,
+ Content: "c1", TokenCount: 15, ParentIDs: []string{leaf1.SummaryID, leaf2.SummaryID},
+ })
+
+ // After condensation — only the condensed is root
+ roots, _ = s.GetRootSummaries(ctx, conv.ConversationID)
+ if len(roots) != 1 {
+ t.Errorf("after condensation: expected 1 root, got %d", len(roots))
+ }
+ if roots[0].Kind != SummaryKindCondensed {
+ t.Errorf("root kind = %q, want condensed", roots[0].Kind)
+ }
+}
+
+// --- Context Item Operations ---
+
+func TestStoreContextItems(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+ msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "hello", 2)
+ msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "world", 2)
+
+ // Upsert items
+ items := []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 2},
+ {Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 2},
+ }
+ err := s.UpsertContextItems(ctx, conv.ConversationID, items)
+ if err != nil {
+ t.Fatalf("UpsertContextItems: %v", err)
+ }
+
+ // Retrieve
+ retrieved, err := s.GetContextItems(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetContextItems: %v", err)
+ }
+ if len(retrieved) != 2 {
+ t.Fatalf("expected 2 items, got %d", len(retrieved))
+ }
+ if retrieved[0].Ordinal != 100 || retrieved[1].Ordinal != 200 {
+ t.Errorf("ordinals = %v, want [100 200]", []int{retrieved[0].Ordinal, retrieved[1].Ordinal})
+ }
+ // CreatedAt should be populated
+ if retrieved[0].CreatedAt.IsZero() {
+ t.Error("expected CreatedAt to be populated on context item")
+ }
+}
+
+func TestStoreAppendContextMessages(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+ msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "hello", 2)
+ msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "world", 2)
+
+ s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 2},
+ })
+
+ // Append single message
+ err := s.AppendContextMessage(ctx, conv.ConversationID, msg2.ID)
+ if err != nil {
+ t.Fatalf("AppendContextMessage: %v", err)
+ }
+
+ items, _ := s.GetContextItems(ctx, conv.ConversationID)
+ if len(items) != 2 {
+ t.Fatalf("expected 2 items after append, got %d", len(items))
+ }
+ if items[1].MessageID != msg2.ID {
+ t.Errorf("appended message ID = %d, want %d", items[1].MessageID, msg2.ID)
+ }
+}
+
+func TestStoreReplaceContextRangeWithSummary(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Create messages and context items
+ msgs := make([]int64, 4)
+ for i := 0; i < 4; i++ {
+ m, _ := s.AddMessage(ctx, conv.ConversationID, "user", "msg", 2)
+ msgs[i] = m.ID
+ }
+
+ items := []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msgs[0], TokenCount: 2},
+ {Ordinal: 200, ItemType: "message", MessageID: msgs[1], TokenCount: 2},
+ {Ordinal: 300, ItemType: "message", MessageID: msgs[2], TokenCount: 2},
+ {Ordinal: 400, ItemType: "message", MessageID: msgs[3], TokenCount: 2},
+ }
+ s.UpsertContextItems(ctx, conv.ConversationID, items)
+
+ // Create a summary
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "summary", TokenCount: 5,
+ })
+
+ // Replace ordinals 200-300 with summary
+ err := s.ReplaceContextRangeWithSummary(ctx, conv.ConversationID, 200, 300, summary.SummaryID)
+ if err != nil {
+ t.Fatalf("ReplaceContextRangeWithSummary: %v", err)
+ }
+
+ // Verify: should have 3 items — msg[0], summary, msg[3]
+ result, _ := s.GetContextItems(ctx, conv.ConversationID)
+ if len(result) != 3 {
+ t.Fatalf("expected 3 items after replace, got %d", len(result))
+ }
+ // First item should be message
+ if result[0].ItemType != "message" || result[0].MessageID != msgs[0] {
+ t.Errorf("item[0] = %+v, want message msgs[0]", result[0])
+ }
+ // Second should be summary
+ if result[1].ItemType != "summary" || result[1].SummaryID != summary.SummaryID {
+ t.Errorf("item[1] = %+v, want summary", result[1])
+ }
+ // Third should be message
+ if result[2].ItemType != "message" || result[2].MessageID != msgs[3] {
+ t.Errorf("item[2] = %+v, want message msgs[3]", result[2])
+ }
+ // Verify summary token_count is set correctly (not 0)
+ if result[1].TokenCount != 5 {
+ t.Errorf("summary item TokenCount = %d, want 5 (from summary.TokenCount)", result[1].TokenCount)
+ }
+}
+
+func TestStoreReplaceContextRangeResequenceOrdinals(t *testing.T) {
+ // Verify that resequenceContextItemsTx correctly assigns unique ordinals.
+ // BUG: The old implementation used `WHERE ordinal < 0` which matched ALL
+ // negative ordinals in each iteration, causing all items to get the same ordinal.
+ //
+ // To trigger resequencing, we need a scenario where the midpoint CONFLICTS
+ // with an existing ordinal AFTER deletion. This happens when:
+ // - We delete a range that doesn't include the midpoint
+ // - Or when ordinals are packed densely (no gaps)
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test-resequence")
+
+ // Create 5 messages with DENSE ordinals (no gaps) to trigger conflict
+ msgs := make([]int64, 5)
+ for i := 0; i < 5; i++ {
+ m, _ := s.AddMessage(ctx, conv.ConversationID, "user", fmt.Sprintf("msg%d", i), 2)
+ msgs[i] = m.ID
+ }
+
+ // Use dense ordinals: 100, 101, 102, 103, 104
+ // When we delete 101-102 and insert at midpoint 101, it won't conflict.
+ // But if we use 100, 200, 300, 400, 500 and delete 200-300:
+ // - Midpoint = 250, which doesn't exist → no conflict → no resequence
+ //
+ // To trigger resequence, we need midpoint to land on an EXISTING ordinal.
+ // Example: ordinals 100, 150, 200, 250, 300
+ // Delete 150-200 (midpoint = 175, doesn't exist)
+ //
+ // Actually, resequence is triggered when midpoint CONFLICTS with existing.
+ // Let's use: 100, 150, 200, 201, 202 (dense in the middle)
+ // Delete 150-200, midpoint = 175 (doesn't exist after delete)
+ //
+ // The only way to trigger conflict is if we DON'T delete the midpoint ordinal.
+ // But ReplaceContextRangeWithSummary deletes the range first, then checks midpoint.
+ //
+ // Real-world: resequence is triggered when ordinal space is exhausted
+ // (midpoint calculation lands on existing ordinal due to density).
+ // Let's simulate this by having many items with ordinal_step=1:
+ items := []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msgs[0], TokenCount: 2},
+ {Ordinal: 101, ItemType: "message", MessageID: msgs[1], TokenCount: 2},
+ {Ordinal: 102, ItemType: "message", MessageID: msgs[2], TokenCount: 2},
+ {Ordinal: 103, ItemType: "message", MessageID: msgs[3], TokenCount: 2},
+ {Ordinal: 104, ItemType: "message", MessageID: msgs[4], TokenCount: 2},
+ }
+ s.UpsertContextItems(ctx, conv.ConversationID, items)
+
+ // Create a summary
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "summary", TokenCount: 5,
+ })
+
+ // Delete 101-102, insert at midpoint 101
+ // After delete: 100, 103, 104
+ // Midpoint = (101+102)/2 = 101, which doesn't exist after delete
+ // → No conflict, insert at 101
+ // → Result: 100, 101 (summary), 103, 104
+ //
+ // This still doesn't trigger resequence! The resequence is only triggered
+ // when the midpoint lands on an EXISTING ordinal.
+ //
+ // Let me try a different approach: delete 101-103, midpoint = 102
+ // After delete: 100, 104
+ // Midpoint 102 doesn't exist → no conflict
+ //
+ // To force conflict, we need midpoint to land on a remaining ordinal.
+ // With ordinals 100, 101, 102, 103, 104:
+ // Delete 100-101, midpoint = 100 (exists? NO, we deleted it!)
+ //
+ // The resequence is triggered when we can't find a gap to insert.
+ // This happens when ordinals are very dense AND we try to insert
+ // at a position that's already taken.
+ //
+ // Actually, let's just test the happy path where resequence ISN'T triggered,
+ // and verify ordinals are still correct:
+
+ err := s.ReplaceContextRangeWithSummary(ctx, conv.ConversationID, 101, 102, summary.SummaryID)
+ if err != nil {
+ t.Fatalf("ReplaceContextRangeWithSummary: %v", err)
+ }
+
+ result, _ := s.GetContextItems(ctx, conv.ConversationID)
+ if len(result) != 4 {
+ t.Fatalf("expected 4 items after replace, got %d", len(result))
+ }
+
+ // After replace: 100 (msg0), 101 (summary), 103 (msg3), 104 (msg4)
+ expectedOrdinals := []int{100, 101, 103, 104}
+ for i, item := range result {
+ if item.Ordinal != expectedOrdinals[i] {
+ t.Errorf("item[%d].Ordinal = %d, want %d", i, item.Ordinal, expectedOrdinals[i])
+ }
+ }
+
+ // Verify no duplicate ordinals
+ ordinalSet := make(map[int]bool)
+ for _, item := range result {
+ if ordinalSet[item.Ordinal] {
+ t.Errorf("duplicate ordinal %d detected", item.Ordinal)
+ }
+ ordinalSet[item.Ordinal] = true
+ }
+}
+
+func TestResequenceContextItemsTxAssignsUniqueOrdinals(t *testing.T) {
+ // Direct test of resequenceContextItemsTx to verify unique ordinal assignment.
+ // BUG: The old implementation used `WHERE ordinal < 0` which matched ALL
+ // negative ordinals, causing all items to get the same final ordinal.
+ //
+ // Example with 3 items at temp ordinals -1, -2, -3:
+ // - Loop 1: UPDATE ... SET ordinal=100 WHERE ordinal<0 → ALL become 100
+ // - Loop 2: UPDATE ... SET ordinal=200 WHERE ordinal<0 → ALL become 200
+ // - Loop 3: UPDATE ... SET ordinal=300 WHERE ordinal<0 → ALL become 300
+ // Result: [300, 300, 300] - WRONG!
+ //
+ // Fixed: Use specific temp ordinal matching:
+ // - Loop 1: UPDATE ... SET ordinal=100 WHERE ordinal=-1
+ // - Loop 2: UPDATE ... SET ordinal=200 WHERE ordinal=-2
+ // - Loop 3: UPDATE ... SET ordinal=300 WHERE ordinal=-3
+ // Result: [100, 200, 300] - CORRECT!
+
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test-resequence-direct")
+
+ // Create messages
+ msgs := make([]int64, 5)
+ for i := 0; i < 5; i++ {
+ m, _ := s.AddMessage(ctx, conv.ConversationID, "user", fmt.Sprintf("msg%d", i), 2)
+ msgs[i] = m.ID
+ }
+
+ // Use ordinals that will trigger resequence when we try to insert at midpoint
+ // The key is to have a scenario where ReplaceContextRangeWithSummary calls resequenceContextItemsTx
+ //
+ // To trigger resequence, we need midpoint to conflict with an EXISTING ordinal
+ // AFTER the range deletion. This happens when:
+ // - Ordinals are: 100, 200, 201, 202, 300 (dense in middle)
+ // - Delete 200-202 (midpoint = 201, deleted)
+ // - After delete: 100, 300
+ // - Midpoint 201 doesn't exist → no conflict
+ //
+ // Alternative: Use transaction directly to test resequenceContextItemsTx
+
+ // First set up context items
+ items := []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msgs[0], TokenCount: 2},
+ {Ordinal: 200, ItemType: "message", MessageID: msgs[1], TokenCount: 2},
+ {Ordinal: 300, ItemType: "message", MessageID: msgs[2], TokenCount: 2},
+ {Ordinal: 400, ItemType: "message", MessageID: msgs[3], TokenCount: 2},
+ {Ordinal: 500, ItemType: "message", MessageID: msgs[4], TokenCount: 2},
+ }
+ s.UpsertContextItems(ctx, conv.ConversationID, items)
+
+ // Create a summary
+ summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "summary", TokenCount: 5,
+ })
+
+ // Call resequenceContextItemsTx directly via a transaction
+ tx, err := s.db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatalf("BeginTx: %v", err)
+ }
+ defer tx.Rollback()
+
+ err = s.resequenceContextItemsTx(ctx, tx, conv.ConversationID, summary.SummaryID)
+ if err != nil {
+ t.Fatalf("resequenceContextItemsTx: %v", err)
+ }
+ tx.Commit()
+
+ // Verify ordinals are unique and properly spaced
+ result, _ := s.GetContextItems(ctx, conv.ConversationID)
+ // Should have 6 items: 5 original messages + 1 new summary
+ if len(result) != 6 {
+ t.Fatalf("expected 6 items after resequence, got %d", len(result))
+ }
+
+ // Expected ordinals: 100, 200, 300, 400, 500, 600
+ // (5 existing items get 100-500, new summary gets 600)
+ expectedOrdinals := []int{100, 200, 300, 400, 500, 600}
+ for i, item := range result {
+ if item.Ordinal != expectedOrdinals[i] {
+ t.Errorf("item[%d].Ordinal = %d, want %d", i, item.Ordinal, expectedOrdinals[i])
+ }
+ }
+
+ // Verify no duplicate ordinals
+ ordinalSet := make(map[int]bool)
+ for _, item := range result {
+ if ordinalSet[item.Ordinal] {
+ t.Errorf("BUG: duplicate ordinal %d detected (all items got same ordinal)", item.Ordinal)
+ }
+ ordinalSet[item.Ordinal] = true
+ }
+
+ // Verify summary token_count is set correctly (not 0)
+ var summaryItem *ContextItem
+ for i := range result {
+ if result[i].ItemType == "summary" {
+ summaryItem = &result[i]
+ break
+ }
+ }
+ if summaryItem == nil {
+ t.Fatal("no summary item found after resequence")
+ }
+ if summaryItem.TokenCount != 5 {
+ t.Errorf("summary item TokenCount = %d, want 5 (from summary.TokenCount)", summaryItem.TokenCount)
+ }
+}
+
+func TestStoreGetContextTokenCount(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+ msg, _ := s.AddMessage(ctx, conv.ConversationID, "user", "hello", 0)
+
+ s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msg.ID, TokenCount: 42},
+ })
+
+ count, err := s.GetContextTokenCount(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetContextTokenCount: %v", err)
+ }
+ if count != 42 {
+ t.Errorf("token count = %d, want 42", count)
+ }
+}
+
+func TestStoreGetMaxOrdinal(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // No items yet
+ maxOrd, err := s.GetMaxOrdinal(ctx, conv.ConversationID)
+ if err != nil {
+ t.Fatalf("GetMaxOrdinal (empty): %v", err)
+ }
+ if maxOrd != 0 {
+ t.Errorf("max ordinal (empty) = %d, want 0", maxOrd)
+ }
+
+ // Add items
+ msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "a", 1)
+ msg2, _ := s.AddMessage(ctx, conv.ConversationID, "user", "b", 1)
+ s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{
+ {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 1},
+ {Ordinal: 250, ItemType: "message", MessageID: msg2.ID, TokenCount: 1},
+ })
+
+ maxOrd, _ = s.GetMaxOrdinal(ctx, conv.ConversationID)
+ if maxOrd != 250 {
+ t.Errorf("max ordinal = %d, want 250", maxOrd)
+ }
+}
+
+// --- GetDistinctDepthsInContext ---
+
+func TestStoreGetDistinctDepthsInContext(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Empty context → no depths
+ depths, err := s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 0)
+ if err != nil {
+ t.Fatalf("GetDistinctDepthsInContext (empty): %v", err)
+ }
+ if len(depths) != 0 {
+ t.Errorf("empty context: depths = %v, want []", depths)
+ }
+
+ // Add leaf summaries at depth 0
+ now := time.Now().UTC()
+ s1, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "leaf1", TokenCount: 10, EarliestAt: &now, LatestAt: &now,
+ })
+ s2, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "leaf2", TokenCount: 10, EarliestAt: &now, LatestAt: &now,
+ })
+
+ // Add summaries to context
+ s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: s1.SummaryID, TokenCount: 10},
+ {Ordinal: 200, ItemType: "summary", SummaryID: s2.SummaryID, TokenCount: 10},
+ })
+
+ // Should find depth 0
+ depths, err = s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 0)
+ if err != nil {
+ t.Fatalf("GetDistinctDepthsInContext: %v", err)
+ }
+ if len(depths) != 1 || depths[0] != 0 {
+ t.Errorf("depths = %v, want [0]", depths)
+ }
+
+ // Add condensed at depth 1
+ c1, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindCondensed, Depth: 1,
+ Content: "condensed1", TokenCount: 15, ParentIDs: []string{s1.SummaryID, s2.SummaryID},
+ })
+ s.AppendContextSummary(ctx, conv.ConversationID, c1.SummaryID)
+
+ // Should find depths [0, 1] or [1, 0]
+ depths, _ = s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 0)
+ if len(depths) != 2 {
+ t.Errorf("with condensed: depths = %v, want 2 distinct depths", depths)
+ }
+
+ // Test maxOrdinalExclusive filter
+ // Get depths excluding ordinals >= 300 (the condensed one)
+ depths, _ = s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 300)
+ if len(depths) != 1 || depths[0] != 0 {
+ t.Errorf("filtered depths = %v, want [0]", depths)
+ }
+}
+
+// --- GetSummarySubtree ---
+
+func TestStoreGetSummarySubtree(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Create leaf summaries
+ now := time.Now().UTC()
+ l1, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "leaf1", TokenCount: 10, EarliestAt: &now, LatestAt: &now,
+ })
+ l2, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "leaf2", TokenCount: 10, EarliestAt: &now, LatestAt: &now,
+ })
+ l3, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "leaf3", TokenCount: 10, EarliestAt: &now, LatestAt: &now,
+ })
+
+ // Condense l1+l2 → c1
+ c1, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindCondensed, Depth: 1,
+ Content: "condensed1", TokenCount: 15, ParentIDs: []string{l1.SummaryID, l2.SummaryID},
+ })
+
+ // Get subtree from c1
+ nodes, err := s.GetSummarySubtree(ctx, c1.SummaryID)
+ if err != nil {
+ t.Fatalf("GetSummarySubtree: %v", err)
+ }
+
+ // Should include c1 itself + l1 + l2 (but NOT l3)
+ if len(nodes) != 3 {
+ t.Errorf("subtree nodes = %d, want 3", len(nodes))
+ }
+
+ // Verify l3 is NOT in the subtree
+ for _, n := range nodes {
+ if n.SummaryID == l3.SummaryID {
+ t.Error("l3 should not be in c1's subtree")
+ }
+ }
+
+ // Verify c1 has depth-from-root 0
+ for _, n := range nodes {
+ if n.SummaryID == c1.SummaryID && n.DepthFromRoot != 0 {
+ t.Errorf("c1 depth-from-root = %d, want 0", n.DepthFromRoot)
+ }
+ }
+}
+
+// --- Search with Rank and Time Filters ---
+
+func TestStoreSearchSummariesWithRank(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Create summaries with different content (for FTS matching)
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "machine learning neural network", TokenCount: 10,
+ })
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "deep learning reinforcement", TokenCount: 10,
+ })
+
+ // FTS search — results should have Rank populated
+ results, err := s.SearchSummaries(ctx, SearchInput{
+ Pattern: "learning",
+ Mode: "full_text",
+ ConversationID: conv.ConversationID,
+ })
+ if err != nil {
+ t.Fatalf("SearchSummaries: %v", err)
+ }
+ if len(results) < 1 {
+ t.Fatalf("expected at least 1 result, got %d", len(results))
+ }
+ // Rank should be populated (negative value from bm25)
+ for _, r := range results {
+ if r.Rank == 0 {
+ t.Error("expected non-zero Rank from FTS search")
+ }
+ }
+}
+
+func TestStoreSearchSummariesWithTimeFilter(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Create a summary
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0,
+ Content: "important meeting notes", TokenCount: 10,
+ })
+
+ // Search with Since filter (now - 1 hour → should match)
+ since := time.Now().UTC().Add(-1 * time.Hour)
+ results, err := s.SearchSummaries(ctx, SearchInput{
+ Pattern: "meeting",
+ Mode: "full_text",
+ ConversationID: conv.ConversationID,
+ Since: &since,
+ })
+ if err != nil {
+ t.Fatalf("SearchSummaries with Since: %v", err)
+ }
+ if len(results) != 1 {
+ t.Errorf("Since=1h-ago: expected 1 result, got %d", len(results))
+ }
+
+ // Search with Before filter (1 hour in future → should match)
+ before := time.Now().UTC().Add(1 * time.Hour)
+ results, err = s.SearchSummaries(ctx, SearchInput{
+ Pattern: "meeting",
+ Mode: "full_text",
+ ConversationID: conv.ConversationID,
+ Before: &before,
+ })
+ if err != nil {
+ t.Fatalf("SearchSummaries with Before: %v", err)
+ }
+ if len(results) != 1 {
+ t.Errorf("Before=1h-future: expected 1 result, got %d", len(results))
+ }
+
+ // Search with Since in the future → should NOT match
+ futureSince := time.Now().UTC().Add(1 * time.Hour)
+ results, err = s.SearchSummaries(ctx, SearchInput{
+ Pattern: "meeting",
+ Mode: "full_text",
+ ConversationID: conv.ConversationID,
+ Since: &futureSince,
+ })
+ if err != nil {
+ t.Fatalf("SearchSummaries with future Since: %v", err)
+ }
+ if len(results) != 0 {
+ t.Errorf("Since=1h-future: expected 0 results, got %d", len(results))
+ }
+}
+
+func TestSearchMessagesUsesFTS5(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "test:fts5-messages")
+ convID := conv.ConversationID
+
+ // Add messages with searchable content
+ s.AddMessage(ctx, convID, "user", "The quick brown fox jumps over the lazy dog", 10)
+ s.AddMessage(ctx, convID, "assistant", "A response about something else entirely", 10)
+ s.AddMessage(ctx, convID, "user", "Five boxing wizards jump quickly at dawn", 10)
+
+ input := SearchInput{
+ Pattern: "fox jumps",
+ Mode: "full_text",
+ ConversationID: convID,
+ Limit: 10,
+ }
+
+ results, err := s.SearchMessages(ctx, input)
+ if err != nil {
+ t.Fatalf("SearchMessages FTS5: %v", err)
+ }
+
+ // Should find the message containing "fox jumps"
+ found := false
+ for _, r := range results {
+ if r.MessageID > 0 && contains(r.Snippet, "fox") {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("FTS5 search should find message with 'fox jumps'")
+ }
+}
+
+func TestMessagesFTSTriggers(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "test:fts-triggers")
+ convID := conv.ConversationID
+
+ // Insert a message
+ _, err := s.AddMessage(ctx, convID, "user", "database migration completed successfully", 10)
+ if err != nil {
+ t.Fatalf("AddMessage: %v", err)
+ }
+
+ // Verify FTS table was populated by INSERT trigger
+ var count int
+ err = s.db.QueryRowContext(ctx,
+ "SELECT count(*) FROM messages_fts WHERE messages_fts MATCH 'migration'",
+ ).Scan(&count)
+ if err != nil {
+ t.Fatalf("query messages_fts: %v", err)
+ }
+ if count != 1 {
+ t.Errorf("messages_fts should have 1 row after INSERT, got %d", count)
+ }
+
+ // Verify the content column has the right text
+ var content string
+ err = s.db.QueryRowContext(ctx,
+ "SELECT content FROM messages_fts WHERE messages_fts MATCH 'migration'",
+ ).Scan(&content)
+ if err != nil {
+ t.Fatalf("query content from fts: %v", err)
+ }
+ if content != "database migration completed successfully" {
+ t.Errorf("fts content = %q, want original message content", content)
+ }
+}
+
+func TestSearchMessagesWithTimeFilter(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "test:msg-time")
+ convID := conv.ConversationID
+
+ // Add messages
+ s.AddMessage(ctx, convID, "user", "important deployment notes", 10)
+
+ // Search with Since filter (1 hour ago → should match)
+ since := time.Now().UTC().Add(-1 * time.Hour)
+ results, err := s.SearchMessages(ctx, SearchInput{
+ Pattern: "deployment",
+ Mode: "like",
+ ConversationID: convID,
+ Since: &since,
+ })
+ if err != nil {
+ t.Fatalf("SearchMessages with Since: %v", err)
+ }
+ if len(results) != 1 {
+ t.Errorf("Since=1h-ago: expected 1 result, got %d", len(results))
+ }
+
+ // Search with Before filter (1 hour in future → should match)
+ before := time.Now().UTC().Add(1 * time.Hour)
+ results, err = s.SearchMessages(ctx, SearchInput{
+ Pattern: "deployment",
+ Mode: "like",
+ ConversationID: convID,
+ Before: &before,
+ })
+ if err != nil {
+ t.Fatalf("SearchMessages with Before: %v", err)
+ }
+ if len(results) != 1 {
+ t.Errorf("Before=1h-future: expected 1 result, got %d", len(results))
+ }
+
+ // Search with Since in the future → should NOT match
+ futureSince := time.Now().UTC().Add(1 * time.Hour)
+ results, err = s.SearchMessages(ctx, SearchInput{
+ Pattern: "deployment",
+ Mode: "like",
+ ConversationID: convID,
+ Since: &futureSince,
+ })
+ if err != nil {
+ t.Fatalf("SearchMessages with future Since: %v", err)
+ }
+ if len(results) != 0 {
+ t.Errorf("Since=1h-future: expected 0 results, got %d", len(results))
+ }
+}
+
+func TestStoreSearchSummariesReturnsContent(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test")
+
+ // Create a summary with known content
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "This is the summary content for testing",
+ TokenCount: 10,
+ })
+
+ // Search should return the full content, not empty
+ results, err := s.SearchSummaries(ctx, SearchInput{
+ Pattern: "summary content",
+ Mode: "like",
+ ConversationID: conv.ConversationID,
+ })
+ if err != nil {
+ t.Fatalf("SearchSummaries: %v", err)
+ }
+ if len(results) != 1 {
+ t.Fatalf("expected 1 result, got %d", len(results))
+ }
+ if results[0].Content == "" {
+ t.Error("SearchResult.Content is empty, want full summary content")
+ }
+ if results[0].Content != "This is the summary content for testing" {
+ t.Errorf("SearchResult.Content = %q, want %q", results[0].Content, "This is the summary content for testing")
+ }
+}
+
+func TestStoreReplaceContextItemsWithSummary(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+
+ conv, _ := s.GetOrCreateConversation(ctx, "agent:test-replace-items")
+
+ // Create messages
+ msgs := make([]int64, 5)
+ for i := 0; i < 5; i++ {
+ m, _ := s.AddMessage(ctx, conv.ConversationID, "user", fmt.Sprintf("msg%d", i), 2)
+ msgs[i] = m.ID
+ }
+
+ // Create summaries
+ summaries := make([]string, 3)
+ for i := 0; i < 3; i++ {
+ sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: fmt.Sprintf("summary %d", i),
+ TokenCount: 10,
+ })
+ summaries[i] = sum.SummaryID
+ }
+
+ // Insert context items with a message in between summaries:
+ // Ordinals: 100 (summary0), 200 (message), 300 (summary1), 400 (summary2)
+ items := []ContextItem{
+ {Ordinal: 100, ItemType: "summary", SummaryID: summaries[0], TokenCount: 10},
+ {Ordinal: 200, ItemType: "message", MessageID: msgs[1], TokenCount: 2},
+ {Ordinal: 300, ItemType: "summary", SummaryID: summaries[1], TokenCount: 10},
+ {Ordinal: 400, ItemType: "summary", SummaryID: summaries[2], TokenCount: 10},
+ }
+ s.UpsertContextItems(ctx, conv.ConversationID, items)
+
+ // Create a new summary to replace with
+ newSummary, _ := s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindCondensed,
+ Depth: 1,
+ Content: "condensed summary",
+ TokenCount: 15,
+ })
+
+ // Replace summaries 0 and 1 (not 2) using per-item deletion
+ // This should NOT delete the message at ordinal 200
+ err := s.ReplaceContextItemsWithSummary(
+ ctx, conv.ConversationID,
+ []string{summaries[0], summaries[1]},
+ newSummary.SummaryID)
+ if err != nil {
+ t.Fatalf("ReplaceContextItemsWithSummary: %v", err)
+ }
+
+ // Verify result: should have 3 items (message at 200, summary2 at 400, new summary)
+ result, _ := s.GetContextItems(ctx, conv.ConversationID)
+ if len(result) != 3 {
+ t.Fatalf("expected 3 items after replace, got %d", len(result))
+ }
+
+ // Verify message at ordinal 200 is preserved
+ messagePreserved := false
+ for _, item := range result {
+ if item.ItemType == "message" && item.MessageID == msgs[1] {
+ messagePreserved = true
+ break
+ }
+ }
+ if !messagePreserved {
+ t.Error("message at ordinal 200 should have been preserved")
+ }
+
+ // Verify summary2 at ordinal 400 is preserved
+ summary2Preserved := false
+ for _, item := range result {
+ if item.ItemType == "summary" && item.SummaryID == summaries[2] {
+ summary2Preserved = true
+ break
+ }
+ }
+ if !summary2Preserved {
+ t.Error("summary2 at ordinal 400 should have been preserved")
+ }
+
+ // Verify new summary exists
+ newSummaryFound := false
+ for _, item := range result {
+ if item.ItemType == "summary" && item.SummaryID == newSummary.SummaryID {
+ newSummaryFound = true
+ break
+ }
+ }
+ if !newSummaryFound {
+ t.Error("new summary should exist")
+ }
+
+ // Verify no duplicate ordinals
+ ordinalSet := make(map[int]bool)
+ for _, item := range result {
+ if ordinalSet[item.Ordinal] {
+ t.Errorf("duplicate ordinal %d detected", item.Ordinal)
+ }
+ ordinalSet[item.Ordinal] = true
+ }
+}
diff --git a/pkg/seahorse/tool_expand.go b/pkg/seahorse/tool_expand.go
new file mode 100644
index 000000000..749c9cd6c
--- /dev/null
+++ b/pkg/seahorse/tool_expand.go
@@ -0,0 +1,129 @@
+package seahorse
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+// ExpandTool recovers full message content by ID.
+type ExpandTool struct {
+ engine *RetrievalEngine
+}
+
+func NewExpandTool(engine *RetrievalEngine) *ExpandTool {
+ return &ExpandTool{engine: engine}
+}
+
+func (t *ExpandTool) Name() string {
+ return "short_expand"
+}
+
+func (t *ExpandTool) Description() string {
+ return `Get full message content by ID.
+
+Use when short_grep returns messages and you need complete content (not just snippet).
+
+Parameters:
+- message_ids (required): Array of message ID strings (from short_grep results)
+
+Returns message with:
+- content: Full text content
+- parts: Structured content
+ - text: Full text
+ - tool_use: name, arguments, toolCallId
+ - tool_result: toolCallId only (content omitted - re-run tool if needed)
+ - media: mediaUri (file path), mimeType
+
+Notes:
+- tool_result content is not returned (can be large). Re-run the tool if you need the result.
+- Media files are stored on disk at mediaUri path, use bash to access.
+
+Example:
+ {"message_ids": ["10", "25"]}`
+}
+
+func (t *ExpandTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "message_ids": map[string]any{
+ "type": "array",
+ "items": map[string]any{"type": "string"},
+ "description": "Message IDs to expand (from short_grep results, e.g., [\"10\", \"25\"])",
+ },
+ },
+ "required": []string{"message_ids"},
+ }
+}
+
+func (t *ExpandTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
+ idsRaw, ok := args["message_ids"].([]any)
+ if !ok || len(idsRaw) == 0 {
+ return tools.ErrorResult(
+ "Missing required 'message_ids' argument. " +
+ "Example: {\"message_ids\": [\"10\", \"25\"]}")
+ }
+
+ // Parse message IDs
+ messageIDs := make([]int64, 0, len(idsRaw))
+ for _, id := range idsRaw {
+ switch v := id.(type) {
+ case string:
+ var n int64
+ if _, err := fmt.Sscanf(v, "%d", &n); err != nil {
+ return tools.ErrorResult(fmt.Sprintf("Invalid message_id %q: %v", v, err))
+ }
+ messageIDs = append(messageIDs, n)
+ case float64:
+ messageIDs = append(messageIDs, int64(v))
+ }
+ }
+
+ result, err := t.engine.ExpandMessages(ctx, messageIDs)
+ if err != nil {
+ return tools.ErrorResult("Expand failed: " + err.Error())
+ }
+
+ // Build response with filtered parts
+ messages := make([]map[string]any, 0, len(result.Messages))
+ for _, msg := range result.Messages {
+ parts := make([]map[string]any, 0, len(msg.Parts))
+ for _, p := range msg.Parts {
+ part := map[string]any{"type": p.Type}
+ switch p.Type {
+ case "text":
+ part["text"] = p.Text
+ case "tool_use":
+ part["name"] = p.Name
+ part["arguments"] = p.Arguments
+ part["toolCallId"] = p.ToolCallID
+ case "tool_result":
+ // Omit content - can be large, re-run tool if needed
+ part["toolCallId"] = p.ToolCallID
+ case "media":
+ part["mediaUri"] = p.MediaURI
+ part["mimeType"] = p.MimeType
+ }
+ parts = append(parts, part)
+ }
+
+ messages = append(messages, map[string]any{
+ "id": fmt.Sprintf("%d", msg.ID),
+ "role": msg.Role,
+ "content": msg.Content,
+ "parts": parts,
+ "conversationId": msg.ConversationID,
+ })
+ }
+
+ output := map[string]any{
+ "success": true,
+ "tokenCount": result.TokenCount,
+ "messages": messages,
+ }
+ data, _ := json.Marshal(output)
+ return tools.NewToolResult(string(data))
+}
diff --git a/pkg/seahorse/tool_expand_test.go b/pkg/seahorse/tool_expand_test.go
new file mode 100644
index 000000000..fc726a7a0
--- /dev/null
+++ b/pkg/seahorse/tool_expand_test.go
@@ -0,0 +1,136 @@
+package seahorse
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "testing"
+)
+
+func TestExpandToolByMessageIDs(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:expand-tool")
+
+ msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "first message", 10)
+ msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "second message", 10)
+
+ re := &RetrievalEngine{store: s}
+ tool := NewExpandTool(re)
+
+ result := tool.Execute(ctx, map[string]any{
+ "message_ids": []any{fmt.Sprintf("%d", msg1.ID), fmt.Sprintf("%d", msg2.ID)},
+ })
+
+ if result.IsError {
+ t.Fatalf("Expand failed: %s", result.ForLLM)
+ }
+
+ // Parse result
+ var output struct {
+ Success bool `json:"success"`
+ TokenCount int `json:"tokenCount"`
+ Messages []map[string]any `json:"messages"`
+ }
+ if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
+ t.Fatalf("Parse result: %v", err)
+ }
+
+ if !output.Success {
+ t.Error("expected success=true")
+ }
+ if len(output.Messages) != 2 {
+ t.Errorf("Messages = %d, want 2", len(output.Messages))
+ }
+ if output.TokenCount != 20 {
+ t.Errorf("TokenCount = %d, want 20", output.TokenCount)
+ }
+}
+
+func TestExpandToolMissingIDs(t *testing.T) {
+ s := openTestStore(t)
+ re := &RetrievalEngine{store: s}
+ tool := NewExpandTool(re)
+
+ result := tool.Execute(context.Background(), map[string]any{})
+
+ if !result.IsError {
+ t.Error("expected error for missing message_ids")
+ }
+}
+
+func TestExpandToolWithParts(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:expand-parts")
+
+ // Create message with parts
+ parts := []MessagePart{
+ {Type: "text", Text: "Hello"},
+ {Type: "tool_use", Name: "bash", Arguments: `{"command":"ls"}`, ToolCallID: "call_123"},
+ {Type: "tool_result", ToolCallID: "call_123", Text: "file1.txt\nfile2.txt"},
+ }
+ msg, _ := s.AddMessageWithParts(ctx, conv.ConversationID, "assistant", parts, 50)
+
+ re := &RetrievalEngine{store: s}
+ tool := NewExpandTool(re)
+
+ result := tool.Execute(ctx, map[string]any{
+ "message_ids": []any{fmt.Sprintf("%d", msg.ID)},
+ })
+
+ if result.IsError {
+ t.Fatalf("Expand failed: %s", result.ForLLM)
+ }
+
+ var output struct {
+ Messages []struct {
+ Parts []map[string]any `json:"parts"`
+ } `json:"messages"`
+ }
+ if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
+ t.Fatalf("Parse result: %v", err)
+ }
+
+ if len(output.Messages) != 1 {
+ t.Fatalf("Messages = %d, want 1", len(output.Messages))
+ }
+
+ // Verify parts are filtered correctly
+ foundText := false
+ foundToolUse := false
+ foundToolResult := false
+ for _, p := range output.Messages[0].Parts {
+ switch p["type"].(string) {
+ case "text":
+ foundText = true
+ if p["text"] != "Hello" {
+ t.Errorf("text = %v, want Hello", p["text"])
+ }
+ case "tool_use":
+ foundToolUse = true
+ if p["name"] != "bash" {
+ t.Errorf("name = %v, want bash", p["name"])
+ }
+ case "tool_result":
+ foundToolResult = true
+ // tool_result should NOT have content
+ if _, hasContent := p["content"]; hasContent {
+ t.Error("tool_result should not have content field")
+ }
+ if p["toolCallId"] != "call_123" {
+ t.Errorf("toolCallId = %v, want call_123", p["toolCallId"])
+ }
+ }
+ }
+
+ if !foundText {
+ t.Error("missing text part")
+ }
+ if !foundToolUse {
+ t.Error("missing tool_use part")
+ }
+ if !foundToolResult {
+ t.Error("missing tool_result part")
+ }
+}
diff --git a/pkg/seahorse/tool_grep.go b/pkg/seahorse/tool_grep.go
new file mode 100644
index 000000000..9671d2a7f
--- /dev/null
+++ b/pkg/seahorse/tool_grep.go
@@ -0,0 +1,172 @@
+package seahorse
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+// GrepTool searches summaries and messages for matching content.
+type GrepTool struct {
+ engine *RetrievalEngine
+}
+
+func NewGrepTool(engine *RetrievalEngine) *GrepTool {
+ return &GrepTool{engine: engine}
+}
+
+func (t *GrepTool) Name() string {
+ return "short_grep"
+}
+
+func (t *GrepTool) Description() string {
+ return `Search summaries and messages for matching content.
+
+Pattern syntax:
+- Words: "authentication" - matches content containing this word
+- AND: "auth AND login" - matches content with both words
+- OR: "auth OR signin" - matches content with either word
+- NOT: "bug NOT fixed" - matches "bug" but excludes "fixed"
+- Wildcard: "%auth%" - matches any text containing "auth" (e.g., "auth", "authentication")
+
+Each summary has a "depth" field:
+- depth 0: Created from messages, most detailed
+- depth 1+: Created from other summaries, more compressed but covers longer time
+
+Parameters:
+- pattern (required): Search pattern
+- scope: "both" (default), "summary", or "message" - what to search
+- role: "user", "assistant", or omit for all - filter by message role
+- last: Time shortcut like "6h", "7d", "2w", "1m" (hours/days/weeks/months)
+- all_conversations: Search all conversations (default: current only)
+- since: ISO8601 timestamp, content after this time
+- before: ISO8601 timestamp, content before this time
+- limit: Max results (default: 20)
+
+Returns:
+{
+ "success": true,
+ "summaries": [{"id": "sum_abc", "content": "...", "depth": 0, "kind": "leaf", "conversationId": 1, "rank": -0.5}],
+ "messages": [{"id": "10", "snippet": "...matched...", "role": "user", "conversationId": 1, "rank": -1.2}],
+ "totalSummaries": 5,
+ "totalMessages": 10,
+ "hint": "No matches. Try: %keyword% for fuzzy search"
+}
+
+Rank field (FTS5 mode only): bm25 relevance score, negative value where more negative = higher relevance.
+Examples: -5=excellent, -2=good, -0.5=partial. LIKE mode (%pattern%) has no rank.
+
+Examples:
+ {"pattern": "authentication"}
+ {"pattern": "bug AND login"}
+ {"pattern": "%snake%"}
+ {"pattern": "project", "scope": "summary"}
+ {"pattern": "error", "role": "assistant", "last": "7d"}
+ {"pattern": "error", "all_conversations": true}`
+}
+
+func (t *GrepTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "pattern": map[string]any{
+ "type": "string",
+ "description": "Search pattern. Supports: words, AND/OR/NOT operators, % wildcard",
+ },
+ "scope": map[string]any{
+ "type": "string",
+ "enum": []string{"both", "summary", "message"},
+ "description": "What to search: 'both' (default), 'summary', or 'message'",
+ },
+ "role": map[string]any{
+ "type": "string",
+ "enum": []string{"user", "assistant"},
+ "description": "Filter by message role (default: all roles)",
+ },
+ "last": map[string]any{
+ "type": "string",
+ "description": "Time shortcut: '6h' (6 hours), '7d' (7 days), '2w' (2 weeks), '1m' (1 month)",
+ },
+ "all_conversations": map[string]any{
+ "type": "boolean",
+ "description": "Search across all conversations (default: searches current conversation only)",
+ },
+ "since": map[string]any{
+ "type": "string",
+ "description": "ISO8601 timestamp, only return content after this time",
+ },
+ "before": map[string]any{
+ "type": "string",
+ "description": "ISO8601 timestamp, only return content before this time",
+ },
+ "limit": map[string]any{
+ "type": "integer",
+ "description": "Maximum number of results (default: 20)",
+ },
+ },
+ "required": []string{"pattern"},
+ }
+}
+
+func (t *GrepTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
+ pattern, ok := args["pattern"].(string)
+ if !ok || pattern == "" {
+ return tools.ErrorResult("Missing required 'pattern' argument. Example: {\"pattern\": \"authentication\"}")
+ }
+
+ input := GrepInput{Pattern: pattern}
+
+ if scope, ok := args["scope"].(string); ok && scope != "" {
+ input.Scope = scope
+ }
+ if role, ok := args["role"].(string); ok && role != "" {
+ input.Role = role
+ }
+ if last, ok := args["last"].(string); ok && last != "" {
+ input.Last = last
+ }
+ if allConv, ok := args["all_conversations"].(bool); ok {
+ input.AllConversations = allConv
+ }
+ if limit, ok := args["limit"].(float64); ok {
+ input.Limit = int(limit)
+ }
+ if sinceStr, ok := args["since"].(string); ok && sinceStr != "" {
+ parsed, err := time.Parse(time.RFC3339, sinceStr)
+ if err != nil {
+ return tools.ErrorResult(fmt.Sprintf(
+ "Invalid 'since' timestamp. Use RFC3339 format like '2024-01-15T10:00:00Z'. Error: %v", err))
+ }
+ input.Since = &parsed
+ }
+ if beforeStr, ok := args["before"].(string); ok && beforeStr != "" {
+ parsed, err := time.Parse(time.RFC3339, beforeStr)
+ if err != nil {
+ return tools.ErrorResult(fmt.Sprintf("Invalid 'before' timestamp format: %v", err))
+ }
+ input.Before = &parsed
+ }
+
+ result, err := t.engine.Grep(ctx, input)
+ if err != nil {
+ return tools.ErrorResult("Grep failed: " + err.Error())
+ }
+
+ // Build response
+ output := map[string]any{
+ "success": result.Success,
+ "summaries": result.Summaries,
+ "messages": result.Messages,
+ }
+
+ // Add hint if provided
+ if result.Hint != "" {
+ output["hint"] = result.Hint
+ }
+
+ data, _ := json.Marshal(output)
+ return tools.NewToolResult(string(data))
+}
diff --git a/pkg/seahorse/tool_grep_test.go b/pkg/seahorse/tool_grep_test.go
new file mode 100644
index 000000000..050d9deeb
--- /dev/null
+++ b/pkg/seahorse/tool_grep_test.go
@@ -0,0 +1,72 @@
+package seahorse
+
+import (
+ "context"
+ "testing"
+)
+
+func TestGrepSearchSummaries(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:grep-tool")
+
+ s.CreateSummary(ctx, CreateSummaryInput{
+ ConversationID: conv.ConversationID,
+ Kind: SummaryKindLeaf,
+ Depth: 0,
+ Content: "database connection pool configuration",
+ TokenCount: 50,
+ })
+
+ re := &RetrievalEngine{store: s}
+ results, err := re.Grep(ctx, GrepInput{
+ Pattern: "database",
+ })
+ if err != nil {
+ t.Fatalf("Grep: %v", err)
+ }
+ if len(results.Summaries) == 0 {
+ t.Error("expected at least 1 summary result")
+ }
+}
+
+func TestGrepSearchMessages(t *testing.T) {
+ s := openTestStore(t)
+ ctx := context.Background()
+ conv, _ := s.GetOrCreateConversation(ctx, "test:grep-msg")
+
+ s.AddMessage(ctx, conv.ConversationID, "user", "find this message about testing", 5)
+ s.AddMessage(ctx, conv.ConversationID, "user", "unrelated content", 3)
+
+ re := &RetrievalEngine{store: s}
+ results, err := re.Grep(ctx, GrepInput{
+ Pattern: "testing",
+ })
+ if err != nil {
+ t.Fatalf("Grep messages: %v", err)
+ }
+ if len(results.Messages) == 0 {
+ t.Error("expected at least 1 message result")
+ }
+}
+
+func TestGrepMissingPattern(t *testing.T) {
+ s := openTestStore(t)
+ re := &RetrievalEngine{store: s}
+ _, err := re.Grep(context.Background(), GrepInput{})
+ if err == nil {
+ t.Error("expected error for missing pattern")
+ }
+}
+
+func TestGrepToolSupportsAllConversations(t *testing.T) {
+ s := openTestStore(t)
+ tool := NewGrepTool(&RetrievalEngine{store: s})
+ params := tool.Parameters()
+ props := params["properties"].(map[string]any)
+
+ // GrepTool should accept all_conversations parameter
+ if _, ok := props["all_conversations"]; !ok {
+ t.Error("Parameters missing 'all_conversations' field")
+ }
+}
diff --git a/pkg/seahorse/types.go b/pkg/seahorse/types.go
new file mode 100644
index 000000000..2bc7f931f
--- /dev/null
+++ b/pkg/seahorse/types.go
@@ -0,0 +1,161 @@
+package seahorse
+
+import (
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/tokenizer"
+)
+
+// SummaryKind distinguishes leaf summaries (from raw messages) vs condensed
+// summaries (from other summaries).
+type SummaryKind string
+
+const (
+ SummaryKindLeaf SummaryKind = "leaf"
+ SummaryKindCondensed SummaryKind = "condensed"
+)
+
+// Message represents a single chat message with role and content.
+type Message struct {
+ ID int64 `json:"id"`
+ ConversationID int64 `json:"conversationId"`
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoningContent,omitempty"`
+ TokenCount int `json:"tokenCount"`
+ CreatedAt time.Time `json:"createdAt"`
+ Parts []MessagePart `json:"parts,omitempty"`
+}
+
+// MessagePart holds structured content (tool calls, media, etc.)
+type MessagePart struct {
+ ID int64 `json:"id"`
+ MessageID int64 `json:"messageId"`
+ Type string `json:"type"` // "text", "tool_use", "tool_result", "media"
+ Text string `json:"text"`
+ Name string `json:"name"`
+ Arguments string `json:"arguments"`
+ ToolCallID string `json:"toolCallId"`
+ MediaURI string `json:"mediaUri"`
+ MimeType string `json:"mimeType"`
+}
+
+// Summary represents a compressed representation of messages or other summaries.
+type Summary struct {
+ SummaryID string `json:"summaryId"`
+ ConversationID int64 `json:"conversationId"`
+ Kind SummaryKind `json:"kind"`
+ Depth int `json:"depth"`
+ Content string `json:"content"`
+ TokenCount int `json:"tokenCount"`
+ EarliestAt *time.Time `json:"earliestAt,omitempty"`
+ LatestAt *time.Time `json:"latestAt,omitempty"`
+ DescendantCount int `json:"descendantCount"`
+ DescendantTokenCount int `json:"descendantTokenCount"`
+ SourceMessageTokenCount int `json:"sourceMessageTokenCount"`
+ Model string `json:"model"`
+ CreatedAt time.Time `json:"createdAt"`
+}
+
+// SummaryNode is a Summary with graph relationships for tree traversal.
+type SummaryNode struct {
+ Summary
+ Children []string `json:"children"` // Child summary IDs
+ Expanded bool `json:"expanded"` // UI state for expansion
+}
+
+// Conversation represents a session's conversation with metadata.
+type Conversation struct {
+ ConversationID int64 `json:"conversationId"`
+ SessionKey string `json:"sessionKey"`
+ CreatedAt time.Time `json:"createdAt"`
+ UpdatedAt time.Time `json:"updatedAt"`
+}
+
+// SessionStatus contains status information for a session.
+type SessionStatus struct {
+ SessionKey string `json:"sessionKey"`
+ ConversationID int64 `json:"conversationId"`
+ Messages int `json:"messages"`
+ TotalTokens int `json:"totalTokens"`
+ Summaries int `json:"summaries"`
+ OldestAt time.Time `json:"oldestAt"`
+ NewestAt time.Time `json:"newestAt"`
+}
+
+// ContextItem represents one item in the assembled context window.
+type ContextItem struct {
+ ConversationID int64 `json:"conversationId"`
+ Ordinal int `json:"ordinal"`
+ ItemType string `json:"itemType"` // "summary" or "message"
+ SummaryID string `json:"summaryId,omitempty"`
+ MessageID int64 `json:"messageId,omitempty"`
+ TokenCount int `json:"tokenCount"`
+ CreatedAt time.Time `json:"createdAt"`
+}
+
+// SummarySubtreeNode is a node in a summary DAG subtree.
+type SummarySubtreeNode struct {
+ SummaryID string `json:"summaryId"`
+ DepthFromRoot int `json:"depthFromRoot"`
+}
+
+// SearchInput controls summary search.
+type SearchInput struct {
+ Pattern string `json:"pattern"`
+ Mode string `json:"mode"` // "like" (LIKE search) or "full_text" (FTS5, default)
+ Scope string `json:"scope,omitempty"` // "messages", "summaries", "both"
+ Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
+ Since *time.Time `json:"since,omitempty"`
+ Before *time.Time `json:"before,omitempty"`
+ Limit int `json:"limit,omitempty"`
+ ConversationID int64 `json:"conversationId,omitempty"`
+ AllConversations bool `json:"allConversations,omitempty"`
+}
+
+// SearchResult is a search match.
+type SearchResult struct {
+ SummaryID string `json:"summaryId,omitempty"`
+ MessageID int64 `json:"messageId,omitempty"`
+ ConversationID int64 `json:"conversationId"`
+ Kind SummaryKind `json:"kind,omitempty"`
+ Depth int `json:"depth,omitempty"`
+ Role string `json:"role,omitempty"`
+ Content string `json:"content,omitempty"` // Full content for summaries
+ Snippet string `json:"snippet"`
+ CreatedAt time.Time `json:"createdAt"`
+ Rank float64 `json:"rank,omitempty"`
+ TotalCount int `json:"totalCount,omitempty"` // Total matching rows (from window function)
+}
+
+// EstimateMessageTokens estimates token count for a full message using the
+// shared tokenizer package for consistency with agent.context_budget.
+func EstimateMessageTokens(msg Message) int {
+ pm := providers.Message{
+ Role: msg.Role,
+ Content: msg.Content,
+ ReasoningContent: msg.ReasoningContent,
+ }
+
+ // Convert MessageParts to ToolCalls / ToolCallID / Media
+ for _, part := range msg.Parts {
+ switch part.Type {
+ case "tool_use":
+ pm.ToolCalls = append(pm.ToolCalls, providers.ToolCall{
+ ID: part.ToolCallID,
+ Type: "function",
+ Function: &providers.FunctionCall{
+ Name: part.Name,
+ Arguments: part.Arguments,
+ },
+ })
+ case "tool_result":
+ pm.ToolCallID = part.ToolCallID
+ case "media":
+ pm.Media = append(pm.Media, part.MediaURI)
+ }
+ }
+
+ return tokenizer.EstimateMessageTokens(pm)
+}
diff --git a/pkg/seahorse/types_test.go b/pkg/seahorse/types_test.go
new file mode 100644
index 000000000..b7467005f
--- /dev/null
+++ b/pkg/seahorse/types_test.go
@@ -0,0 +1,54 @@
+package seahorse
+
+import (
+ "testing"
+)
+
+func TestSummaryKindValues(t *testing.T) {
+ if SummaryKindLeaf != "leaf" {
+ t.Errorf("expected SummaryKindLeaf = 'leaf', got %q", SummaryKindLeaf)
+ }
+ if SummaryKindCondensed != "condensed" {
+ t.Errorf("expected SummaryKindCondensed = 'condensed', got %q", SummaryKindCondensed)
+ }
+}
+
+func TestConstants(t *testing.T) {
+ // Ordinal gap step
+ if OrdinalStep != 100 {
+ t.Errorf("expected OrdinalStep = 100, got %d", OrdinalStep)
+ }
+
+ // Compaction triggers
+ if ContextThreshold != 0.75 {
+ t.Errorf("expected ContextThreshold = 0.75, got %f", ContextThreshold)
+ }
+ if FreshTailCount != 32 {
+ t.Errorf("expected FreshTailCount = 32, got %d", FreshTailCount)
+ }
+
+ // Fanout
+ if LeafMinFanout != 8 {
+ t.Errorf("expected LeafMinFanout = 8, got %d", LeafMinFanout)
+ }
+ if CondensedMinFanout != 4 {
+ t.Errorf("expected CondensedMinFanout = 4, got %d", CondensedMinFanout)
+ }
+ if CondensedMinFanoutHard != 2 {
+ t.Errorf("expected CondensedMinFanoutHard = 2, got %d", CondensedMinFanoutHard)
+ }
+
+ // Token targets
+ if LeafChunkTokens != 20000 {
+ t.Errorf("expected LeafChunkTokens = 20000, got %d", LeafChunkTokens)
+ }
+ if LeafTargetTokens != 1200 {
+ t.Errorf("expected LeafTargetTokens = 1200, got %d", LeafTargetTokens)
+ }
+ if CondensedTargetTokens != 2000 {
+ t.Errorf("expected CondensedTargetTokens = 2000, got %d", CondensedTargetTokens)
+ }
+ if MaxExpandTokens != 4000 {
+ t.Errorf("expected MaxExpandTokens = 4000, got %d", MaxExpandTokens)
+ }
+}
diff --git a/pkg/session/jsonl_backend.go b/pkg/session/jsonl_backend.go
index caa18a624..06044b618 100644
--- a/pkg/session/jsonl_backend.go
+++ b/pkg/session/jsonl_backend.go
@@ -222,3 +222,8 @@ func (b *JSONLBackend) Save(key string) error {
func (b *JSONLBackend) Close() error {
return b.store.Close()
}
+
+// ListSessions returns all known session keys.
+func (b *JSONLBackend) ListSessions() []string {
+ return b.store.ListSessions()
+}
diff --git a/pkg/session/manager.go b/pkg/session/manager.go
index ef720b7c5..7f87d460a 100644
--- a/pkg/session/manager.go
+++ b/pkg/session/manager.go
@@ -145,6 +145,16 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
session.Updated = time.Now()
}
+func (sm *SessionManager) ListSessions() []string {
+ sm.mu.RLock()
+ defer sm.mu.RUnlock()
+ keys := make([]string, 0, len(sm.sessions))
+ for k := range sm.sessions {
+ keys = append(keys, k)
+ }
+ return keys
+}
+
// sanitizeFilename converts a session key into a cross-platform safe filename.
// Replaces ':' with '_' (session key separator) and '/' and '\' with '_' so
// composite IDs (e.g. Telegram forum "chatID/threadID") do not create
diff --git a/pkg/session/session_store.go b/pkg/session/session_store.go
index 1d1a2f967..2ba2a974d 100644
--- a/pkg/session/session_store.go
+++ b/pkg/session/session_store.go
@@ -27,6 +27,8 @@ type SessionStore interface {
TruncateHistory(key string, keepLast int)
// Save persists any pending state to durable storage.
Save(key string) error
+ // ListSessions returns all known session keys.
+ ListSessions() []string
// Close releases resources held by the store.
Close() error
}
diff --git a/pkg/tokenizer/estimator.go b/pkg/tokenizer/estimator.go
new file mode 100644
index 000000000..3265edaa8
--- /dev/null
+++ b/pkg/tokenizer/estimator.go
@@ -0,0 +1,91 @@
+package tokenizer
+
+import (
+ "encoding/json"
+ "unicode/utf8"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// EstimateMessageTokens estimates the token count for a single message,
+// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
+// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
+func EstimateMessageTokens(msg providers.Message) int {
+ contentChars := utf8.RuneCountInString(msg.Content)
+
+ // SystemParts are structured system blocks used for cache-aware adapters.
+ // They carry the same content as Content, but in multiple blocks.
+ // We estimate them as an alternative representation, not additive.
+ systemPartsChars := 0
+ if len(msg.SystemParts) > 0 {
+ for _, part := range msg.SystemParts {
+ systemPartsChars += utf8.RuneCountInString(part.Text)
+ }
+ // Per-part overhead for JSON structure (type, text, cache_control).
+ const perPartOverhead = 20
+ systemPartsChars += len(msg.SystemParts) * perPartOverhead
+ }
+
+ // Use the larger of the two representations to stay conservative.
+ chars := contentChars
+ if systemPartsChars > chars {
+ chars = systemPartsChars
+ }
+
+ chars += utf8.RuneCountInString(msg.ReasoningContent)
+
+ for _, tc := range msg.ToolCalls {
+ chars += len(tc.ID) + len(tc.Type)
+ if tc.Function != nil {
+ // Count function name + arguments (the wire format for most providers).
+ // tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
+ chars += len(tc.Function.Name) + len(tc.Function.Arguments)
+ } else {
+ // Fallback: some provider formats use top-level Name without Function.
+ chars += len(tc.Name)
+ }
+ }
+
+ if msg.ToolCallID != "" {
+ chars += len(msg.ToolCallID)
+ }
+
+ // Per-message overhead for role label, JSON structure, separators.
+ const messageOverhead = 12
+ chars += messageOverhead
+
+ tokens := chars * 2 / 5
+
+ // Media items (images, files) are serialized by provider adapters into
+ // multipart or image_url payloads. Add a fixed per-item token estimate
+ // directly (not through the chars heuristic) since actual cost depends
+ // on resolution and provider-specific image tokenization.
+ const mediaTokensPerItem = 256
+ tokens += len(msg.Media) * mediaTokensPerItem
+
+ return tokens
+}
+
+// EstimateToolDefsTokens estimates the total token cost of tool definitions
+// as they appear in the LLM request.
+func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
+ if len(defs) == 0 {
+ return 0
+ }
+
+ totalChars := 0
+ for _, d := range defs {
+ totalChars += len(d.Function.Name) + len(d.Function.Description)
+
+ if d.Function.Parameters != nil {
+ if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
+ totalChars += len(paramJSON)
+ }
+ }
+
+ // Per-tool overhead: type field, JSON structure, separators.
+ totalChars += 20
+ }
+
+ return totalChars * 2 / 5
+}
diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go
index d5bebf4a2..09d1f545b 100644
--- a/pkg/tools/edit.go
+++ b/pkg/tools/edit.go
@@ -29,7 +29,7 @@ func (t *EditFileTool) Name() string {
}
func (t *EditFileTool) Description() string {
- return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
+ return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n."
}
func (t *EditFileTool) Parameters() map[string]any {
@@ -42,11 +42,11 @@ func (t *EditFileTool) Parameters() map[string]any {
},
"old_text": map[string]any{
"type": "string",
- "description": "The exact text to find and replace",
+ "description": "The exact text to find and replace. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
},
"new_text": map[string]any{
"type": "string",
- "description": "The text to replace with",
+ "description": "The text to replace with. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
},
},
"required": []string{"path", "old_text", "new_text"},
@@ -92,7 +92,7 @@ func (t *AppendFileTool) Name() string {
}
func (t *AppendFileTool) Description() string {
- return "Append content to the end of a file"
+ return "Append content to the end of a file. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n."
}
func (t *AppendFileTool) Parameters() map[string]any {
@@ -105,7 +105,7 @@ func (t *AppendFileTool) Parameters() map[string]any {
},
"content": map[string]any{
"type": "string",
- "description": "The content to append",
+ "description": "The content to append. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
},
},
"required": []string{"path", "content"},
diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go
index 39d45013d..52d77f665 100644
--- a/pkg/tools/filesystem.go
+++ b/pkg/tools/filesystem.go
@@ -1,18 +1,22 @@
package tools
import (
+ "bufio"
+ "bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"math"
+ "net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
+ "unicode/utf8"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -20,7 +24,11 @@ import (
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
-func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
+func validatePathWithAllowPaths(
+ path, workspace string,
+ restrict bool,
+ patterns []*regexp.Regexp,
+) (string, error) {
if workspace == "" {
return path, fmt.Errorf("workspace is not defined")
}
@@ -253,6 +261,11 @@ type ReadFileTool struct {
maxSize int64
}
+type ReadFileLinesTool struct {
+ fs fileSystem
+ maxSize int64
+}
+
func NewReadFileTool(
workspace string,
restrict bool,
@@ -275,14 +288,53 @@ func NewReadFileTool(
}
}
+func NewReadFileBytesTool(
+ workspace string,
+ restrict bool,
+ maxReadFileSize int,
+ allowPaths ...[]*regexp.Regexp,
+) *ReadFileTool {
+ return NewReadFileTool(workspace, restrict, maxReadFileSize, allowPaths...)
+}
+
+func NewReadFileLinesTool(
+ workspace string,
+ restrict bool,
+ maxReadFileSize int,
+ allowPaths ...[]*regexp.Regexp,
+) *ReadFileLinesTool {
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
+ }
+
+ maxSize := int64(maxReadFileSize)
+ if maxSize <= 0 {
+ maxSize = MaxReadFileSize
+ }
+
+ return &ReadFileLinesTool{
+ fs: buildFs(workspace, restrict, patterns),
+ maxSize: maxSize,
+ }
+}
+
func (t *ReadFileTool) Name() string {
return "read_file"
}
+func (t *ReadFileLinesTool) Name() string {
+ return "read_file"
+}
+
func (t *ReadFileTool) Description() string {
return "Read the contents of a file. Supports pagination via `offset` and `length`."
}
+func (t *ReadFileLinesTool) Description() string {
+ return "Read a UTF-8 text file from the filesystem. Output always includes line numbers in the format `LINE_NUMBER|LINE_CONTENT` (1-indexed). Supports partial reads via `start_line` and `max_lines` for large text files."
+}
+
func (t *ReadFileTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
@@ -306,6 +358,28 @@ func (t *ReadFileTool) Parameters() map[string]any {
}
}
+func (t *ReadFileLinesTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "path": map[string]any{
+ "type": "string",
+ "description": "Path to the file to read.",
+ },
+ "start_line": map[string]any{
+ "type": "integer",
+ "description": "Line number to start reading from (1-indexed, inclusive).",
+ "default": 1,
+ },
+ "max_lines": map[string]any{
+ "type": "integer",
+ "description": "Maximum number of lines to read.",
+ },
+ },
+ "required": []string{"path"},
+ }
+}
+
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
@@ -447,6 +521,302 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return NewToolResult(header + "\n\n" + string(data))
}
+func (t *ReadFileLinesTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
+ path, ok := args["path"].(string)
+ if !ok {
+ return ErrorResult("path is required")
+ }
+
+ startLine, err := getInt64Arg(args, "start_line", 1)
+ if err != nil {
+ return ErrorResult(err.Error())
+ }
+ if startLine < 1 {
+ return ErrorResult("start_line must be >= 1")
+ }
+ if _, exists := args["offset"]; exists {
+ return ErrorResult("offset is not supported in line mode; use start_line")
+ }
+ if _, exists := args["length"]; exists {
+ return ErrorResult("length is not supported in line mode; use max_lines")
+ }
+ if _, exists := args["limit"]; exists {
+ return ErrorResult("limit is not supported in line mode; use max_lines")
+ }
+
+ limit := int64(-1)
+ if raw, exists := args["max_lines"]; exists && raw != nil {
+ limit, err = getInt64Arg(args, "max_lines", -1)
+ if err != nil {
+ return ErrorResult(err.Error())
+ }
+ if limit <= 0 {
+ return ErrorResult("max_lines, if provided, must be > 0")
+ }
+ }
+
+ file, err := t.fs.Open(path)
+ if err != nil {
+ return ErrorResult(err.Error())
+ }
+ defer file.Close()
+
+ if info, statErr := file.Stat(); statErr == nil && info.IsDir() {
+ return ErrorResult(fmt.Sprintf("failed to open file: path is a directory: %s", path))
+ }
+
+ sample := make([]byte, 512)
+ sampleN, readErr := file.Read(sample)
+ if readErr != nil && readErr != io.EOF {
+ return ErrorResult(fmt.Sprintf("failed to read file: %v", readErr))
+ }
+ sample = sample[:sampleN]
+ if isBinaryReadFileData(sample) {
+ return ErrorResult("file appears to be binary; switch read_file mode to 'bytes' for byte-based inspection")
+ }
+
+ reader := bufio.NewReaderSize(io.MultiReader(bytes.NewReader(sample), file), 32*1024)
+
+ var content strings.Builder
+ lineIndex := int64(1)
+ var linesRead int64
+ var fileBytesRead int64
+ var outputBytesRead int64
+ var reachedEOF bool
+ var byteBudgetTruncated bool
+ var lineTruncated bool
+
+ for lineIndex < startLine {
+ hasLine, consumeErr := consumeNextLine(reader)
+ if consumeErr != nil {
+ return ErrorResult(fmt.Sprintf("failed to read file content: %v", consumeErr))
+ }
+ if !hasLine {
+ reachedEOF = true
+ break
+ }
+ lineIndex++
+ }
+
+ for !reachedEOF && (limit < 0 || linesRead < limit) {
+ prefix := formatReadFileLinePrefix(lineIndex)
+ remaining := t.maxSize - outputBytesRead - int64(len(prefix))
+ if remaining <= 0 {
+ byteBudgetTruncated = true
+ break
+ }
+
+ line, complete, hasLine, readLineErr := readNextLinePrefix(reader, remaining)
+ if readLineErr != nil {
+ return ErrorResult(fmt.Sprintf("failed to read file content: %v", readLineErr))
+ }
+ if !hasLine {
+ reachedEOF = true
+ break
+ }
+
+ content.WriteString(prefix)
+ content.Write(line)
+ fileBytesRead += int64(len(line))
+ outputBytesRead += int64(len(prefix) + len(line))
+ linesRead++
+ lineIndex++
+
+ if !complete {
+ byteBudgetTruncated = true
+ lineTruncated = true
+ break
+ }
+ }
+
+ if !reachedEOF && !lineTruncated {
+ hasMoreContent, peekErr := readerHasMoreContent(reader)
+ if peekErr != nil {
+ return ErrorResult(fmt.Sprintf("failed to inspect remaining file content: %v", peekErr))
+ }
+ if !hasMoreContent {
+ reachedEOF = true
+ byteBudgetTruncated = false
+ }
+ }
+
+ if linesRead == 0 && content.Len() == 0 {
+ return NewToolResult(fmt.Sprintf("[END OF FILE - no content at or after start_line=%d]", startLine))
+ }
+
+ start := startLine
+ endLine := startLine + linesRead - 1
+ displayPath := filepath.Base(path)
+ header := fmt.Sprintf(
+ "[file: %s | read: lines %d-%d (1-indexed) | file_bytes: %d | output_bytes: %d]",
+ displayPath, start, endLine, fileBytesRead, outputBytesRead,
+ )
+
+ switch {
+ case lineTruncated:
+ header += fmt.Sprintf(
+ "\n[TRUNCATED - line %d exceeded the %d byte read budget and was cut mid-line.]",
+ endLine,
+ t.maxSize,
+ )
+ case byteBudgetTruncated:
+ if limit > 0 {
+ header += fmt.Sprintf(
+ "\n[TRUNCATED - byte budget reached. Call read_file again with start_line=%d and max_lines=%d to continue at the next line.]",
+ startLine+linesRead,
+ limit,
+ )
+ } else {
+ header += fmt.Sprintf(
+ "\n[TRUNCATED - byte budget reached. Call read_file again with start_line=%d to continue at the next line.]",
+ startLine+linesRead,
+ )
+ }
+ case !reachedEOF && limit > 0 && linesRead >= limit:
+ header += fmt.Sprintf(
+ "\n[PARTIAL - more content remains. Call read_file again with start_line=%d and max_lines=%d to continue.]",
+ startLine+linesRead,
+ limit,
+ )
+ default:
+ header += "\n[END OF FILE - no further content.]"
+ }
+
+ logger.DebugCF("tool", "ReadFileTool execution completed successfully",
+ map[string]any{
+ "path": path,
+ "lines_read": linesRead,
+ "file_bytes_read": fileBytesRead,
+ "output_bytes_read": outputBytesRead,
+ "truncated": byteBudgetTruncated,
+ "tool": t.Name(),
+ })
+
+ return NewToolResult(header + "\n\n" + content.String())
+}
+
+func formatReadFileLinePrefix(lineNumber int64) string {
+ return strconv.FormatInt(lineNumber, 10) + "|"
+}
+
+func isBinaryReadFileData(data []byte) bool {
+ if len(data) == 0 {
+ return false
+ }
+
+ sample := data
+ if len(sample) > 512 {
+ sample = sample[:512]
+ }
+
+ if bytes.IndexByte(sample, 0) >= 0 {
+ return true
+ }
+
+ contentType := http.DetectContentType(sample)
+ if strings.HasPrefix(contentType, "text/") {
+ return false
+ }
+ if strings.HasSuffix(contentType, "/json") ||
+ strings.HasSuffix(contentType, "+json") ||
+ strings.HasSuffix(contentType, "/xml") ||
+ strings.HasSuffix(contentType, "+xml") ||
+ strings.Contains(contentType, "javascript") {
+ return false
+ }
+
+ if !utf8.Valid(sample) {
+ return true
+ }
+
+ controlChars := 0
+ for _, b := range sample {
+ if b < 0x20 && b != '\n' && b != '\r' && b != '\t' && b != '\f' && b != '\b' {
+ controlChars++
+ }
+ }
+
+ return float64(controlChars)/float64(len(sample)) > 0.1
+}
+
+func consumeNextLine(reader *bufio.Reader) (bool, error) {
+ sawData := false
+
+ for {
+ fragment, err := reader.ReadSlice('\n')
+ if len(fragment) > 0 {
+ sawData = true
+ }
+
+ switch {
+ case err == nil:
+ return true, nil
+ case errors.Is(err, bufio.ErrBufferFull):
+ continue
+ case errors.Is(err, io.EOF):
+ return sawData, nil
+ default:
+ return false, err
+ }
+ }
+}
+
+func readNextLinePrefix(reader *bufio.Reader, maxBytes int64) ([]byte, bool, bool, error) {
+ if maxBytes <= 0 {
+ return nil, false, false, nil
+ }
+
+ var out bytes.Buffer
+ sawData := false
+ complete := true
+
+ for {
+ fragment, err := reader.ReadSlice('\n')
+ if len(fragment) > 0 {
+ sawData = true
+ if remaining := maxBytes - int64(out.Len()); remaining > 0 {
+ take := len(fragment)
+ if int64(take) > remaining {
+ take = int(remaining)
+ complete = false
+ }
+ out.Write(fragment[:take])
+ } else {
+ complete = false
+ }
+ }
+
+ switch {
+ case err == nil:
+ return out.Bytes(), complete, sawData, nil
+ case errors.Is(err, bufio.ErrBufferFull):
+ if !complete {
+ return out.Bytes(), false, true, nil
+ }
+ continue
+ case errors.Is(err, io.EOF):
+ if !sawData {
+ return nil, true, false, nil
+ }
+ return out.Bytes(), complete, true, nil
+ default:
+ return nil, false, false, err
+ }
+ }
+}
+
+func readerHasMoreContent(reader *bufio.Reader) (bool, error) {
+ _, err := reader.Peek(1)
+ switch {
+ case err == nil:
+ return true, nil
+ case errors.Is(err, io.EOF):
+ return false, nil
+ default:
+ return false, err
+ }
+}
+
// getInt64Arg extracts an integer argument from the args map, returning the
// provided default if the key is absent.
func getInt64Arg(args map[string]any, key string, defaultVal int64) (int64, error) {
@@ -483,7 +853,11 @@ type WriteFileTool struct {
fs fileSystem
}
-func NewWriteFileTool(workspace string, restrict bool, allowPaths ...[]*regexp.Regexp) *WriteFileTool {
+func NewWriteFileTool(
+ workspace string,
+ restrict bool,
+ allowPaths ...[]*regexp.Regexp,
+) *WriteFileTool {
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
@@ -496,7 +870,7 @@ func (t *WriteFileTool) Name() string {
}
func (t *WriteFileTool) Description() string {
- return "Write content to a file. If the file already exists, you must set overwrite=true to replace it."
+ return "Write content to a file. In `function.arguments`, use \\n for a newline and \\\\n for a literal backslash-n sequence. Content is written byte-for-byte after argument decoding. If the file already exists, you must set overwrite=true to replace it."
}
func (t *WriteFileTool) Parameters() map[string]any {
@@ -509,7 +883,7 @@ func (t *WriteFileTool) Parameters() map[string]any {
},
"content": map[string]any{
"type": "string",
- "description": "Content to write to the file",
+ "description": "Content to write to the file. In `function.arguments`, use \\n for newline and \\\\n for literal backslash-n.",
},
"overwrite": map[string]any{
"type": "boolean",
@@ -536,7 +910,9 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolR
if !overwrite {
if _, err := t.fs.Open(path); err == nil {
- return ErrorResult(fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path))
+ return ErrorResult(
+ fmt.Sprintf("file: %s already exists. Set overwrite=true to replace.", path),
+ )
}
}
diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go
index 0b4dd310b..0ab37c215 100644
--- a/pkg/tools/filesystem_test.go
+++ b/pkg/tools/filesystem_test.go
@@ -18,7 +18,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0o644)
- tool := NewReadFileTool("", false, MaxReadFileSize)
+ tool := NewReadFileBytesTool("", false, MaxReadFileSize)
ctx := context.Background()
args := map[string]any{
"path": testFile,
@@ -45,7 +45,7 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
- tool := NewReadFileTool("", false, MaxReadFileSize)
+ tool := NewReadFileBytesTool("", false, MaxReadFileSize)
ctx := context.Background()
args := map[string]any{
"path": "/nonexistent_file_12345.txt",
@@ -59,8 +59,13 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
}
// Should contain error message
- if !strings.Contains(result.ForLLM, "failed to open file") && !strings.Contains(result.ForUser, "failed to read") {
- t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
+ if !strings.Contains(result.ForLLM, "failed to open file") &&
+ !strings.Contains(result.ForUser, "failed to open") {
+ t.Errorf(
+ "Expected error message, got ForLLM: %s, ForUser: %s",
+ result.ForLLM,
+ result.ForUser,
+ )
}
}
@@ -78,7 +83,8 @@ func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
}
// Should mention required parameter
- if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
+ if !strings.Contains(result.ForLLM, "path is required") &&
+ !strings.Contains(result.ForUser, "path is required") {
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
}
}
@@ -122,6 +128,45 @@ func TestFilesystemTool_WriteFile_Success(t *testing.T) {
}
}
+// TestFilesystemTool_WriteFile_LiteralBackslashN verifies write_file keeps
+// literal backslash sequences unchanged when they are passed as plain text.
+func TestFilesystemTool_WriteFile_LiteralBackslashN(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "literal.txt")
+
+ tool := NewWriteFileTool("", false)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "content": `aaa\naaa`,
+ })
+
+ assert.False(t, result.IsError, "expected success, got: %s", result.ForLLM)
+
+ data, err := os.ReadFile(testFile)
+ assert.NoError(t, err)
+ assert.Equal(t, `aaa\naaa`, string(data))
+}
+
+// TestFilesystemTool_WriteFile_PreservesCRLF verifies write_file does not
+// normalize line endings and writes CRLF bytes as provided.
+func TestFilesystemTool_WriteFile_PreservesCRLF(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "crlf.txt")
+ content := "line1\r\nline2\r\n"
+
+ tool := NewWriteFileTool("", false)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "content": content,
+ })
+
+ assert.False(t, result.IsError, "expected success, got: %s", result.ForLLM)
+
+ data, err := os.ReadFile(testFile)
+ assert.NoError(t, err)
+ assert.Equal(t, []byte(content), data)
+}
+
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
@@ -297,7 +342,12 @@ func TestFilesystemTool_WriteFile_OverwriteSandboxed(t *testing.T) {
"content": "replaced in sandbox",
"overwrite": true,
})
- assert.False(t, result.IsError, "expected success in sandbox mode with overwrite=true, got: %s", result.ForLLM)
+ assert.False(
+ t,
+ result.IsError,
+ "expected success in sandbox mode with overwrite=true, got: %s",
+ result.ForLLM,
+ )
data, err := os.ReadFile(filepath.Join(workspace, testFile))
assert.NoError(t, err)
@@ -325,7 +375,8 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
}
// Should list files and directories
- if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
+ if !strings.Contains(result.ForLLM, "file1.txt") ||
+ !strings.Contains(result.ForLLM, "file2.txt") {
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subdir") {
@@ -349,8 +400,13 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
}
// Should contain error message
- if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
- t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
+ if !strings.Contains(result.ForLLM, "failed to read") &&
+ !strings.Contains(result.ForUser, "failed to read") {
+ t.Errorf(
+ "Expected error message, got ForLLM: %s, ForUser: %s",
+ result.ForLLM,
+ result.ForUser,
+ )
}
}
@@ -397,7 +453,8 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
// os.Root might return different errors depending on platform/implementation
// but it definitely should error.
// Our wrapper returns "access denied or file not found"
- if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
+ if !strings.Contains(result.ForLLM, "access denied") &&
+ !strings.Contains(result.ForLLM, "file not found") &&
!strings.Contains(result.ForLLM, "no such file") {
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
}
@@ -416,10 +473,20 @@ func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
})
// We EXPECT IsError=true (access blocked due to empty workspace)
- assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
+ assert.True(
+ t,
+ result.IsError,
+ "Security Regression: Empty workspace allowed access! content: %s",
+ result.ForLLM,
+ )
// Verify it failed for the right reason
- assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
+ assert.Contains(
+ t,
+ result.ForLLM,
+ "workspace is not defined",
+ "Expected 'workspace is not defined' error",
+ )
}
// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
@@ -653,7 +720,10 @@ func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
- result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
+ result := tool.Execute(
+ context.Background(),
+ map[string]any{"path": filepath.Join(linkPath, "secret.txt")},
+ )
if !result.IsError {
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
}
@@ -726,7 +796,6 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "pagination_test.txt")
- // Create a test file with exactly 26 bytes of content
fullContent := "abcdefghijklmnopqrstuvwxyz"
err := os.WriteFile(testFile, []byte(fullContent), 0o644)
if err != nil {
@@ -748,15 +817,12 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
t.Fatalf("Chunk 1 failed: %s", result1.ForLLM)
}
- // Expect the first 10 characters
if !strings.Contains(result1.ForLLM, "abcdefghij") {
t.Errorf("Chunk 1 should contain 'abcdefghij', got: %s", result1.ForLLM)
}
- // Expect the header to indicate the file is truncated
if !strings.Contains(result1.ForLLM, "[TRUNCATED") {
t.Errorf("Chunk 1 header should indicate truncation, got: %s", result1.ForLLM)
}
- // Expect the header to suggest the next offset (10)
if !strings.Contains(result1.ForLLM, "offset=10") {
t.Errorf("Chunk 1 header should suggest next offset=10, got: %s", result1.ForLLM)
}
@@ -773,17 +839,14 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
t.Fatalf("Chunk 2 failed: %s", result2.ForLLM)
}
- // Expect the next 10 characters
if !strings.Contains(result2.ForLLM, "klmnopqrst") {
t.Errorf("Chunk 2 should contain 'klmnopqrst', got: %s", result2.ForLLM)
}
- // Expect the header to suggest the next offset (20)
if !strings.Contains(result2.ForLLM, "offset=20") {
t.Errorf("Chunk 2 header should suggest next offset=20, got: %s", result2.ForLLM)
}
// Step 3: Read the final chunk (remaining 6 bytes) ---
- // We ask for 10 bytes, but only 6 are left in the file
args3 := map[string]any{
"path": testFile,
"offset": 20,
@@ -795,16 +858,12 @@ func TestReadFileTool_ChunkedReading(t *testing.T) {
t.Fatalf("Chunk 3 failed: %s", result3.ForLLM)
}
- // Expect the last 6 characters
if !strings.Contains(result3.ForLLM, "uvwxyz") {
t.Errorf("Chunk 3 should contain 'uvwxyz', got: %s", result3.ForLLM)
}
- // Expect the header to indicate the end of the file
if !strings.Contains(result3.ForLLM, "[END OF FILE") {
t.Errorf("Chunk 3 header should indicate end of file, got: %s", result3.ForLLM)
}
-
- // Ensure no TRUNCATED message is present in the final chunk
if strings.Contains(result3.ForLLM, "[TRUNCATED") {
t.Errorf("Chunk 3 header should NOT indicate truncation, got: %s", result3.ForLLM)
}
@@ -816,7 +875,6 @@ func TestReadFileTool_OffsetBeyondEOF(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "short.txt")
- // create a file of only 5 bytes
err := os.WriteFile(testFile, []byte("12345"), 0o644)
if err != nil {
t.Fatalf("Failed to write test file: %v", err)
@@ -827,19 +885,393 @@ func TestReadFileTool_OffsetBeyondEOF(t *testing.T) {
args := map[string]any{
"path": testFile,
- "offset": int64(100), // Offset beyond the end of the file
+ "offset": int64(100),
}
result := tool.Execute(ctx, args)
- // It should not be classified as a tool execution error
if result.IsError {
t.Errorf("A mistake was not expected, obtained IsError=true: %s", result.ForLLM)
}
- // Must return EXACTLY the string provided in the code
expectedMsg := "[END OF FILE - no content at this offset]"
if result.ForLLM != expectedMsg {
t.Errorf("The message %q was expected, obtained: %q", expectedMsg, result.ForLLM)
}
}
+
+func TestReadFileLinesTool_ChunkedReading(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "pagination_lines.txt")
+
+ fullContent := strings.Join([]string{
+ "line 1",
+ "line 2",
+ "line 3",
+ "line 4",
+ "line 5",
+ "line 6",
+ }, "\n") + "\n"
+ err := os.WriteFile(testFile, []byte(fullContent), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+
+ result1 := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ "max_lines": 2,
+ })
+ if result1.IsError {
+ t.Fatalf("Chunk 1 failed: %s", result1.ForLLM)
+ }
+ if !strings.Contains(result1.ForLLM, "1|line 1\n2|line 2\n") {
+ t.Fatalf("expected first two lines, got: %s", result1.ForLLM)
+ }
+ if !strings.Contains(result1.ForLLM, "lines 1-2") {
+ t.Fatalf("expected line range 1-2, got: %s", result1.ForLLM)
+ }
+ if !strings.Contains(result1.ForLLM, "start_line=3") {
+ t.Fatalf("expected continuation start_line=3, got: %s", result1.ForLLM)
+ }
+ if !strings.Contains(result1.ForLLM, "max_lines=2") {
+ t.Fatalf("expected continuation max_lines=2, got: %s", result1.ForLLM)
+ }
+
+ result2 := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 3,
+ "max_lines": 2,
+ })
+ if result2.IsError {
+ t.Fatalf("Chunk 2 failed: %s", result2.ForLLM)
+ }
+ if !strings.Contains(result2.ForLLM, "3|line 3\n4|line 4\n") {
+ t.Fatalf("expected middle chunk, got: %s", result2.ForLLM)
+ }
+ if !strings.Contains(result2.ForLLM, "start_line=5") {
+ t.Fatalf("expected continuation start_line=5, got: %s", result2.ForLLM)
+ }
+ if !strings.Contains(result2.ForLLM, "max_lines=2") {
+ t.Fatalf("expected continuation max_lines=2, got: %s", result2.ForLLM)
+ }
+
+ result3 := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 5,
+ "max_lines": 2,
+ })
+ if result3.IsError {
+ t.Fatalf("Chunk 3 failed: %s", result3.ForLLM)
+ }
+ if !strings.Contains(result3.ForLLM, "5|line 5\n6|line 6\n") {
+ t.Fatalf("expected final chunk, got: %s", result3.ForLLM)
+ }
+ if !strings.Contains(result3.ForLLM, "[END OF FILE") {
+ t.Fatalf("expected EOF marker, got: %s", result3.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_DefaultOffsetAndRemainingLines(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "default_lines.txt")
+
+ err := os.WriteFile(testFile, []byte("line 1\nline 2\nline 3\n"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ })
+ if result.IsError {
+ t.Fatalf("Execute() error = %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "1|line 1\n2|line 2\n3|line 3\n") {
+ t.Fatalf("expected remaining lines by default, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "lines 1-3") {
+ t.Fatalf("expected line range 1-3, got: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileTool_LegacyLengthUsesByteModeForText(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "legacy_bytes.txt")
+
+ err := os.WriteFile(testFile, []byte("abcdefghijklmnopqrstuvwxyz"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileBytesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "offset": 10,
+ "length": 5,
+ })
+ if result.IsError {
+ t.Fatalf("Execute() error = %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "read: bytes 10-14") {
+ t.Fatalf("expected byte-based header, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "klmno") {
+ t.Fatalf("expected byte chunk content, got: %s", result.ForLLM)
+ }
+ if strings.Contains(result.ForLLM, "lines ") {
+ t.Fatalf("expected legacy byte mode, got line-based header: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_OffsetBeyondEOF(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "short_lines.txt")
+
+ err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": int64(100),
+ })
+ if result.IsError {
+ t.Fatalf("unexpected error: %s", result.ForLLM)
+ }
+ if result.ForLLM != "[END OF FILE - no content at or after start_line=100]" {
+ t.Fatalf("unexpected EOF message: %q", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_RegistryValidationSupportsMaxLinesAndRejectsLimit(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "registry_lines.txt")
+
+ err := os.WriteFile(testFile, []byte("line 1\nline 2\nline 3\n"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ reg := NewToolRegistry()
+ reg.Register(NewReadFileLinesTool(tmpDir, false, MaxReadFileSize))
+
+ result := reg.Execute(context.Background(), "read_file", map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ "max_lines": 1,
+ })
+ if result.IsError {
+ t.Fatalf("expected max_lines to pass registry validation, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "1|line 1\n") {
+ t.Fatalf("expected first line via max_lines, got: %s", result.ForLLM)
+ }
+
+ result = reg.Execute(context.Background(), "read_file", map[string]any{
+ "path": testFile,
+ "start_line": 2,
+ "limit": 1,
+ })
+ if !result.IsError {
+ t.Fatalf("expected limit to be rejected, got success: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "unexpected property \"limit\"") {
+ t.Fatalf("expected registry validation error for limit, got: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_RejectsOffset(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "legacy_offset.txt")
+
+ err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ "offset": 1,
+ })
+ if !result.IsError {
+ t.Fatalf("expected offset to be rejected, got success: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "offset is not supported in line mode; use start_line") {
+ t.Fatalf("unexpected error for offset in line mode: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_RejectsLength(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "legacy_length.txt")
+
+ err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ "length": 1,
+ })
+ if !result.IsError {
+ t.Fatalf("expected length to be rejected, got success: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "length is not supported in line mode; use max_lines") {
+ t.Fatalf("unexpected error for length in line mode: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_RejectsLimit(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "legacy_limit.txt")
+
+ err := os.WriteFile(testFile, []byte("line 1\nline 2\n"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ "limit": 1,
+ })
+ if !result.IsError {
+ t.Fatalf("expected limit to be rejected, got success: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "limit is not supported in line mode; use max_lines") {
+ t.Fatalf("unexpected error for limit in line mode: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_BinaryFileRejected(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "binary.dat")
+
+ data := []byte{0x00, 0x01, 'A', 'B', 'C', 'D', 'E', 'F'}
+ err := os.WriteFile(testFile, data, 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ })
+ if !result.IsError {
+ t.Fatalf("expected binary file rejection in line mode, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "switch read_file mode to 'bytes'") {
+ t.Fatalf("expected binary file rejection message, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "mode to 'bytes'") {
+ t.Fatalf("expected suggestion to switch read_file mode, got: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_TruncatesSingleLongLineAtByteBudget(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "long_line.txt")
+
+ content := "first line\n" + strings.Repeat("x", 70*1024) + "\n"
+ err := os.WriteFile(testFile, []byte(content), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ })
+ if result.IsError {
+ t.Fatalf("Execute() error = %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "was cut mid-line") {
+ t.Fatalf("expected explicit mid-line truncation warning, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "1|first line\n") {
+ t.Fatalf("expected the first line with line prefix, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "2|") {
+ t.Fatalf("expected line prefix for the truncated line, got: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_NoTrailingNewline(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "no_trailing_newline.txt")
+
+ err := os.WriteFile(testFile, []byte("line 1\nline 2"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, MaxReadFileSize)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ })
+ if result.IsError {
+ t.Fatalf("Execute() error = %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "1|line 1\n2|line 2") {
+ t.Fatalf(
+ "expected final line without trailing newline to be preserved, got: %s",
+ result.ForLLM,
+ )
+ }
+ if !strings.Contains(result.ForLLM, "[END OF FILE - no further content.]") {
+ t.Fatalf("expected EOF marker, got: %s", result.ForLLM)
+ }
+}
+
+func TestReadFileLinesTool_ExactByteBudgetBoundaryIncludesPrefix(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "exact_boundary.txt")
+
+ err := os.WriteFile(testFile, []byte("1234567\nsecond line\n"), 0o644)
+ if err != nil {
+ t.Fatalf("Failed to write test file: %v", err)
+ }
+
+ tool := NewReadFileLinesTool(tmpDir, false, 10)
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": testFile,
+ "start_line": 1,
+ })
+ if result.IsError {
+ t.Fatalf("Execute() error = %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "1|1234567\n") {
+ t.Fatalf(
+ "expected first line to fit exactly in the byte budget with its prefix, got: %s",
+ result.ForLLM,
+ )
+ }
+ if strings.Contains(result.ForLLM, "2|") {
+ t.Fatalf(
+ "expected second line to be excluded once the exact output byte budget was reached, got: %s",
+ result.ForLLM,
+ )
+ }
+ if !strings.Contains(result.ForLLM, "file_bytes: 8 | output_bytes: 10") {
+ t.Fatalf("expected separate file/output byte counters, got: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "start_line=2") {
+ t.Fatalf("expected continuation at line 2, got: %s", result.ForLLM)
+ }
+}
diff --git a/pkg/tools/load_image.go b/pkg/tools/load_image.go
new file mode 100644
index 000000000..41ea6d054
--- /dev/null
+++ b/pkg/tools/load_image.go
@@ -0,0 +1,163 @@
+package tools
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/media"
+)
+
+// LoadImageTool loads a local image file into the MediaStore and returns a
+// media:// reference. The agent loop's resolveMediaRefs will then base64-encode
+// it and attach it as an image_url part in the next LLM request, enabling
+// vision on local files — the same pipeline used when a user sends an image
+// through a chat channel.
+//
+// This is intentionally different from SendFileTool:
+// - SendFileTool → MediaResult + WithResponseHandled() → sends file to user, ends turn
+// - LoadImageTool → plain ToolResult with media:// in ForLLM → LLM sees the image next turn
+type LoadImageTool struct {
+ workspace string
+ restrict bool
+ maxFileSize int
+ mediaStore media.MediaStore
+ allowPaths []*regexp.Regexp
+
+ defaultChannel string
+ defaultChatID string
+}
+
+func NewLoadImageTool(
+ workspace string,
+ restrict bool,
+ maxFileSize int,
+ store media.MediaStore,
+ allowPaths ...[]*regexp.Regexp,
+) *LoadImageTool {
+ if maxFileSize <= 0 {
+ maxFileSize = config.DefaultMaxMediaSize
+ }
+ var patterns []*regexp.Regexp
+ if len(allowPaths) > 0 {
+ patterns = allowPaths[0]
+ }
+ return &LoadImageTool{
+ workspace: workspace,
+ restrict: restrict,
+ maxFileSize: maxFileSize,
+ mediaStore: store,
+ allowPaths: patterns,
+ }
+}
+
+func (t *LoadImageTool) Name() string { return "load_image" }
+
+func (t *LoadImageTool) Description() string {
+ return "Load a local image file so you can analyze its contents with vision. " +
+ "Supported formats: JPEG, PNG, GIF, WebP, BMP. " +
+ "After calling this tool, describe or analyze the image in your next response."
+}
+
+func (t *LoadImageTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "path": map[string]any{
+ "type": "string",
+ "description": "Path to the local image file. Relative paths are resolved from workspace.",
+ },
+ },
+ "required": []string{"path"},
+ }
+}
+
+func (t *LoadImageTool) SetContext(channel, chatID string) {
+ t.defaultChannel = channel
+ t.defaultChatID = chatID
+}
+
+func (t *LoadImageTool) SetMediaStore(store media.MediaStore) {
+ t.mediaStore = store
+}
+
+func (t *LoadImageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
+ path, _ := args["path"].(string)
+ if strings.TrimSpace(path) == "" {
+ return ErrorResult("path is required")
+ }
+
+ // Prefer context-injected channel/chatID (set by ExecuteWithContext), fall back to SetContext values.
+ channel := ToolChannel(ctx)
+ if channel == "" {
+ channel = t.defaultChannel
+ }
+ chatID := ToolChatID(ctx)
+ if chatID == "" {
+ chatID = t.defaultChatID
+ }
+ if channel == "" || chatID == "" {
+ return ErrorResult("no target channel/chat available")
+ }
+
+ if t.mediaStore == nil {
+ return ErrorResult("media store not configured")
+ }
+
+ resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
+ if err != nil {
+ return ErrorResult(fmt.Sprintf("invalid path: %v", err))
+ }
+
+ info, err := os.Stat(resolved)
+ if err != nil {
+ return ErrorResult(fmt.Sprintf("file not found: %v", err))
+ }
+ if info.IsDir() {
+ return ErrorResult("path is a directory, expected an image file")
+ }
+ if info.Size() > int64(t.maxFileSize) {
+ return ErrorResult(fmt.Sprintf(
+ "file too large: %d bytes (max %d bytes)", info.Size(), t.maxFileSize,
+ ))
+ }
+
+ // Detect MIME type — reuse the helper already in send_file.go
+ mediaType := detectMediaType(resolved)
+ if !strings.HasPrefix(mediaType, "image/") {
+ return ErrorResult(fmt.Sprintf(
+ "file does not appear to be an image (detected type: %s)", mediaType,
+ ))
+ }
+
+ filename := filepath.Base(resolved)
+ scope := fmt.Sprintf("tool:load_image:%s:%s", channel, chatID)
+
+ ref, err := t.mediaStore.Store(resolved, media.MediaMeta{
+ Filename: filename,
+ ContentType: mediaType,
+ Source: "tool:load_image",
+ CleanupPolicy: media.CleanupPolicyForgetOnly,
+ }, scope)
+ if err != nil {
+ return ErrorResult(fmt.Sprintf("failed to register image in media store: %v", err))
+ }
+
+ // Build the tool result text. The media:// ref will be picked up by
+ // resolveMediaRefs in loop_media.go and converted to a base64 data URL
+ // before the next LLM call, exactly like channel-received images.
+ msg := fmt.Sprintf("Image loaded: %s\n[image: %s]", filename, ref)
+
+ return &ToolResult{
+ ForLLM: msg,
+ ForUser: fmt.Sprintf("Loaded image: %s", filename),
+ // Media refs inside ForLLM are resolved by resolveMediaRefs in the
+ // agent loop before the next LLM call. Do NOT use MediaResult here —
+ // that would send the file to the user channel instead.
+ Media: []string{ref},
+ }
+}
diff --git a/pkg/tools/load_image_test.go b/pkg/tools/load_image_test.go
new file mode 100644
index 000000000..91118f93e
--- /dev/null
+++ b/pkg/tools/load_image_test.go
@@ -0,0 +1,174 @@
+package tools
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/media"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+func TestLoadImage_PathRequired(t *testing.T) {
+ tool := NewLoadImageTool("/tmp", false, 0, nil)
+ ctx := WithToolContext(context.Background(), "test", "chat1")
+ result := tool.Execute(ctx, map[string]any{})
+ if !result.IsError {
+ t.Fatal("expected error for missing path")
+ }
+}
+
+func TestLoadImage_NilMediaStore(t *testing.T) {
+ tool := NewLoadImageTool("/tmp", false, 0, nil)
+ ctx := WithToolContext(context.Background(), "test", "chat1")
+ result := tool.Execute(ctx, map[string]any{"path": "test.png"})
+ if !result.IsError || result.ForLLM != "media store not configured" {
+ t.Fatalf("expected media store error, got: %s", result.ForLLM)
+ }
+}
+
+func TestLoadImage_NoChannelContext(t *testing.T) {
+ store := media.NewFileMediaStore()
+ tool := NewLoadImageTool("/tmp", false, 0, store)
+ // No WithToolContext — should fail
+ result := tool.Execute(context.Background(), map[string]any{"path": "test.png"})
+ if !result.IsError || result.ForLLM != "no target channel/chat available" {
+ t.Fatalf("expected channel error, got: %s", result.ForLLM)
+ }
+}
+
+func TestLoadImage_NonImageFile(t *testing.T) {
+ dir := t.TempDir()
+ txtFile := filepath.Join(dir, "readme.txt")
+ os.WriteFile(txtFile, []byte("hello"), 0o644)
+
+ store := media.NewFileMediaStore()
+ tool := NewLoadImageTool(dir, false, 0, store)
+ ctx := WithToolContext(context.Background(), "test", "chat1")
+ result := tool.Execute(ctx, map[string]any{"path": txtFile})
+ if !result.IsError {
+ t.Fatal("expected error for non-image file")
+ }
+}
+
+func TestLoadImage_DefaultMaxSize(t *testing.T) {
+ tool := NewLoadImageTool("/tmp", false, 0, nil)
+ if tool.maxFileSize != config.DefaultMaxMediaSize {
+ t.Errorf("expected default max size %d, got %d", config.DefaultMaxMediaSize, tool.maxFileSize)
+ }
+}
+
+func TestLoadImage_FileTooLarge(t *testing.T) {
+ dir := t.TempDir()
+ bigFile := filepath.Join(dir, "big.png")
+ // Create a file with PNG header but exceeding max size
+ data := make([]byte, 1024)
+ copy(data, []byte{0x89, 0x50, 0x4E, 0x47}) // PNG magic bytes
+ os.WriteFile(bigFile, data, 0o644)
+
+ store := media.NewFileMediaStore()
+ tool := NewLoadImageTool(dir, false, 512, store) // maxSize = 512
+ ctx := WithToolContext(context.Background(), "test", "chat1")
+ result := tool.Execute(ctx, map[string]any{"path": bigFile})
+ if !result.IsError {
+ t.Fatal("expected error for oversized file")
+ }
+}
+
+func TestSubagentManager_SetMediaResolver_StoresResolver(t *testing.T) {
+ manager := NewSubagentManager(nil, "gpt-test", "/tmp")
+
+ called := false
+ manager.SetMediaResolver(func(msgs []providers.Message) []providers.Message {
+ called = true
+ return msgs
+ })
+
+ manager.mu.RLock()
+ got := manager.mediaResolver
+ manager.mu.RUnlock()
+
+ if got == nil {
+ t.Fatal("expected mediaResolver to be set")
+ }
+
+ if called {
+ t.Fatal("resolver should not be called during SetMediaResolver")
+ }
+}
+
+func TestLoadImage_SuccessPath(t *testing.T) {
+ dir := t.TempDir()
+
+ // Create a minimal valid PNG file (8-byte signature + minimal IHDR + IEND).
+ // The PNG spec requires the 8-byte magic header: 0x89 P N G \r \n 0x1a \n
+ pngSignature := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
+ // IHDR chunk: length(13) + "IHDR" + 1x1 px, 8-bit RGB, no interlace + CRC
+ ihdr := []byte{
+ 0x00, 0x00, 0x00, 0x0D, // chunk length = 13
+ 0x49, 0x48, 0x44, 0x52, // "IHDR"
+ 0x00, 0x00, 0x00, 0x01, // width = 1
+ 0x00, 0x00, 0x00, 0x01, // height = 1
+ 0x08, // bit depth = 8
+ 0x02, // color type = RGB
+ 0x00, 0x00, 0x00, // compression, filter, interlace
+ 0x90, 0x77, 0x53, 0xDE, // CRC (valid for this IHDR)
+ }
+ // IEND chunk
+ iend := []byte{
+ 0x00, 0x00, 0x00, 0x00, // chunk length = 0
+ 0x49, 0x45, 0x4E, 0x44, // "IEND"
+ 0xAE, 0x42, 0x60, 0x82, // CRC
+ }
+
+ pngData := make([]byte, 0, len(pngSignature)+len(ihdr)+len(iend))
+ pngData = append(pngData, pngSignature...)
+ pngData = append(pngData, ihdr...)
+ pngData = append(pngData, iend...)
+
+ imgPath := filepath.Join(dir, "test_image.png")
+ if err := os.WriteFile(imgPath, pngData, 0o644); err != nil {
+ t.Fatalf("failed to create test PNG: %v", err)
+ }
+
+ store := media.NewFileMediaStore()
+ tool := NewLoadImageTool(dir, false, 0, store)
+ ctx := WithToolContext(context.Background(), "test", "chat1")
+
+ result := tool.Execute(ctx, map[string]any{"path": imgPath})
+
+ // 1. Must not be an error
+ if result.IsError {
+ t.Fatalf("expected success, got error: %s", result.ForLLM)
+ }
+
+ // 2. Media must contain exactly one media:// ref
+ if len(result.Media) != 1 {
+ t.Fatalf("expected 1 media ref, got %d", len(result.Media))
+ }
+ if !strings.HasPrefix(result.Media[0], "media://") {
+ t.Errorf("expected media ref to start with 'media://', got: %s", result.Media[0])
+ }
+
+ // 3. ForLLM must contain the [image: marker
+ if !strings.Contains(result.ForLLM, "[image:") {
+ t.Errorf("expected ForLLM to contain '[image:' marker, got: %s", result.ForLLM)
+ }
+
+ // 4. ForLLM should also contain the media:// ref
+ if !strings.Contains(result.ForLLM, result.Media[0]) {
+ t.Errorf("expected ForLLM to contain media ref %q, got: %s", result.Media[0], result.ForLLM)
+ }
+
+ // 5. Verify the ref is resolvable in the store
+ resolved, err := store.Resolve(result.Media[0])
+ if err != nil {
+ t.Fatalf("media ref not resolvable: %v", err)
+ }
+ if resolved != imgPath {
+ t.Errorf("expected resolved path %q, got %q", imgPath, resolved)
+ }
+}
diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go
index 5bffb4e89..1caf390cf 100644
--- a/pkg/tools/mcp_tool.go
+++ b/pkg/tools/mcp_tool.go
@@ -6,11 +6,14 @@ import (
"fmt"
"hash/fnv"
"os"
+ "path/filepath"
"strings"
"time"
+ "unicode/utf8"
"github.com/modelcontextprotocol/go-sdk/mcp"
+ "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
)
@@ -26,18 +29,21 @@ type MCPManager interface {
// MCPTool wraps an MCP tool to implement the Tool interface
type MCPTool struct {
- manager MCPManager
- serverName string
- tool *mcp.Tool
- mediaStore media.MediaStore
+ manager MCPManager
+ serverName string
+ tool *mcp.Tool
+ mediaStore media.MediaStore
+ workspace string
+ maxInlineTextRunes int
}
// NewMCPTool creates a new MCP tool wrapper
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
return &MCPTool{
- manager: manager,
- serverName: serverName,
- tool: tool,
+ manager: manager,
+ serverName: serverName,
+ tool: tool,
+ maxInlineTextRunes: maxMCPInlineTextRunes,
}
}
@@ -45,6 +51,18 @@ func (t *MCPTool) SetMediaStore(store media.MediaStore) {
t.mediaStore = store
}
+func (t *MCPTool) SetWorkspace(workspace string) {
+ t.workspace = strings.TrimSpace(workspace)
+}
+
+func (t *MCPTool) SetMaxInlineTextRunes(limit int) {
+ if limit > 0 {
+ t.maxInlineTextRunes = limit
+ }
+}
+
+const maxMCPInlineTextRunes = 16 * 1024
+
// sanitizeIdentifierComponent normalizes a string so it can be safely used
// as part of a tool/function identifier for downstream providers.
// It:
@@ -255,14 +273,19 @@ func extractContentText(content []mcp.Content) string {
func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Content) *ToolResult {
llmParts := make([]string, 0, len(content))
+ rawTextParts := make([]string, 0, len(content))
mediaRefs := make([]string, 0, len(content))
for _, c := range content {
switch v := c.(type) {
case *mcp.TextContent:
- text := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
- if text != "" {
- llmParts = append(llmParts, text)
+ rawText := strings.TrimSpace(v.Text)
+ if rawText != "" {
+ rawTextParts = append(rawTextParts, rawText)
+ }
+ safeText := strings.TrimSpace(sanitizeToolLLMContent(v.Text))
+ if safeText != "" {
+ llmParts = append(llmParts, safeText)
}
case *mcp.ImageContent:
ref, note := t.storeBinaryContent(
@@ -295,10 +318,13 @@ func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Cont
case *mcp.ResourceLink:
llmParts = append(llmParts, summarizeResourceLink(v))
case *mcp.EmbeddedResource:
- ref, note := t.storeEmbeddedResource(ctx, v)
+ ref, note, rawText := t.storeEmbeddedResource(ctx, v)
if ref != "" {
mediaRefs = append(mediaRefs, ref)
}
+ if rawText != "" {
+ rawTextParts = append(rawTextParts, rawText)
+ }
if note != "" {
llmParts = append(llmParts, note)
}
@@ -307,34 +333,105 @@ func (t *MCPTool) normalizeResultContent(ctx context.Context, content []mcp.Cont
}
}
+ forLLM := strings.Join(compactStrings(llmParts), "\n")
+ rawText := strings.Join(compactStrings(rawTextParts), "\n")
+ if artifactResult := t.persistLargeTextArtifact(rawText); artifactResult != nil {
+ artifactResult.Media = mediaRefs
+ return artifactResult
+ }
+
result := &ToolResult{
- ForLLM: strings.Join(compactStrings(llmParts), "\n"),
+ ForLLM: forLLM,
Media: mediaRefs,
}
return result
}
-func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string) {
+func (t *MCPTool) persistLargeTextArtifact(text string) *ToolResult {
+ text = strings.TrimSpace(text)
+ limit := t.maxInlineTextRunes
+ if limit <= 0 {
+ limit = maxMCPInlineTextRunes
+ }
+ size := utf8.RuneCountInString(text)
+ if text == "" || size <= limit || t.workspace == "" {
+ return nil
+ }
+
+ dir := filepath.Join(t.workspace, ".artifacts", "mcp")
+ if err := os.MkdirAll(dir, 0o700); err != nil {
+ return t.largeTextArtifactFallback(text, err)
+ }
+ // TODO: Add lifecycle cleanup/retention for MCP artifact files.
+
+ pattern := fmt.Sprintf(
+ "%s_%s_*.txt",
+ sanitizeIdentifierComponent(t.serverName),
+ sanitizeIdentifierComponent(t.tool.Name),
+ )
+ tmpFile, err := os.CreateTemp(dir, pattern)
+ if err != nil {
+ return t.largeTextArtifactFallback(text, err)
+ }
+ path := tmpFile.Name()
+ if _, err = tmpFile.WriteString(text); err != nil {
+ _ = tmpFile.Close()
+ _ = os.Remove(path)
+ return t.largeTextArtifactFallback(text, err)
+ }
+ if err = tmpFile.Close(); err != nil {
+ _ = os.Remove(path)
+ return t.largeTextArtifactFallback(text, err)
+ }
+
+ return &ToolResult{
+ ForLLM: fmt.Sprintf(
+ "[MCP returned a large text result (%d chars); omitted from model context and saved as a local artifact.]",
+ size,
+ ),
+ ArtifactTags: []string{"[file:" + path + "]"},
+ }
+}
+
+func (t *MCPTool) largeTextArtifactFallback(text string, err error) *ToolResult {
+ size := utf8.RuneCountInString(text)
+ logger.WarnCF("tool", "Failed to persist large MCP text artifact", map[string]any{
+ "server": t.serverName,
+ "tool": t.tool.Name,
+ "chars": size,
+ "error": err.Error(),
+ })
+ return &ToolResult{
+ ForLLM: fmt.Sprintf(
+ "[MCP returned a large text result (%d chars); omitted from model context because artifact persistence failed.]",
+ size,
+ ),
+ }
+}
+
+func (t *MCPTool) storeEmbeddedResource(ctx context.Context, content *mcp.EmbeddedResource) (string, string, string) {
if content == nil || content.Resource == nil {
- return "", "[MCP returned an embedded resource without data.]"
+ return "", "[MCP returned an embedded resource without data.]", ""
}
resource := content.Resource
if len(resource.Blob) > 0 {
- return t.storeBinaryContent(
+ ref, note := t.storeBinaryContent(
ctx,
"resource",
normalizedMIMEType(resource.MIMEType),
resource.Blob,
content.Annotations,
)
+ return ref, note, ""
}
- if strings.TrimSpace(resource.Text) != "" {
- return "", sanitizeToolLLMContent(resource.Text)
+ rawText := strings.TrimSpace(resource.Text)
+ if rawText != "" {
+ return "", sanitizeToolLLMContent(resource.Text), rawText
}
- return "", summarizeEmbeddedResource(content)
+ return "", summarizeEmbeddedResource(content), ""
}
func (t *MCPTool) storeBinaryContent(
diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go
index 8bbac3bc7..f2b02d6f6 100644
--- a/pkg/tools/mcp_tool_test.go
+++ b/pkg/tools/mcp_tool_test.go
@@ -634,3 +634,177 @@ func TestMCPTool_Execute_LargeBase64TextIsOmittedFromContext(t *testing.T) {
t.Fatalf("expected sanitized large base64 note, got %q", result.ForLLM)
}
}
+
+func TestMCPTool_Execute_LargeBase64TextArtifactPreservesRawPayload(t *testing.T) {
+ workspace := t.TempDir()
+ largeBase64 := strings.Repeat("QUJD", 400)
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: largeBase64},
+ },
+ }, nil
+ },
+ }
+
+ mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
+ mcpTool.SetWorkspace(workspace)
+ mcpTool.SetMaxInlineTextRunes(32)
+
+ result := mcpTool.Execute(context.Background(), nil)
+
+ if !strings.Contains(result.ForLLM, "saved as a local artifact") {
+ t.Fatalf("expected artifact note, got %q", result.ForLLM)
+ }
+ if result.ForLLM == largeBase64OmittedMessage {
+ t.Fatalf("expected artifact note instead of sanitized base64 placeholder")
+ }
+ if len(result.ArtifactTags) != 1 {
+ t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
+ }
+ tag := result.ArtifactTags[0]
+ const prefix = "[file:"
+ if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
+ t.Fatalf("expected file artifact tag, got %q", tag)
+ }
+ path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("expected artifact file to be readable: %v", err)
+ }
+ if string(data) != largeBase64 {
+ t.Fatalf("expected artifact file contents to preserve raw MCP payload")
+ }
+}
+
+func TestMCPTool_Execute_LargeTextStoredAsArtifact(t *testing.T) {
+ workspace := t.TempDir()
+ largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: largeText},
+ },
+ }, nil
+ },
+ }
+
+ mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
+ mcpTool.SetWorkspace(workspace)
+
+ result := mcpTool.Execute(context.Background(), nil)
+
+ if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
+ t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "saved as a local artifact") {
+ t.Fatalf("expected artifact note, got %q", result.ForLLM)
+ }
+ if len(result.ArtifactTags) != 1 {
+ t.Fatalf("expected 1 artifact tag, got %d", len(result.ArtifactTags))
+ }
+ tag := result.ArtifactTags[0]
+ const prefix = "[file:"
+ if !strings.HasPrefix(tag, prefix) || !strings.HasSuffix(tag, "]") {
+ t.Fatalf("expected file artifact tag, got %q", tag)
+ }
+ path := strings.TrimSuffix(strings.TrimPrefix(tag, prefix), "]")
+ if !strings.HasPrefix(path, workspace) {
+ t.Fatalf("expected artifact inside workspace, got %q", path)
+ }
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("expected artifact file to be readable: %v", err)
+ }
+ if string(data) != strings.TrimSpace(largeText) {
+ t.Fatalf("expected artifact file contents to match source text")
+ }
+}
+
+func TestMCPTool_Execute_CustomInlineTextThreshold(t *testing.T) {
+ workspace := t.TempDir()
+ text := strings.Repeat("small custom threshold text\n", 20)
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: text},
+ },
+ }, nil
+ },
+ }
+
+ mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
+ mcpTool.SetWorkspace(workspace)
+ mcpTool.SetMaxInlineTextRunes(32)
+
+ result := mcpTool.Execute(context.Background(), nil)
+
+ if len(result.ArtifactTags) != 1 {
+ t.Fatalf("expected custom threshold to persist artifact, got %+v", result)
+ }
+ if strings.Contains(result.ForLLM, "small custom threshold text") {
+ t.Fatalf("expected text to be omitted from ForLLM, got %q", result.ForLLM)
+ }
+}
+
+func TestMCPTool_Execute_LargeTextArtifactFailureStillOmitsContext(t *testing.T) {
+ workspaceRoot := t.TempDir()
+ workspaceFile := filepath.Join(workspaceRoot, "not-a-directory")
+ if err := os.WriteFile(workspaceFile, []byte("x"), 0o600); err != nil {
+ t.Fatalf("failed to create workspace file: %v", err)
+ }
+
+ largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: largeText},
+ },
+ }, nil
+ },
+ }
+
+ mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
+ mcpTool.SetWorkspace(workspaceFile)
+
+ result := mcpTool.Execute(context.Background(), nil)
+
+ if strings.Contains(result.ForLLM, "This is a large MCP text payload") {
+ t.Fatalf("expected large MCP text to be omitted from ForLLM, got %q", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "artifact persistence failed") {
+ t.Fatalf("expected persistence failure note, got %q", result.ForLLM)
+ }
+ if len(result.ArtifactTags) != 0 {
+ t.Fatalf("expected no artifact tags on persistence failure, got %+v", result.ArtifactTags)
+ }
+}
+
+func TestMCPTool_Execute_WhitespaceWorkspaceDisablesArtifactPersistence(t *testing.T) {
+ largeText := strings.Repeat("This is a large MCP text payload.\n", 800)
+ manager := &MockMCPManager{
+ callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
+ return &mcp.CallToolResult{
+ Content: []mcp.Content{
+ &mcp.TextContent{Text: largeText},
+ },
+ }, nil
+ },
+ }
+
+ mcpTool := NewMCPTool(manager, "test_server", &mcp.Tool{Name: "dump_payload"})
+ mcpTool.SetWorkspace(" \n\t ")
+
+ result := mcpTool.Execute(context.Background(), nil)
+
+ if len(result.ArtifactTags) != 0 {
+ t.Fatalf("expected no artifact tags for whitespace workspace, got %+v", result.ArtifactTags)
+ }
+ if !strings.Contains(result.ForLLM, "This is a large MCP text payload") {
+ t.Fatalf("expected large text to remain inline when workspace is blank, got %q", result.ForLLM)
+ }
+}
diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go
index 56af8d695..e51dff71a 100644
--- a/pkg/tools/registry.go
+++ b/pkg/tools/registry.go
@@ -228,6 +228,7 @@ func (r *ToolRegistry) ExecuteWithContext(
func() {
defer func() {
if re := recover(); re != nil {
+ logger.RecoverPanicNoExit(re)
errMsg := fmt.Sprintf("Tool '%s' crashed with panic: %v", name, re)
logger.ErrorCF("tool", "Tool execution panic recovered",
map[string]any{
diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go
index 9a1a8b802..ada89efb7 100644
--- a/pkg/tools/subagent.go
+++ b/pkg/tools/subagent.go
@@ -67,6 +67,12 @@ type SubagentManager struct {
hasTemperature bool
nextID int
spawner SpawnSubTurnFunc
+
+ // mediaResolver resolves media:// refs in tool-loop messages before
+ // each LLM call in the legacy RunToolLoop fallback path.
+ // This lets subagents reuse the same media handling behavior as the
+ // main agent loop without importing pkg/agent and creating a cycle.
+ mediaResolver func([]providers.Message) []providers.Message
}
func NewSubagentManager(
@@ -90,6 +96,17 @@ func (sm *SubagentManager) SetSpawner(spawner SpawnSubTurnFunc) {
sm.spawner = spawner
}
+// SetMediaResolver injects a message preprocessor that resolves media:// refs
+// into LLM-ready content before each tool-loop iteration.
+// This is only used by the legacy RunToolLoop fallback path.
+func (sm *SubagentManager) SetMediaResolver(
+ resolver func([]providers.Message) []providers.Message,
+) {
+ sm.mu.Lock()
+ defer sm.mu.Unlock()
+ sm.mediaResolver = resolver
+}
+
// SetLLMOptions sets max tokens and temperature for subagent LLM calls.
func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) {
sm.mu.Lock()
@@ -177,6 +194,7 @@ func (sm *SubagentManager) runTask(
temperature := sm.temperature
hasMaxTokens := sm.hasMaxTokens
hasTemperature := sm.hasTemperature
+ mediaResolver := sm.mediaResolver
sm.mu.RUnlock()
var result *ToolResult
@@ -223,6 +241,7 @@ After completing the task, provide a clear summary of what was done.`
Tools: tools,
MaxIterations: maxIter,
LLMOptions: llmOptions,
+ MediaResolver: mediaResolver,
}, messages, task.OriginChannel, task.OriginChatID)
if err == nil {
diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go
index 387813e94..ac568f598 100644
--- a/pkg/tools/toolloop.go
+++ b/pkg/tools/toolloop.go
@@ -24,6 +24,11 @@ type ToolLoopConfig struct {
Tools *ToolRegistry
MaxIterations int
LLMOptions map[string]any
+
+ // MediaResolver resolves media:// refs in messages before each LLM call.
+ // This is optional and is mainly used by subagent legacy fallback execution
+ // so subagents can reuse the same multimodal media handling as the main loop.
+ MediaResolver func(messages []providers.Message) []providers.Message
}
// ToolLoopResult contains the result of running the tool loop.
@@ -63,8 +68,27 @@ func RunToolLoop(
if llmOpts == nil {
llmOpts = map[string]any{}
}
- // 3. Call LLM
- response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
+
+ // 3. Resolve media:// refs and Call LLM.
+ // Tools like load_image produce media:// refs in their result messages.
+ // Without this step, the LLM would receive raw "media://uuid" strings
+ // instead of base64-encoded image data URLs.
+ //
+ // We build a separate callMessages slice so that:
+ // (a) the resolver output is used for the LLM call only,
+ // (b) the original `messages` slice keeps the unresolved refs for
+ // subsequent iterations — the resolver is idempotent but working
+ // on the original avoids double-encoding issues.
+ //
+ // On iteration 1 the initial user messages typically have no media://
+ // refs (they come from plain text), so this is effectively a no-op;
+ // it becomes relevant from iteration 2 onward when tool results may
+ // contain media refs.
+ callMessages := messages
+ if config.MediaResolver != nil && iteration > 1 {
+ callMessages = config.MediaResolver(messages)
+ }
+ response, err := config.Provider.Chat(ctx, callMessages, providerToolDefs, config.Model, llmOpts)
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
@@ -161,11 +185,15 @@ func RunToolLoop(
for _, r := range results {
contentForLLM := r.result.ContentForLLM()
- messages = append(messages, providers.Message{
+ toolMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: r.tc.ID,
- })
+ }
+ if len(r.result.Media) > 0 && !r.result.ResponseHandled {
+ toolMsg.Media = append(toolMsg.Media, r.result.Media...)
+ }
+ messages = append(messages, toolMsg)
}
}
diff --git a/pkg/tools/tts_send.go b/pkg/tools/tts_send.go
new file mode 100644
index 000000000..3d569e3f7
--- /dev/null
+++ b/pkg/tools/tts_send.go
@@ -0,0 +1,82 @@
+package tools
+
+import (
+ "context"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/audio/tts"
+ "github.com/sipeed/picoclaw/pkg/media"
+)
+
+type SendTTSTool struct {
+ provider tts.TTSProvider
+ mediaStore media.MediaStore
+}
+
+func NewSendTTSTool(provider tts.TTSProvider, store media.MediaStore) *SendTTSTool {
+ return &SendTTSTool{
+ provider: provider,
+ mediaStore: store,
+ }
+}
+
+func (t *SendTTSTool) Name() string { return "send_tts" }
+
+func (t *SendTTSTool) Description() string {
+ return "Synthesize speech from text and send it as an audio file to the user."
+}
+
+func (t *SendTTSTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "text": map[string]any{
+ "type": "string",
+ "description": "The text to synthesize into speech. NOTE: Reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally.",
+ },
+ "filename": map[string]any{
+ "type": "string",
+ "description": "Optional filename for the audio file (e.g., response.ogg).",
+ },
+ },
+ "required": []string{"text"},
+ }
+}
+
+func (t *SendTTSTool) SetMediaStore(store media.MediaStore) {
+ t.mediaStore = store
+}
+
+func (t *SendTTSTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
+ text, _ := args["text"].(string)
+ text = strings.TrimSpace(text)
+ if text == "" {
+ return ErrorResult("text is required")
+ }
+
+ channel := ToolChannel(ctx)
+ chatID := ToolChatID(ctx)
+ filename, _ := args["filename"].(string)
+
+ ref, err := tts.SynthesizeAndStore(
+ ctx,
+ t.provider,
+ t.mediaStore,
+ text,
+ filename,
+ channel,
+ chatID,
+ )
+ if err != nil {
+ return ErrorResult(err.Error()).WithError(err)
+ }
+
+ // Return with ForUser set to original text, Media containing the audio ref,
+ // and mark as ResponseHandled so the audio is sent immediately without LLM intervention.
+ return &ToolResult{
+ ForLLM: "TTS audio sent",
+ ForUser: text,
+ Media: []string{ref},
+ ResponseHandled: true,
+ }
+}
diff --git a/pkg/updater/updater.go b/pkg/updater/updater.go
new file mode 100644
index 000000000..e73c1e859
--- /dev/null
+++ b/pkg/updater/updater.go
@@ -0,0 +1,707 @@
+package updater
+
+import (
+ "archive/tar"
+ "archive/zip"
+ "compress/gzip"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/minio/selfupdate"
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// httpClient is a shared HTTP client used for release checks and downloads.
+// The Timeout value applies to the entire HTTP request: dialing, TLS
+// handshake, redirects, and reading the response body. It is NOT only
+// a connection (dial) timeout. To control lower-level timeouts (dial,
+// TLS handshake, response header wait), supply a custom Transport with
+// an appropriately configured net.Dialer.
+var httpClient = &http.Client{Timeout: 2 * time.Minute}
+
+// DownloadAndExtractRelease downloads a release archive (or uses a direct
+// asset URL) and extracts it to a temporary directory. It returns the
+// extraction directory on success. If releaseURL is empty, the latest
+// release of the current project is used. platform/arch can be used to
+// select the correct asset (e.g. "linux", "amd64").
+func DownloadAndExtractRelease(releaseURL, platform, arch string) (string, error) {
+ assetURL, checksum, err := findAssetInfo(releaseURL, platform, arch)
+ if err != nil {
+ return "", err
+ }
+
+ // Download asset to temp file. Use the asset URL extension so
+ // extractArchive can detect the archive format (zip/tar.gz/tar).
+ tmpPattern := "picoclaw-release-*"
+ if u, perr := url.Parse(assetURL); perr == nil {
+ base := filepath.Base(u.Path)
+ lbase := strings.ToLower(base)
+ switch {
+ case strings.HasSuffix(lbase, ".zip"):
+ tmpPattern += ".zip"
+ case strings.HasSuffix(lbase, ".tar.gz") || strings.HasSuffix(lbase, ".tgz"):
+ tmpPattern += ".tar.gz"
+ case strings.HasSuffix(lbase, ".tar"):
+ tmpPattern += ".tar"
+ default:
+ tmpPattern += ".archive"
+ }
+ } else {
+ tmpPattern += ".archive"
+ }
+
+ tmpFile, err := os.CreateTemp("", tmpPattern)
+ if err != nil {
+ return "", err
+ }
+ tmpPath := tmpFile.Name()
+ defer tmpFile.Close()
+
+ resp, err := httpClient.Get(assetURL)
+ if err != nil {
+ os.Remove(tmpPath)
+ return "", err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ os.Remove(tmpPath)
+ return "", fmt.Errorf("failed to download asset: status %d", resp.StatusCode)
+ }
+
+ // Stream download while computing SHA256 to avoid a second download.
+ // Also show a simple progress line to stderr so users see activity.
+ h := sha256.New()
+ pw := &progressWriter{total: resp.ContentLength}
+ mw := io.MultiWriter(tmpFile, h, pw)
+ if _, err = io.Copy(mw, resp.Body); err != nil {
+ _ = os.Remove(tmpPath)
+ return "", err
+ }
+ // ensure final progress line ends with newline
+ pw.Finish()
+
+ // verify checksum if available
+ if checksum != "" {
+ got := hex.EncodeToString(h.Sum(nil))
+ if !strings.EqualFold(got, checksum) {
+ _ = os.Remove(tmpPath)
+ return "", fmt.Errorf("checksum mismatch: got %s expected %s", got, checksum)
+ }
+ }
+
+ // Extract
+ destDir, err := os.MkdirTemp("", "picoclaw-extract-*")
+ if err != nil {
+ os.Remove(tmpPath)
+ return "", err
+ }
+
+ if err := extractArchive(tmpPath, destDir); err != nil {
+ os.Remove(tmpPath)
+ os.RemoveAll(destDir)
+ return "", err
+ }
+
+ // cleanup archive file; keep extracted contents
+ _ = os.Remove(tmpPath)
+ return destDir, nil
+}
+
+// UpdateSelfFromRelease downloads the release matching the given parameters,
+// extracts it and applies the binary named programName to update the
+// currently running executable using minio/selfupdate.
+// If releaseURL is empty, the latest release is used. If platform or arch
+// is empty, runtime values are used.
+func UpdateSelfFromRelease(releaseURL, platform, arch, programName string) error {
+ if platform == "" {
+ platform = runtime.GOOS
+ }
+ if arch == "" {
+ arch = runtime.GOARCH
+ }
+
+ dir, err := DownloadAndExtractRelease(releaseURL, platform, arch)
+ if err != nil {
+ return err
+ }
+ defer os.RemoveAll(dir)
+
+ binPath, err := findBinaryInDir(dir, programName)
+ if err != nil {
+ return err
+ }
+
+ // ensure executable bit on non-windows
+ if runtime.GOOS != "windows" {
+ _ = os.Chmod(binPath, 0o755)
+ }
+
+ f, err := os.Open(binPath)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ // Backup current executable so we can roll back if needed.
+ var opts selfupdate.Options
+ if exePath, err := os.Executable(); err == nil {
+ opts.OldSavePath = exePath + ".old"
+ }
+
+ if err := selfupdate.Apply(f, opts); err != nil {
+ return fmt.Errorf("apply update: %w", err)
+ }
+
+ return nil
+}
+
+// UpdateSelf updates the running executable by fetching the latest release
+// and applying the binary matching programName.
+func UpdateSelf(programName string) error {
+ // By default, select the latest stable release when no explicit
+ // release URL is provided. Use --nightly or a custom URL to override.
+ return UpdateSelfFromRelease("", runtime.GOOS, runtime.GOARCH, programName)
+}
+
+// GetReleaseAPIURL returns the GitHub Releases API URL for the given repo owner.
+// Example: owner="sky5454" -> https://api.github.com/repos/sky5454/picoclaw/releases/latest
+func GetReleaseAPIURL(owner string) string {
+ return fmt.Sprintf("https://api.github.com/repos/%s/picoclaw/releases/latest", owner)
+}
+
+// GetProdReleaseAPIURL returns the production release API URL (upstream).
+func GetProdReleaseAPIURL() string {
+ return GetReleaseAPIURL("sipeed")
+}
+
+// GetReleaseTagAPIURL returns the GitHub Releases API URL for a specific tag.
+// Example: owner="sipeed", tag="nightly" -> https://api.github.com/repos/sipeed/picoclaw/releases/tags/nightly
+func GetReleaseTagAPIURL(owner, tag string) string {
+ return fmt.Sprintf("https://api.github.com/repos/%s/picoclaw/releases/tags/%s", owner, tag)
+}
+
+// GetNightlyReleaseAPIURL returns the nightly release API URL for the production repo.
+func GetNightlyReleaseAPIURL() string {
+ return GetReleaseTagAPIURL("sipeed", "nightly")
+}
+
+// findAssetURL resolves the appropriate asset URL for the given release
+// selector. It accepts direct archive URLs as well as GitHub release URLs
+// or empty (latest release for the project).
+func findAssetInfo(releaseURL, platform, arch string) (string, string, error) {
+ // returns (assetURL, sha256ChecksumHex, error)
+ if looksLikeDirectAssetURL(releaseURL) {
+ return "", "", fmt.Errorf("no checksum found for asset %s", releaseURL)
+ }
+
+ apiURL := buildReleaseAPIURL(releaseURL)
+ if apiURL == "" {
+ // If caller provided an empty releaseURL, default to the
+ // production latest release API URL (stable release).
+ apiURL = GetProdReleaseAPIURL()
+ }
+
+ resp, err := httpClient.Get(apiURL)
+ if err != nil {
+ return "", "", err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return "", "", fmt.Errorf("failed to query releases: status %d", resp.StatusCode)
+ }
+
+ var data struct {
+ TagName string `json:"tag_name"`
+ Assets []struct {
+ Name string `json:"name"`
+ BrowserDownloadURL string `json:"browser_download_url"`
+ Digest string `json:"digest"`
+ } `json:"assets"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
+ return "", "", err
+ }
+
+ // Selection order: platform -> arch -> extension.
+ platformLower := strings.ToLower(platform)
+ archLower := strings.ToLower(arch)
+
+ isZip := func(name string) bool {
+ return strings.HasSuffix(name, ".zip")
+ }
+ isTarGz := func(name string) bool {
+ return strings.HasSuffix(name, ".tar.gz") || strings.HasSuffix(name, ".tgz")
+ }
+ isTar := func(name string) bool { return strings.HasSuffix(name, ".tar") }
+
+ // collect indices of assets that contain platform (if provided)
+ var platformIdx []int
+ for i, a := range data.Assets {
+ n := strings.ToLower(a.Name)
+ if platform == "" || strings.Contains(n, platformLower) {
+ platformIdx = append(platformIdx, i)
+ }
+ }
+
+ pickBest := func(idxs []int) (string, int, bool) {
+ if len(idxs) == 0 {
+ return "", -1, false
+ }
+ // prefer arch matches within idxs; if arch was specified but
+ // no arch match exists among idxs, treat as no candidate.
+ var archIdx []int
+ if arch != "" {
+ aliases := archAliases(archLower)
+ for _, i := range idxs {
+ n := strings.ToLower(data.Assets[i].Name)
+ for _, ali := range aliases {
+ if strings.Contains(n, ali) {
+ archIdx = append(archIdx, i)
+ break
+ }
+ }
+ }
+ if len(archIdx) == 0 {
+ return "", -1, false
+ }
+ }
+ candidates := archIdx
+ if len(candidates) == 0 {
+ candidates = idxs
+ }
+
+ // extension preference
+ if platformLower == "windows" {
+ // prefer .zip only
+ for _, i := range candidates {
+ if isZip(strings.ToLower(data.Assets[i].Name)) {
+ return data.Assets[i].BrowserDownloadURL, i, true
+ }
+ }
+ // if no zip found, fallthrough to first candidate
+ return data.Assets[candidates[0]].BrowserDownloadURL, candidates[0], true
+ }
+
+ // non-windows: prefer tar.gz/tgz, then tar, then zip
+ for _, i := range candidates {
+ if isTarGz(strings.ToLower(data.Assets[i].Name)) {
+ return data.Assets[i].BrowserDownloadURL, i, true
+ }
+ }
+ for _, i := range candidates {
+ if isTar(strings.ToLower(data.Assets[i].Name)) {
+ return data.Assets[i].BrowserDownloadURL, i, true
+ }
+ }
+ for _, i := range candidates {
+ if isZip(strings.ToLower(data.Assets[i].Name)) {
+ return data.Assets[i].BrowserDownloadURL, i, true
+ }
+ }
+ // fallback to first candidate
+ return data.Assets[candidates[0]].BrowserDownloadURL, candidates[0], true
+ }
+
+ // Try platform matches first
+ if url, idx, ok := pickBest(platformIdx); ok {
+ // attempt to find checksum: prefer asset digest from API if present
+ if d := strings.TrimSpace(data.Assets[idx].Digest); d != "" {
+ dLower := strings.ToLower(d)
+ if strings.HasPrefix(dLower, "sha256:") {
+ hexpart := strings.TrimPrefix(dLower, "sha256:")
+ return url, hexpart, nil
+ }
+ // If digest already looks like a 64-hex, return it
+ if ok, _ := regexp.MatchString("(?i)^[a-f0-9]{64}$", dLower); ok {
+ return url, dLower, nil
+ }
+ }
+ // Look for checksum assets and verify by computing the asset's sha256.
+ for j, a := range data.Assets {
+ n := strings.ToLower(a.Name)
+ if strings.Contains(n, "sha256") ||
+ strings.Contains(n, "sha256sum") ||
+ strings.Contains(n, "checksums") ||
+ strings.HasSuffix(n, ".sha256") ||
+ strings.HasSuffix(n, ".sha256sum") {
+ resp2, err := httpClient.Get(data.Assets[j].BrowserDownloadURL)
+ if err != nil {
+ continue
+ }
+ bs, err := io.ReadAll(resp2.Body)
+ resp2.Body.Close()
+ if err != nil {
+ continue
+ }
+ if h, ok := findHashInChecksumContent(bs, url); ok {
+ return url, h, nil
+ }
+ }
+ }
+ // No checksum found for the selected platform asset -> error
+ return "", "", fmt.Errorf("no checksum found for asset %s", url)
+ }
+
+ // No platform match — require explicit platform+arch; fail fast.
+ return "", "", fmt.Errorf("no release asset matching platform %q and arch %q", platform, arch)
+}
+
+func looksLikeDirectAssetURL(u string) bool {
+ if u == "" {
+ return false
+ }
+ lower := strings.ToLower(u)
+ if strings.HasSuffix(lower, ".zip") ||
+ strings.HasSuffix(lower, ".tar.gz") ||
+ strings.HasSuffix(lower, ".tgz") ||
+ strings.HasSuffix(lower, ".tar") {
+ return true
+ }
+ if strings.Contains(lower, "/releases/download/") {
+ return true
+ }
+ return false
+}
+
+func buildReleaseAPIURL(releaseURL string) string {
+ if releaseURL == "" {
+ return ""
+ }
+ if strings.Contains(releaseURL, "api.github.com") {
+ return releaseURL
+ }
+ u, err := url.Parse(releaseURL)
+ if err != nil {
+ return ""
+ }
+ if u.Host != "github.com" {
+ return ""
+ }
+ parts := strings.Split(strings.Trim(u.Path, "/"), "/")
+ if len(parts) < 2 {
+ return ""
+ }
+ owner := parts[0]
+ repo := parts[1]
+ // if tag specified
+ if len(parts) >= 5 && parts[2] == "releases" && parts[3] == "tag" {
+ tag := parts[4]
+ return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/tags/%s", owner, repo, tag)
+ }
+ // default to latest
+ return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", owner, repo)
+}
+
+// NOTE: helper functions to compute SHA256 from URL/path were removed
+// after refactoring to stream the download and verify the checksum
+// during the single download to avoid double-transfer.
+
+// findHashInChecksumContent attempts to locate a 64-hex SHA256 in the
+// checksum file content that corresponds to assetURL. It returns the
+// found hash (lowercase) and true, or "", false if not found.
+func findHashInChecksumContent(bs []byte, assetURL string) (string, bool) {
+ s := strings.ToLower(string(bs))
+ var assetBase string
+ if u, err := url.Parse(assetURL); err == nil {
+ assetBase = strings.ToLower(filepath.Base(u.Path))
+ } else {
+ assetBase = strings.ToLower(filepath.Base(assetURL))
+ }
+ re := regexp.MustCompile(`(?i)\b([a-f0-9]{64})\b`)
+ // prefer a line containing the asset filename
+ for _, line := range strings.Split(s, "\n") {
+ if strings.Contains(line, assetBase) {
+ if m := re.FindString(line); m != "" {
+ return m, true
+ }
+ }
+ }
+ // fallback: if there's exactly one unique 64-hex value, return it
+ matches := re.FindAllString(s, -1)
+ uniq := map[string]struct{}{}
+ for _, m := range matches {
+ uniq[m] = struct{}{}
+ }
+ if len(uniq) == 1 {
+ for k := range uniq {
+ return k, true
+ }
+ }
+ return "", false
+}
+
+// progressWriter implements io.Writer and prints a simple progress
+// line to stderr while bytes are written. It is intended to be used
+// as one writer in an io.MultiWriter so we can stream-to-disk, compute
+// the sha256, and update the progress display in a single pass.
+type progressWriter struct {
+ total int64
+ written int64
+ last time.Time
+}
+
+func (pw *progressWriter) Write(p []byte) (int, error) {
+ n := len(p)
+ pw.written += int64(n)
+ now := time.Now()
+ if pw.last.IsZero() || now.Sub(pw.last) >= 200*time.Millisecond || (pw.total > 0 && pw.written == pw.total) {
+ pw.print()
+ pw.last = now
+ }
+ return n, nil
+}
+
+func (pw *progressWriter) print() {
+ if pw.total > 0 {
+ pct := float64(pw.written) * 100.0 / float64(pw.total)
+ fmt.Fprintf(os.Stderr, "\rDownloading: %s / %s (%.1f%%)", humanBytes(pw.written), humanBytes(pw.total), pct)
+ } else {
+ fmt.Fprintf(os.Stderr, "\rDownloading: %s", humanBytes(pw.written))
+ }
+}
+
+func (pw *progressWriter) Finish() {
+ pw.print()
+ fmt.Fprintln(os.Stderr, "")
+}
+
+func humanBytes(n int64) string {
+ f := float64(n)
+ const (
+ KB = 1024.0
+ MB = KB * 1024.0
+ GB = MB * 1024.0
+ )
+ switch {
+ case f >= GB:
+ return fmt.Sprintf("%.2f GB", f/GB)
+ case f >= MB:
+ return fmt.Sprintf("%.2f MB", f/MB)
+ case f >= KB:
+ return fmt.Sprintf("%.2f KB", f/KB)
+ default:
+ return fmt.Sprintf("%d B", n)
+ }
+}
+
+// archAliases returns common name variants for an architecture string
+// so we can match release asset names like "x86_64" vs Go's "amd64".
+// archAliases returns name variants for an architecture string.
+// If `arch` is empty or matches the local runtime.GOARCH, prefer the
+// compile-time architecture aliases provided by archAliasesForLocal
+// (implemented per-architecture via build tags). For other `arch`
+// values we use a small synonyms map.
+func archAliases(arch string) []string {
+ a := strings.ToLower(arch)
+ if syns, ok := archSynonyms[a]; ok {
+ return syns
+ }
+ return []string{a}
+}
+
+var archSynonyms = map[string][]string{
+ "amd64": {"amd64", "x86_64", "x64"},
+ "x86_64": {"amd64", "x86_64", "x64"},
+ "x64": {"amd64", "x86_64", "x64"},
+ "386": {"386", "x86"},
+ "x86": {"386", "x86"},
+ "arm64": {"arm64", "aarch64"},
+ "aarch64": {"arm64", "aarch64"},
+ "arm": {"arm"},
+}
+
+func extractArchive(archivePath, destDir string) error {
+ lower := strings.ToLower(archivePath)
+ if strings.HasSuffix(lower, ".zip") {
+ return extractZip(archivePath, destDir)
+ }
+ // treat .tar.gz and .tgz as gzip+tar
+ if strings.HasSuffix(lower, ".tar.gz") || strings.HasSuffix(lower, ".tgz") {
+ return extractTarGz(archivePath, destDir)
+ }
+ if strings.HasSuffix(lower, ".tar") {
+ return extractTar(archivePath, destDir)
+ }
+ // fallback: try tar.gz
+ return extractTarGz(archivePath, destDir)
+}
+
+func extractZip(archivePath, destDir string) error {
+ r, err := zip.OpenReader(archivePath)
+ if err != nil {
+ return err
+ }
+ defer r.Close()
+ destClean := filepath.Clean(destDir)
+ for _, f := range r.File {
+ target := filepath.Clean(filepath.Join(destClean, f.Name))
+ if !strings.HasPrefix(target, destClean+string(os.PathSeparator)) && target != destClean {
+ return fmt.Errorf("path traversal detected: %s", f.Name)
+ }
+ if f.FileInfo().IsDir() {
+ if err := os.MkdirAll(target, f.FileInfo().Mode()); err != nil {
+ return err
+ }
+ continue
+ }
+ if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
+ return err
+ }
+ rc, err := f.Open()
+ if err != nil {
+ return err
+ }
+ out, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, f.FileInfo().Mode())
+ if err != nil {
+ rc.Close()
+ return err
+ }
+ if _, err := io.Copy(out, rc); err != nil {
+ rc.Close()
+ out.Close()
+ return err
+ }
+ rc.Close()
+ out.Close()
+ }
+ return nil
+}
+
+func extractTarGz(archivePath, destDir string) error {
+ f, err := os.Open(archivePath)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ gzr, err := gzip.NewReader(f)
+ if err != nil {
+ return err
+ }
+ defer gzr.Close()
+ tr := tar.NewReader(gzr)
+ return extractTarFromReader(tr, destDir)
+}
+
+func extractTar(archivePath, destDir string) error {
+ f, err := os.Open(archivePath)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ tr := tar.NewReader(f)
+ return extractTarFromReader(tr, destDir)
+}
+
+// extractTarFromReader contains logic common to extracting entries from a
+// tar.Reader and is used by both extractTarGz and extractTar to avoid
+// duplicated code (golangci-lint: dupl).
+func extractTarFromReader(tr *tar.Reader, destDir string) error {
+ for {
+ hdr, err := tr.Next()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return err
+ }
+ target := filepath.Clean(filepath.Join(filepath.Clean(destDir), hdr.Name))
+ if !strings.HasPrefix(target, filepath.Clean(destDir)+string(os.PathSeparator)) &&
+ target != filepath.Clean(destDir) {
+ return fmt.Errorf("path traversal detected: %s", hdr.Name)
+ }
+ switch hdr.Typeflag {
+ case tar.TypeDir:
+ if err := os.MkdirAll(target, 0o755); err != nil {
+ return err
+ }
+ case tar.TypeReg:
+ if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
+ return err
+ }
+ out, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(hdr.Mode))
+ if err != nil {
+ return err
+ }
+ if _, err := io.Copy(out, tr); err != nil {
+ out.Close()
+ return err
+ }
+ out.Close()
+ }
+ }
+ return nil
+}
+
+func findBinaryInDir(dir, programName string) (string, error) {
+ wanted := []string{programName}
+ if runtime.GOOS == "windows" {
+ wanted = append([]string{programName + ".exe"}, wanted...)
+ } else {
+ // also accept programs with .exe in archives targeting windows
+ wanted = append(wanted, programName+".exe")
+ }
+
+ var found string
+ if err := filepath.WalkDir(dir, func(p string, d os.DirEntry, err error) error {
+ if err != nil || found != "" {
+ return err
+ }
+ if d.IsDir() {
+ return nil
+ }
+ base := filepath.Base(p)
+ for _, w := range wanted {
+ if base == w {
+ found = p
+ return io.EOF // use EOF to stop walking early
+ }
+ }
+ return nil
+ }); err != nil && err != io.EOF {
+ return "", err
+ }
+ if found == "" {
+ return "", fmt.Errorf("binary %q not found in archive", programName)
+ }
+ return found, nil
+}
+
+// NewUpdateCommand returns a cobra command that triggers UpdateSelfFromRelease.
+func NewUpdateCommand(binaryName string) *cobra.Command {
+ var urlStr, platform, arch string
+ cmd := &cobra.Command{
+ Use: "update",
+ Short: "Check and apply updates from GitHub releases",
+ RunE: func(cmd *cobra.Command, args []string) error {
+ if platform == "" {
+ platform = runtime.GOOS
+ }
+ if arch == "" {
+ arch = runtime.GOARCH
+ }
+ fmt.Printf("Current version: %s\n", config.FormatVersion())
+ if err := UpdateSelfFromRelease(urlStr, platform, arch, binaryName); err != nil {
+ return err
+ }
+ fmt.Println("Update applied; restart to use the new version.")
+ return nil
+ },
+ }
+ cmd.Flags().StringVarP(&urlStr, "url", "u", "", "Direct URL to download release asset or release page")
+ cmd.Flags().StringVar(&platform, "platform", "", "Target platform (default: runtime.GOOS)")
+ cmd.Flags().StringVar(&arch, "arch", "", "Target arch (default: runtime.GOARCH)")
+ return cmd
+}
diff --git a/pkg/updater/updater_test.go b/pkg/updater/updater_test.go
new file mode 100644
index 000000000..ff75432e4
--- /dev/null
+++ b/pkg/updater/updater_test.go
@@ -0,0 +1,97 @@
+package updater
+
+import (
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+// matchesMagic checks whether the file at path looks like a platform binary
+// by inspecting magic bytes (ELF for linux, MZ for windows).
+func matchesMagic(path, platform string) (bool, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return false, err
+ }
+ defer f.Close()
+ buf := make([]byte, 4)
+ n, err := f.Read(buf)
+ if err != nil && err != io.EOF {
+ return false, err
+ }
+ if n >= 4 && buf[0] == 0x7f && buf[1] == 'E' && buf[2] == 'L' && buf[3] == 'F' {
+ return strings.Contains(platform, "linux"), nil
+ }
+ if n >= 2 && buf[0] == 'M' && buf[1] == 'Z' {
+ return strings.Contains(platform, "windows"), nil
+ }
+ return false, nil
+}
+
+// TestDownloadAndExtractRelease_RealPlatforms downloads the latest release
+// asset for multiple platform/arch combos and inspects the extracted
+// artifacts to ensure a binary-like file is present. This is a network test
+// and is skipped in short mode.
+func TestDownloadAndExtractRelease_RealPlatforms(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping network tests in short mode")
+ }
+
+ combos := []struct{ platform, arch string }{
+ {"linux", "amd64"},
+ {"linux", "arm64"},
+ {"windows", "amd64"},
+ {"windows", "arm64"},
+ }
+
+ apiURL := GetProdReleaseAPIURL()
+ for _, c := range combos {
+ t.Run(c.platform+"_"+c.arch, func(t *testing.T) {
+ assetURL, checksum, err := findAssetInfo(apiURL, c.platform, c.arch)
+ if err != nil {
+ // If no checksum could be located for this asset, skip this
+ // combo rather than failing — we require signed/checksummed
+ // releases for real-network tests.
+ t.Skipf("skipping %s/%s: %v", c.platform, c.arch, err)
+ }
+ t.Logf("asset URL: %s checksum: %s", assetURL, checksum)
+
+ // Pass the release API URL (not the direct asset URL) so
+ // DownloadAndExtractRelease can locate and verify the asset.
+ dir, err := DownloadAndExtractRelease(apiURL, c.platform, c.arch)
+ if err != nil {
+ t.Fatalf("DownloadAndExtractRelease failed for %s/%s: %v", c.platform, c.arch, err)
+ }
+ defer os.RemoveAll(dir)
+
+ var found bool
+ _ = filepath.WalkDir(dir, func(path string, d os.DirEntry, err error) error {
+ if err != nil || d.IsDir() {
+ return err
+ }
+ info, err := d.Info()
+ if err != nil {
+ return err
+ }
+ if info.Size() < 64 {
+ return nil
+ }
+ ok, err := matchesMagic(path, c.platform)
+ if err != nil {
+ return err
+ }
+ if ok {
+ found = true
+ t.Logf("found artifact: %s (size=%d)", path, info.Size())
+ // continue walking to list all
+ }
+ return nil
+ })
+ if !found {
+ t.Fatalf("no binary-like artifact found for %s/%s", c.platform, c.arch)
+ }
+ })
+ }
+}
diff --git a/pkg/voice/groq_transcriber.go b/pkg/voice/groq_transcriber.go
deleted file mode 100644
index b42e598f7..000000000
--- a/pkg/voice/groq_transcriber.go
+++ /dev/null
@@ -1,151 +0,0 @@
-package voice
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "os"
- "path/filepath"
- "time"
-
- "github.com/sipeed/picoclaw/pkg/logger"
- "github.com/sipeed/picoclaw/pkg/utils"
-)
-
-type GroqTranscriber struct {
- apiKey string
- apiBase string
- httpClient *http.Client
-}
-
-func NewGroqTranscriber(apiKey string) *GroqTranscriber {
- logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
-
- apiBase := "https://api.groq.com/openai/v1"
- return &GroqTranscriber{
- apiKey: apiKey,
- apiBase: apiBase,
- httpClient: &http.Client{
- Timeout: 60 * time.Second,
- },
- }
-}
-
-func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
- logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
-
- audioFile, err := os.Open(audioFilePath)
- if err != nil {
- logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
- return nil, fmt.Errorf("failed to open audio file: %w", err)
- }
- defer audioFile.Close()
-
- fileInfo, err := audioFile.Stat()
- if err != nil {
- logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
- return nil, fmt.Errorf("failed to get file info: %w", err)
- }
-
- logger.DebugCF("voice", "Audio file details", map[string]any{
- "size_bytes": fileInfo.Size(),
- "file_name": filepath.Base(audioFilePath),
- })
-
- var requestBody bytes.Buffer
- writer := multipart.NewWriter(&requestBody)
-
- part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
- if err != nil {
- logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to create form file: %w", err)
- }
-
- copied, err := io.Copy(part, audioFile)
- if err != nil {
- logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to copy file content: %w", err)
- }
-
- logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
-
- if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
- logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to write model field: %w", err)
- }
-
- if err = writer.WriteField("response_format", "json"); err != nil {
- logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to write response_format field: %w", err)
- }
-
- if err = writer.Close(); err != nil {
- logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to close multipart writer: %w", err)
- }
-
- url := t.apiBase + "/audio/transcriptions"
- req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
- if err != nil {
- logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
-
- req.Header.Set("Content-Type", writer.FormDataContentType())
- req.Header.Set("Authorization", "Bearer "+t.apiKey)
-
- logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
- "url": url,
- "request_size_bytes": requestBody.Len(),
- "file_size_bytes": fileInfo.Size(),
- })
-
- resp, err := t.httpClient.Do(req)
- if err != nil {
- logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to send request: %w", err)
- }
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to read response: %w", err)
- }
-
- if resp.StatusCode != http.StatusOK {
- logger.ErrorCF("voice", "API error", map[string]any{
- "status_code": resp.StatusCode,
- "response": string(body),
- })
- return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
- }
-
- logger.DebugCF("voice", "Received response from Groq API", map[string]any{
- "status_code": resp.StatusCode,
- "response_size_bytes": len(body),
- })
-
- var result TranscriptionResponse
- if err := json.Unmarshal(body, &result); err != nil {
- logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
- "text_length": len(result.Text),
- "language": result.Language,
- "duration_seconds": result.Duration,
- "transcription_preview": utils.Truncate(result.Text, 50),
- })
-
- return &result, nil
-}
-
-func (t *GroqTranscriber) Name() string {
- return "groq"
-}
diff --git a/pkg/voice/groq_transcriber_test.go b/pkg/voice/groq_transcriber_test.go
deleted file mode 100644
index fdcaa7580..000000000
--- a/pkg/voice/groq_transcriber_test.go
+++ /dev/null
@@ -1,84 +0,0 @@
-package voice
-
-import (
- "context"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "os"
- "path/filepath"
- "testing"
-)
-
-var _ Transcriber = (*GroqTranscriber)(nil)
-
-func TestGroqTranscriberName(t *testing.T) {
- tr := NewGroqTranscriber("sk-test")
- if got := tr.Name(); got != "groq" {
- t.Errorf("Name() = %q, want %q", got, "groq")
- }
-}
-
-func TestGroqTranscribe(t *testing.T) {
- // Write a minimal fake audio file so the transcriber can open and send it.
- tmpDir := t.TempDir()
- audioPath := filepath.Join(tmpDir, "clip.ogg")
- if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil {
- t.Fatalf("failed to write fake audio file: %v", err)
- }
-
- t.Run("success", func(t *testing.T) {
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != "/audio/transcriptions" {
- t.Errorf("unexpected path: %s", r.URL.Path)
- }
- if r.Header.Get("Authorization") != "Bearer sk-test" {
- t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization"))
- }
- w.Header().Set("Content-Type", "application/json")
- _ = json.NewEncoder(w).Encode(TranscriptionResponse{
- Text: "hello world",
- Language: "en",
- Duration: 1.5,
- })
- }))
- defer srv.Close()
-
- tr := NewGroqTranscriber("sk-test")
- tr.apiBase = srv.URL
-
- resp, err := tr.Transcribe(context.Background(), audioPath)
- if err != nil {
- t.Fatalf("Transcribe() error: %v", err)
- }
- if resp.Text != "hello world" {
- t.Errorf("Text = %q, want %q", resp.Text, "hello world")
- }
- if resp.Language != "en" {
- t.Errorf("Language = %q, want %q", resp.Language, "en")
- }
- })
-
- t.Run("api error", func(t *testing.T) {
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized)
- }))
- defer srv.Close()
-
- tr := NewGroqTranscriber("sk-bad")
- tr.apiBase = srv.URL
-
- _, err := tr.Transcribe(context.Background(), audioPath)
- if err == nil {
- t.Fatal("expected error for non-200 response, got nil")
- }
- })
-
- t.Run("missing file", func(t *testing.T) {
- tr := NewGroqTranscriber("sk-test")
- _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg"))
- if err == nil {
- t.Fatal("expected error for missing file, got nil")
- }
- })
-}
diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go
deleted file mode 100644
index f56fdeedd..000000000
--- a/pkg/voice/transcriber.go
+++ /dev/null
@@ -1,68 +0,0 @@
-package voice
-
-import (
- "context"
- "strings"
-
- "github.com/sipeed/picoclaw/pkg/config"
- "github.com/sipeed/picoclaw/pkg/providers"
-)
-
-type Transcriber interface {
- Name() string
- Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error)
-}
-
-type TranscriptionResponse struct {
- Text string `json:"text"`
- Language string `json:"language,omitempty"`
- Duration float64 `json:"duration,omitempty"`
-}
-
-func supportsAudioTranscription(model string) bool {
- protocol, _ := providers.ExtractProtocol(model)
-
- switch protocol {
- case "openai", "azure", "azure-openai",
- "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
- "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
- "vivgrid", "volcengine", "vllm", "qwen", "qwen-intl", "qwen-international", "dashscope-intl",
- "qwen-us", "dashscope-us", "mistral", "avian", "minimax", "longcat", "modelscope", "novita",
- "coding-plan", "alibaba-coding", "qwen-coding":
- // These protocols all go through the OpenAI-compatible or Azure provider path in
- // providers.CreateProviderFromConfig, so they are the only ones that can supply
- // the audio media payload shape expected by NewAudioModelTranscriber.
-
- // TODO: Further restrict this by modelID, since not every model under these
- // protocols supports audio transcription.
- return true
- default:
- return false
- }
-}
-
-// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or
-// nil if no supported transcription provider is configured.
-func DetectTranscriber(cfg *config.Config) Transcriber {
- if modelName := strings.TrimSpace(cfg.Voice.ModelName); modelName != "" {
- modelCfg, err := cfg.GetModelConfig(modelName)
- if err != nil {
- return nil
- }
- if supportsAudioTranscription(modelCfg.Model) {
- return NewAudioModelTranscriber(modelCfg)
- }
- }
-
- // ElevenLabs voice config (supports Scribe STT).
- if key := strings.TrimSpace(cfg.Voice.ElevenLabsAPIKey); key != "" {
- return NewElevenLabsTranscriber(key)
- }
- // Fall back to any model-list entry that uses the groq/ protocol.
- for _, mc := range cfg.ModelList {
- if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey() != "" {
- return NewGroqTranscriber(mc.APIKey())
- }
- }
- return nil
-}
diff --git a/scripts/build-macos-app.sh b/scripts/build-macos-app.sh
index 76cc72938..df2100aec 100755
--- a/scripts/build-macos-app.sh
+++ b/scripts/build-macos-app.sh
@@ -10,6 +10,8 @@ if [ -z "$EXECUTABLE" ]; then
exit 1
fi
+LAUNCHER_EXECUTABLE="picoclaw-launcher-${EXECUTABLE}"
+EXECUTABLE="picoclaw-${EXECUTABLE}"
echo "executable: $EXECUTABLE"
APP_NAME="PicoClaw Launcher"
@@ -33,17 +35,17 @@ mkdir -p "$APP_RESOURCES"
# Copy executable
echo "Copying executable..."
-if [ -f "./web/build/${APP_EXECUTABLE}" ]; then
- cp "./web/build/${APP_EXECUTABLE}" "${APP_MACOS}/"
+if [ -f "./build/${LAUNCHER_EXECUTABLE}" ]; then
+ cp "./build/${LAUNCHER_EXECUTABLE}" "${APP_MACOS}/${APP_EXECUTABLE}"
else
- echo "Error: ./web/build/${APP_EXECUTABLE} not found. Please build the web backend first."
- echo "Run: make build in web dir"
+ echo "Error: ./build/${LAUNCHER_EXECUTABLE} not found. Please build the web backend first."
+ echo "Run: make build-launcher"
exit 1
fi
-if [ -f "./build/picoclaw" ]; then
- cp "./build/picoclaw" "${APP_MACOS}/"
+if [ -f "./build/${EXECUTABLE}" ]; then
+ cp "./build/${EXECUTABLE}" "${APP_MACOS}/picoclaw"
else
- echo "Error: ./build/picoclaw not found. Please build the main file first."
+ echo "Error: ./build/${EXECUTABLE} not found. Please build the main file first."
echo "Run: make build"
exit 1
fi
@@ -76,10 +78,10 @@ cat > "${APP_CONTENTS}/Info.plist" << 'EOF'
NSSupportsAutomaticGraphicsSwitching
- LSRequiresCarbon
- LSUIElement
- 1
+
+ LSMinimumSystemVersion
+ 10.11
EOF
diff --git a/web/backend/api/auth.go b/web/backend/api/auth.go
index b9b4d5f66..22f7ec2c2 100644
--- a/web/backend/api/auth.go
+++ b/web/backend/api/auth.go
@@ -23,6 +23,7 @@ type LauncherAuthRouteOpts struct {
type LauncherAuthTokenHelp struct {
EnvVarName string `json:"env_var_name"`
LogFileAbs string `json:"log_file,omitempty"`
+ ConfigFileAbs string `json:"config_file,omitempty"`
TrayCopyMenu bool `json:"tray_copy_menu"`
ConsoleStdout bool `json:"console_stdout"`
}
diff --git a/web/backend/api/channels.go b/web/backend/api/channels.go
index dd4c9af3d..88e6ec27c 100644
--- a/web/backend/api/channels.go
+++ b/web/backend/api/channels.go
@@ -3,6 +3,8 @@ package api
import (
"encoding/json"
"net/http"
+
+ "github.com/sipeed/picoclaw/pkg/config"
)
type channelCatalogItem struct {
@@ -30,9 +32,22 @@ var channelCatalog = []channelCatalogItem{
{Name: "irc", ConfigKey: "irc"},
}
+type channelConfigResponse struct {
+ Config any `json:"config"`
+ ConfiguredSecrets []string `json:"configured_secrets"`
+ ConfigKey string `json:"config_key"`
+ Variant string `json:"variant,omitempty"`
+}
+
+type channelSecretPresence struct {
+ key string
+ configured bool
+}
+
// registerChannelRoutes binds read-only channel catalog endpoints to the ServeMux.
func (h *Handler) registerChannelRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/channels/catalog", h.handleListChannelCatalog)
+ mux.HandleFunc("GET /api/channels/{name}/config", h.handleGetChannelConfig)
}
// handleListChannelCatalog returns the channels supported by backend.
@@ -44,3 +59,172 @@ func (h *Handler) handleListChannelCatalog(w http.ResponseWriter, r *http.Reques
"channels": channelCatalog,
})
}
+
+// handleGetChannelConfig returns safe channel config plus secret presence metadata.
+//
+// GET /api/channels/{name}/config
+func (h *Handler) handleGetChannelConfig(w http.ResponseWriter, r *http.Request) {
+ channelName := r.PathValue("name")
+ item, ok := findChannelCatalogItem(channelName)
+ if !ok {
+ http.Error(w, "Channel not found", http.StatusNotFound)
+ return
+ }
+
+ cfg, err := config.LoadConfig(h.configPath)
+ if err != nil {
+ http.Error(w, "Failed to load config", http.StatusInternalServerError)
+ return
+ }
+
+ resp := buildChannelConfigResponse(cfg, item)
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(resp); err != nil {
+ http.Error(w, "Failed to encode response", http.StatusInternalServerError)
+ }
+}
+
+func findChannelCatalogItem(name string) (channelCatalogItem, bool) {
+ for _, item := range channelCatalog {
+ if item.Name == name {
+ return item, true
+ }
+ }
+ return channelCatalogItem{}, false
+}
+
+func buildChannelConfigResponse(cfg *config.Config, item channelCatalogItem) channelConfigResponse {
+ resp := channelConfigResponse{
+ ConfiguredSecrets: []string{},
+ ConfigKey: item.ConfigKey,
+ Variant: item.Variant,
+ }
+
+ switch item.Name {
+ case "weixin":
+ channelCfg := cfg.Channels.Weixin
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
+ )
+ channelCfg.Token = config.SecureString{}
+ resp.Config = channelCfg
+ case "telegram":
+ channelCfg := cfg.Channels.Telegram
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
+ )
+ channelCfg.Token = config.SecureString{}
+ resp.Config = channelCfg
+ case "discord":
+ channelCfg := cfg.Channels.Discord
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
+ )
+ channelCfg.Token = config.SecureString{}
+ resp.Config = channelCfg
+ case "slack":
+ channelCfg := cfg.Channels.Slack
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "bot_token", configured: channelCfg.BotToken.String() != ""},
+ channelSecretPresence{key: "app_token", configured: channelCfg.AppToken.String() != ""},
+ )
+ channelCfg.BotToken = config.SecureString{}
+ channelCfg.AppToken = config.SecureString{}
+ resp.Config = channelCfg
+ case "feishu":
+ channelCfg := cfg.Channels.Feishu
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "app_secret", configured: channelCfg.AppSecret.String() != ""},
+ channelSecretPresence{key: "encrypt_key", configured: channelCfg.EncryptKey.String() != ""},
+ channelSecretPresence{key: "verification_token", configured: channelCfg.VerificationToken.String() != ""},
+ )
+ channelCfg.AppSecret = config.SecureString{}
+ channelCfg.EncryptKey = config.SecureString{}
+ channelCfg.VerificationToken = config.SecureString{}
+ resp.Config = channelCfg
+ case "dingtalk":
+ channelCfg := cfg.Channels.DingTalk
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "client_secret", configured: channelCfg.ClientSecret.String() != ""},
+ )
+ channelCfg.ClientSecret = config.SecureString{}
+ resp.Config = channelCfg
+ case "line":
+ channelCfg := cfg.Channels.LINE
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "channel_secret", configured: channelCfg.ChannelSecret.String() != ""},
+ channelSecretPresence{
+ key: "channel_access_token",
+ configured: channelCfg.ChannelAccessToken.String() != "",
+ },
+ )
+ channelCfg.ChannelSecret = config.SecureString{}
+ channelCfg.ChannelAccessToken = config.SecureString{}
+ resp.Config = channelCfg
+ case "qq":
+ channelCfg := cfg.Channels.QQ
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "app_secret", configured: channelCfg.AppSecret.String() != ""},
+ )
+ channelCfg.AppSecret = config.SecureString{}
+ resp.Config = channelCfg
+ case "onebot":
+ channelCfg := cfg.Channels.OneBot
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "access_token", configured: channelCfg.AccessToken.String() != ""},
+ )
+ channelCfg.AccessToken = config.SecureString{}
+ resp.Config = channelCfg
+ case "wecom":
+ channelCfg := cfg.Channels.WeCom
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "secret", configured: channelCfg.Secret.String() != ""},
+ )
+ channelCfg.Secret = config.SecureString{}
+ resp.Config = channelCfg
+ case "whatsapp", "whatsapp_native":
+ resp.Config = cfg.Channels.WhatsApp
+ case "pico":
+ channelCfg := cfg.Channels.Pico
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
+ )
+ channelCfg.Token = config.SecureString{}
+ resp.Config = channelCfg
+ case "maixcam":
+ resp.Config = cfg.Channels.MaixCam
+ case "matrix":
+ channelCfg := cfg.Channels.Matrix
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "access_token", configured: channelCfg.AccessToken.String() != ""},
+ )
+ channelCfg.AccessToken = config.SecureString{}
+ resp.Config = channelCfg
+ case "irc":
+ channelCfg := cfg.Channels.IRC
+ resp.ConfiguredSecrets = collectConfiguredSecrets(
+ channelSecretPresence{key: "password", configured: channelCfg.Password.String() != ""},
+ channelSecretPresence{key: "nickserv_password", configured: channelCfg.NickServPassword.String() != ""},
+ channelSecretPresence{key: "sasl_password", configured: channelCfg.SASLPassword.String() != ""},
+ )
+ channelCfg.Password = config.SecureString{}
+ channelCfg.NickServPassword = config.SecureString{}
+ channelCfg.SASLPassword = config.SecureString{}
+ resp.Config = channelCfg
+ default:
+ resp.Config = map[string]any{}
+ }
+
+ return resp
+}
+
+func collectConfiguredSecrets(secrets ...channelSecretPresence) []string {
+ configured := make([]string, 0, len(secrets))
+ for _, secret := range secrets {
+ if secret.configured {
+ configured = append(configured, secret.key)
+ }
+ }
+ return configured
+}
diff --git a/web/backend/api/channels_test.go b/web/backend/api/channels_test.go
new file mode 100644
index 000000000..73a4b39f3
--- /dev/null
+++ b/web/backend/api/channels_test.go
@@ -0,0 +1,87 @@
+package api
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func TestHandleGetChannelConfig_ReturnsSecretPresenceWithoutLeakingSecrets(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error = %v", err)
+ }
+ cfg.Channels.Feishu.Enabled = true
+ cfg.Channels.Feishu.AppID = "cli_test_app"
+ cfg.Channels.Feishu.AppSecret = *config.NewSecureString("feishu-secret-from-security")
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/channels/feishu/config", nil)
+ rec := httptest.NewRecorder()
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf(
+ "GET /api/channels/feishu/config status = %d, want %d, body=%s",
+ rec.Code,
+ http.StatusOK,
+ rec.Body.String(),
+ )
+ }
+ if strings.Contains(rec.Body.String(), "feishu-secret-from-security") {
+ t.Fatalf("response leaked secret value: %s", rec.Body.String())
+ }
+
+ var resp struct {
+ Config map[string]any `json:"config"`
+ ConfiguredSecrets []string `json:"configured_secrets"`
+ ConfigKey string `json:"config_key"`
+ Variant string `json:"variant"`
+ }
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("json.Unmarshal() error = %v", err)
+ }
+
+ if got := resp.ConfigKey; got != "feishu" {
+ t.Fatalf("config_key = %q, want %q", got, "feishu")
+ }
+ if got := resp.Config["app_id"]; got != "cli_test_app" {
+ t.Fatalf("config.app_id = %#v, want %q", got, "cli_test_app")
+ }
+ if _, exists := resp.Config["app_secret"]; exists {
+ t.Fatalf("config should omit app_secret, got %#v", resp.Config["app_secret"])
+ }
+ if len(resp.ConfiguredSecrets) != 1 || resp.ConfiguredSecrets[0] != "app_secret" {
+ t.Fatalf("configured_secrets = %#v, want [\"app_secret\"]", resp.ConfiguredSecrets)
+ }
+}
+
+func TestHandleGetChannelConfig_ReturnsNotFoundForUnknownChannel(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/channels/not-a-channel/config", nil)
+ rec := httptest.NewRecorder()
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusNotFound {
+ t.Fatalf("GET /api/channels/not-a-channel/config status = %d, want %d", rec.Code, http.StatusNotFound)
+ }
+}
diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go
index 6f5f5dd5d..139f2c8c8 100644
--- a/web/backend/api/gateway.go
+++ b/web/backend/api/gateway.go
@@ -357,7 +357,13 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
return true
}
- return cmd.Process.Signal(syscall.Signal(0)) == nil
+ err := cmd.Process.Signal(syscall.Signal(0))
+ if err == nil {
+ return true
+ }
+ var errno syscall.Errno
+ // EPERM means the process exists but cannot be signaled by this user.
+ return errors.As(err, &errno) && errno == syscall.EPERM
}
func setGatewayRuntimeStatusLocked(status string) {
@@ -401,6 +407,15 @@ func gatewayStatusWithoutHealthLocked() string {
return "error"
}
if gateway.runtimeStatus == "running" {
+ // For attached processes there is no waiter goroutine; degrade stale
+ // running state once the tracked process exits.
+ if !isCmdProcessAliveLocked(gateway.cmd) {
+ gateway.cmd = nil
+ gateway.owned = false
+ gateway.bootDefaultModel = ""
+ gateway.bootConfigSignature = ""
+ return "stopped"
+ }
return "running"
}
if gateway.runtimeStatus == "error" {
@@ -614,6 +629,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
// Start a goroutine to probe pidFile and health, update runtime state once ready.
go func() {
+ healthConfirmed := false
for i := 0; i < 30; i++ { // try for up to 15 seconds
time.Sleep(500 * time.Millisecond)
gateway.mu.Lock()
@@ -628,6 +644,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
gateway.mu.Lock()
if gateway.cmd == cmd {
gateway.pidData = pd
+ gateway.picoToken = cfg.Channels.Pico.Token.String()
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
@@ -647,7 +664,11 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
- return
+ if !healthConfirmed {
+ healthConfirmed = true
+ logger.InfoC("gateway", "Gateway health endpoint reachable; waiting for pid file")
+ }
+ continue
}
}
}()
@@ -922,34 +943,19 @@ func (h *Handler) gatewayStatusData() map[string]any {
data["pid"] = pidData.PID
gateway.mu.Unlock()
} else {
- // Fallback: probe health endpoint to get pid and status
- _, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
- if err != nil {
- gateway.mu.Lock()
- data["gateway_status"] = gatewayStatusWithoutHealthLocked()
+ // Intentionally skip health probe here; the startup goroutine
+ // (startGatewayLocked) already handles liveness detection via
+ // pidFile polling and health fallback.
+ gateway.mu.Lock()
+ status := gatewayStatusWithoutHealthLocked()
+ data["gateway_status"] = status
+ // Keep last known pidData while gateway is still in a transient
+ // running state; otherwise websocket proxy may lose auth token
+ // during short pid-file races.
+ if status == "stopped" || status == "error" {
gateway.pidData = nil
- gateway.mu.Unlock()
- logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
- } else {
- logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
- if statusCode != http.StatusOK {
- gateway.mu.Lock()
- setGatewayRuntimeStatusLocked("error")
- gateway.pidData = nil
- gateway.mu.Unlock()
- data["gateway_status"] = "error"
- data["status_code"] = statusCode
- } else {
- gateway.mu.Lock()
- setGatewayRuntimeStatusLocked("running")
- bootDefaultModel := gateway.bootDefaultModel
- if bootDefaultModel != "" {
- data["boot_default_model"] = bootDefaultModel
- }
- data["gateway_status"] = "running"
- gateway.mu.Unlock()
- }
}
+ gateway.mu.Unlock()
}
gatewayStatus, _ := data["gateway_status"].(string)
diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go
index fc8ee13f3..1f5f13e27 100644
--- a/web/backend/api/gateway_test.go
+++ b/web/backend/api/gateway_test.go
@@ -15,8 +15,11 @@ import (
"testing"
"time"
+ "github.com/stretchr/testify/require"
+
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
+ ppid "github.com/sipeed/picoclaw/pkg/pid"
"github.com/sipeed/picoclaw/web/backend/utils"
)
@@ -444,7 +447,93 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
}
}
-func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
+func TestGatewayStatusKeepsPidDataWhileTrackedProcessAliveWhenPidFileUnavailable(t *testing.T) {
+ resetGatewayTestState(t)
+
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ cmd := startLongRunningProcess(t)
+ t.Cleanup(func() {
+ if cmd.Process != nil {
+ _ = cmd.Process.Kill()
+ }
+ _ = cmd.Wait()
+ })
+
+ gateway.mu.Lock()
+ gateway.cmd = cmd
+ gateway.pidData = &ppid.PidFileData{
+ PID: cmd.Process.Pid,
+ Token: "existing-token",
+ }
+ setGatewayRuntimeStatusLocked("running")
+ gateway.mu.Unlock()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+
+ gateway.mu.Lock()
+ defer gateway.mu.Unlock()
+ if gateway.pidData == nil {
+ t.Fatal("gateway.pidData was cleared while runtime status remained running")
+ }
+}
+
+func TestGatewayStatusDowngradesRunningWhenTrackedProcessExitedAndPidFileMissing(t *testing.T) {
+ resetGatewayTestState(t)
+
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ cmd := startLongRunningProcess(t)
+ if cmd.Process != nil {
+ _ = cmd.Process.Kill()
+ }
+ _ = cmd.Wait()
+
+ gateway.mu.Lock()
+ gateway.cmd = cmd
+ gateway.pidData = &ppid.PidFileData{
+ PID: cmd.Process.Pid,
+ Token: "stale-token",
+ }
+ setGatewayRuntimeStatusLocked("running")
+ gateway.mu.Unlock()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+
+ var body map[string]any
+ if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
+ t.Fatalf("unmarshal response: %v", err)
+ }
+ if got := body["gateway_status"]; got != "stopped" {
+ t.Fatalf("gateway_status = %#v, want %q", got, "stopped")
+ }
+
+ gateway.mu.Lock()
+ defer gateway.mu.Unlock()
+ if gateway.pidData != nil {
+ t.Fatal("gateway.pidData should be cleared when tracked process has exited")
+ }
+}
+
+func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
@@ -468,6 +557,9 @@ func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
}
+ _, err := ppid.WritePidFile(globalConfigDir(), "localhost", 0)
+ require.NoError(t, err)
+
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
@@ -513,6 +605,8 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}
+ _, err = ppid.WritePidFile(globalConfigDir(), "localhost", 0)
+ require.NoError(t, err)
bootSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
diff --git a/web/backend/api/launcher_config.go b/web/backend/api/launcher_config.go
index e149d5671..d16cd9267 100644
--- a/web/backend/api/launcher_config.go
+++ b/web/backend/api/launcher_config.go
@@ -4,14 +4,16 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "strings"
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
)
type launcherConfigPayload struct {
- Port int `json:"port"`
- Public bool `json:"public"`
- AllowedCIDRs []string `json:"allowed_cidrs"`
+ Port int `json:"port"`
+ Public bool `json:"public"`
+ AllowedCIDRs []string `json:"allowed_cidrs"`
+ LauncherToken string `json:"launcher_token"`
}
func (h *Handler) registerLauncherConfigRoutes(mux *http.ServeMux) {
@@ -48,9 +50,10 @@ func (h *Handler) handleGetLauncherConfig(w http.ResponseWriter, r *http.Request
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(launcherConfigPayload{
- Port: cfg.Port,
- Public: cfg.Public,
- AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
+ Port: cfg.Port,
+ Public: cfg.Public,
+ AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
+ LauncherToken: cfg.LauncherToken,
})
}
@@ -62,9 +65,10 @@ func (h *Handler) handleUpdateLauncherConfig(w http.ResponseWriter, r *http.Requ
}
cfg := launcherconfig.Config{
- Port: payload.Port,
- Public: payload.Public,
- AllowedCIDRs: append([]string(nil), payload.AllowedCIDRs...),
+ Port: payload.Port,
+ Public: payload.Public,
+ AllowedCIDRs: append([]string(nil), payload.AllowedCIDRs...),
+ LauncherToken: strings.TrimSpace(payload.LauncherToken),
}
if err := launcherconfig.Validate(cfg); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -78,8 +82,9 @@ func (h *Handler) handleUpdateLauncherConfig(w http.ResponseWriter, r *http.Requ
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(launcherConfigPayload{
- Port: cfg.Port,
- Public: cfg.Public,
- AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
+ Port: cfg.Port,
+ Public: cfg.Public,
+ AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
+ LauncherToken: cfg.LauncherToken,
})
}
diff --git a/web/backend/api/launcher_config_test.go b/web/backend/api/launcher_config_test.go
index 0d6af823c..4e0acf5d0 100644
--- a/web/backend/api/launcher_config_test.go
+++ b/web/backend/api/launcher_config_test.go
@@ -34,6 +34,9 @@ func TestGetLauncherConfigUsesRuntimeFallback(t *testing.T) {
if got.Port != 19999 || !got.Public {
t.Fatalf("response = %+v, want port=19999 public=true", got)
}
+ if got.LauncherToken != "" {
+ t.Fatalf("response launcher_token = %q, want empty", got.LauncherToken)
+ }
if len(got.AllowedCIDRs) != 1 || got.AllowedCIDRs[0] != "192.168.1.0/24" {
t.Fatalf("response allowed_cidrs = %v, want [192.168.1.0/24]", got.AllowedCIDRs)
}
@@ -50,7 +53,9 @@ func TestPutLauncherConfigPersists(t *testing.T) {
req := httptest.NewRequest(
http.MethodPut,
"/api/system/launcher-config",
- strings.NewReader(`{"port":18080,"public":true,"allowed_cidrs":["192.168.1.0/24"]}`),
+ strings.NewReader(
+ `{"port":18080,"public":true,"allowed_cidrs":["192.168.1.0/24"],"launcher_token":"saved-token"}`,
+ ),
)
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
@@ -67,6 +72,9 @@ func TestPutLauncherConfigPersists(t *testing.T) {
if cfg.Port != 18080 || !cfg.Public {
t.Fatalf("saved config = %+v, want port=18080 public=true", cfg)
}
+ if cfg.LauncherToken != "saved-token" {
+ t.Fatalf("saved launcher_token = %q, want %q", cfg.LauncherToken, "saved-token")
+ }
if len(cfg.AllowedCIDRs) != 1 || cfg.AllowedCIDRs[0] != "192.168.1.0/24" {
t.Fatalf("saved config allowed_cidrs = %v, want [192.168.1.0/24]", cfg.AllowedCIDRs)
}
diff --git a/web/backend/api/model_status.go b/web/backend/api/model_status.go
index 160c4d257..98bd501f5 100644
--- a/web/backend/api/model_status.go
+++ b/web/backend/api/model_status.go
@@ -1,19 +1,36 @@
package api
import (
+ "context"
"encoding/json"
"fmt"
+ "hash/fnv"
"net"
"net/http"
"net/url"
+ "strconv"
"strings"
+ "sync"
"time"
+ "golang.org/x/sync/singleflight"
+
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
-const modelProbeTimeout = 800 * time.Millisecond
+const (
+ modelProbeTimeout = 800 * time.Millisecond
+ modelProbeSuccessBaseInterval = 2 * time.Second
+ modelProbeSuccessMaxInterval = 60 * time.Second
+ modelProbeFailureBaseInterval = 1 * time.Second
+ modelProbeFailureMaxInterval = 30 * time.Second
+ modelProbeBackoffMaxShift = 8
+ modelProbeCacheMaxEntries = 1024
+ modelProbeCacheEntryTTL = 30 * time.Minute
+ modelProbeCacheTrimToEntries = modelProbeCacheMaxEntries * 8 / 10
+ modelProbeTTLGCInterval = 1 * time.Minute
+)
const (
modelStatusAvailable = "available"
@@ -30,8 +47,41 @@ var (
probeTCPServiceFunc = probeTCPService
probeOllamaModelFunc = probeOllamaModel
probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel
+ modelProbeNowFunc = time.Now
+ modelProbeState = newModelProbeCacheState()
)
+type modelProbeCacheState struct {
+ mu sync.RWMutex
+ cache map[string]*modelProbeCacheEntry
+ group singleflight.Group
+ nextTTLGCAt time.Time
+}
+
+type modelProbeCacheEntry struct {
+ lastResult bool
+ hasResult bool
+ successStreak int
+ failureStreak int
+ nextProbeAt time.Time
+ updatedAt time.Time
+}
+
+func newModelProbeCacheState() *modelProbeCacheState {
+ return &modelProbeCacheState{cache: map[string]*modelProbeCacheEntry{}}
+}
+
+func resetModelProbeCache() {
+ modelProbeState.resetForTest()
+}
+
+func (s *modelProbeCacheState) resetForTest() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.cache = map[string]*modelProbeCacheEntry{}
+ s.nextTTLGCAt = time.Time{}
+}
+
func hasModelConfiguration(m *config.ModelConfig) bool {
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
apiKey := strings.TrimSpace(m.APIKey())
@@ -93,6 +143,34 @@ func requiresRuntimeProbe(m *config.ModelConfig) bool {
}
func probeLocalModelAvailability(m *config.ModelConfig) bool {
+ cacheKey := modelProbeCacheKey(m)
+ return modelProbeState.probe(cacheKey, func() bool {
+ return runLocalModelProbe(m)
+ })
+}
+
+func (s *modelProbeCacheState) probe(cacheKey string, probeFunc func() bool) bool {
+ now := modelProbeNowFunc()
+ if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
+ return cachedResult
+ }
+
+ v, _, _ := s.group.Do(cacheKey, func() (any, error) {
+ now = modelProbeNowFunc()
+ if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
+ return cachedResult, nil
+ }
+
+ result := probeFunc()
+ s.setCachedResult(cacheKey, result, now)
+ return result, nil
+ })
+
+ result, _ := v.(bool)
+ return result
+}
+
+func runLocalModelProbe(m *config.ModelConfig) bool {
apiBase := modelProbeAPIBase(m)
protocol, modelID := splitModel(m.Model)
switch protocol {
@@ -112,6 +190,195 @@ func probeLocalModelAvailability(m *config.ModelConfig) bool {
}
}
+func modelProbeCacheKey(m *config.ModelConfig) string {
+ protocol, modelID := splitModel(m.Model)
+
+ apiBaseRaw := modelProbeAPIBase(m)
+ apiBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(apiBaseRaw), "/"))
+ apiKeyFingerprint := modelProbeAPIKeyFingerprint(m.APIKey())
+
+ var b strings.Builder
+ b.Grow(len(protocol) + len(modelID) + len(apiBase) + len(apiKeyFingerprint) + 8)
+ b.WriteString(protocol)
+ b.WriteByte('|')
+ b.WriteString(modelID)
+ b.WriteByte('|')
+ b.WriteString(apiBase)
+ b.WriteByte('|')
+ b.WriteString(apiKeyFingerprint)
+
+ return b.String()
+}
+
+func modelProbeAPIKeyFingerprint(raw string) string {
+ apiKey := strings.TrimSpace(raw)
+ if apiKey == "" {
+ return "none"
+ }
+
+ h := fnv.New64a()
+ _, _ = h.Write([]byte(apiKey))
+ return strconv.FormatUint(h.Sum64(), 36)
+}
+
+func (s *modelProbeCacheState) getCachedResult(cacheKey string, now time.Time) (bool, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ entry, ok := s.cache[cacheKey]
+ if !ok || !entry.hasResult {
+ return false, false
+ }
+ if now.Before(entry.nextProbeAt) {
+ return entry.lastResult, true
+ }
+ return false, false
+}
+
+func (s *modelProbeCacheState) setCachedResult(cacheKey string, result bool, now time.Time) {
+ s.mu.Lock()
+
+ entry, ok := s.cache[cacheKey]
+ if !ok {
+ entry = &modelProbeCacheEntry{}
+ s.cache[cacheKey] = entry
+ }
+
+ entry.lastResult = result
+ entry.hasResult = true
+ entry.updatedAt = now
+
+ var delay time.Duration
+ if result {
+ entry.successStreak++
+ entry.failureStreak = 0
+ delay = modelProbeBackoffDelay(
+ modelProbeSuccessBaseInterval,
+ modelProbeSuccessMaxInterval,
+ entry.successStreak,
+ )
+ } else {
+ entry.failureStreak++
+ entry.successStreak = 0
+ delay = modelProbeBackoffDelay(
+ modelProbeFailureBaseInterval,
+ modelProbeFailureMaxInterval,
+ entry.failureStreak,
+ )
+ }
+
+ entry.nextProbeAt = now.Add(delay)
+
+ shouldRunTTLGC := modelProbeCacheEntryTTL > 0 && (s.nextTTLGCAt.IsZero() || !now.Before(s.nextTTLGCAt))
+ if shouldRunTTLGC {
+ s.nextTTLGCAt = now.Add(modelProbeTTLGCInterval)
+ }
+ shouldRunSizeGC := len(s.cache) > modelProbeCacheMaxEntries
+ s.mu.Unlock()
+
+ if shouldRunTTLGC || shouldRunSizeGC {
+ s.gc(now, shouldRunTTLGC)
+ }
+}
+
+func (s *modelProbeCacheState) gc(now time.Time, runTTL bool) {
+ type evictionCandidate struct {
+ key string
+ updatedAt time.Time
+ }
+
+ var expireBefore time.Time
+ if runTTL && modelProbeCacheEntryTTL > 0 {
+ expireBefore = now.Add(-modelProbeCacheEntryTTL)
+ }
+
+ s.mu.RLock()
+ cacheLen := len(s.cache)
+ if cacheLen == 0 {
+ s.mu.RUnlock()
+ return
+ }
+
+ expiredKeys := make([]string, 0)
+ if !expireBefore.IsZero() {
+ expiredKeys = make([]string, 0, min(cacheLen/8+1, 64))
+ for key, entry := range s.cache {
+ if entry.updatedAt.Before(expireBefore) {
+ expiredKeys = append(expiredKeys, key)
+ }
+ }
+ }
+
+ effectiveLen := cacheLen - len(expiredKeys)
+ removeCount := max(effectiveLen-modelProbeCacheTrimToEntries, 0)
+
+ candidates := make([]evictionCandidate, 0)
+ if removeCount > 0 {
+ candidates = make([]evictionCandidate, 0, effectiveLen)
+ for key, entry := range s.cache {
+ if !expireBefore.IsZero() && entry.updatedAt.Before(expireBefore) {
+ continue
+ }
+ candidates = append(candidates, evictionCandidate{key: key, updatedAt: entry.updatedAt})
+ }
+ }
+ s.mu.RUnlock()
+
+ if len(expiredKeys) == 0 && len(candidates) == 0 {
+ return
+ }
+
+ toEvict := map[string]time.Time{}
+ for i := 0; i < removeCount && len(candidates) > 0; i++ {
+ oldest := 0
+ for j := 1; j < len(candidates); j++ {
+ if candidates[j].updatedAt.Before(candidates[oldest].updatedAt) {
+ oldest = j
+ }
+ }
+ victim := candidates[oldest]
+ toEvict[victim.key] = victim.updatedAt
+ candidates[oldest] = candidates[len(candidates)-1]
+ candidates = candidates[:len(candidates)-1]
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !expireBefore.IsZero() {
+ for _, key := range expiredKeys {
+ entry, ok := s.cache[key]
+ if ok && entry.updatedAt.Before(expireBefore) {
+ delete(s.cache, key)
+ }
+ }
+ }
+
+ for key, victimUpdatedAt := range toEvict {
+ entry, ok := s.cache[key]
+ if ok && !entry.updatedAt.After(victimUpdatedAt) {
+ delete(s.cache, key)
+ }
+ }
+}
+
+func modelProbeBackoffDelay(base, maxDelay time.Duration, streak int) time.Duration {
+ if streak <= 0 {
+ streak = 1
+ }
+
+ shift := min(streak-1, modelProbeBackoffMaxShift)
+
+ delay := base * time.Duration(1< 0 && (delay > maxDelay || delay < 0) {
+ return maxDelay
+ }
+ if delay <= 0 {
+ return base
+ }
+ return delay
+}
+
func modelProbeAPIBase(m *config.ModelConfig) string {
if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" {
return normalizeModelProbeAPIBase(apiBase)
@@ -207,7 +474,11 @@ func probeTCPService(raw string) bool {
return false
}
- conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
+ defer cancel()
+
+ dialer := &net.Dialer{}
+ conn, err := dialer.DialContext(ctx, "tcp", hostPort)
if err != nil {
return false
}
@@ -262,7 +533,10 @@ func probeOpenAICompatibleModel(apiBase, modelID, apiKey string) bool {
}
func getJSON(rawURL string, out any, apiKey string) error {
- req, err := http.NewRequest(http.MethodGet, rawURL, nil)
+ ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return err
}
@@ -270,7 +544,7 @@ func getJSON(rawURL string, out any, apiKey string) error {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
- client := &http.Client{Timeout: modelProbeTimeout}
+ client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
@@ -336,10 +610,29 @@ func ollamaModelMatches(candidate, want string) bool {
if candidate == "" || want == "" {
return false
}
- if strings.EqualFold(candidate, want) {
- return true
+
+ candidateBase, candidateTag := splitOllamaModel(candidate)
+ wantBase, wantTag := splitOllamaModel(want)
+ if candidateBase == "" || wantBase == "" {
+ return false
}
- base, _, _ := strings.Cut(candidate, ":")
- return strings.EqualFold(base, want)
+ if candidateTag == "" {
+ candidateTag = "latest"
+ }
+ if wantTag == "" {
+ wantTag = "latest"
+ }
+
+ return strings.EqualFold(candidateBase, wantBase) && strings.EqualFold(candidateTag, wantTag)
+}
+
+func splitOllamaModel(raw string) (base, tag string) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return "", ""
+ }
+
+ base, tag, _ = strings.Cut(raw, ":")
+ return strings.TrimSpace(base), strings.TrimSpace(tag)
}
diff --git a/web/backend/api/model_status_test.go b/web/backend/api/model_status_test.go
index bfeadf1fe..d5463a856 100644
--- a/web/backend/api/model_status_test.go
+++ b/web/backend/api/model_status_test.go
@@ -3,7 +3,10 @@ package api
import (
"net/http"
"net/http/httptest"
+ "sync"
+ "sync/atomic"
"testing"
+ "time"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -85,3 +88,307 @@ func TestProbeLocalModelAvailability_LMStudioUsesOpenAICompatibleProbe(t *testin
t.Fatal("probeOpenAICompatibleModelFunc was not called for lmstudio")
}
}
+
+func TestModelProbeCacheKey_DifferentAPIKeysProduceDifferentKeys(t *testing.T) {
+ base := &config.ModelConfig{
+ ModelName: "local-vllm",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ AuthMethod: "local",
+ ConnectMode: "",
+ }
+
+ m1 := *base
+ m1.SetAPIKey("key-a")
+ m2 := *base
+ m2.SetAPIKey("key-b")
+
+ k1 := modelProbeCacheKey(&m1)
+ k2 := modelProbeCacheKey(&m2)
+ if k1 == k2 {
+ t.Fatal("modelProbeCacheKey() should differ when api key changes")
+ }
+}
+
+func TestModelProbeCacheKey_NormalizesTrailingSlashInAPIBase(t *testing.T) {
+ m1 := &config.ModelConfig{
+ ModelName: "local-vllm",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ }
+ m2 := &config.ModelConfig{
+ ModelName: "local-vllm",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1/",
+ }
+
+ k1 := modelProbeCacheKey(m1)
+ k2 := modelProbeCacheKey(m2)
+ if k1 != k2 {
+ t.Fatalf("modelProbeCacheKey() mismatch for equivalent api_base values: %q vs %q", k1, k2)
+ }
+}
+
+func TestModelProbeCacheKey_IgnoresDisplayAndConnectionFields(t *testing.T) {
+ base := &config.ModelConfig{
+ ModelName: "vllm-one",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ AuthMethod: "none",
+ ConnectMode: "http",
+ }
+ changed := &config.ModelConfig{
+ ModelName: "vllm-two",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ AuthMethod: "token",
+ ConnectMode: "ws",
+ }
+
+ k1 := modelProbeCacheKey(base)
+ k2 := modelProbeCacheKey(changed)
+ if k1 != k2 {
+ t.Fatalf("modelProbeCacheKey() should ignore non-probe fields, got %q vs %q", k1, k2)
+ }
+}
+
+func TestProbeLocalModelAvailability_SuccessBackoff(t *testing.T) {
+ resetModelProbeHooks(t)
+
+ now := time.Unix(1700000000, 0)
+ modelProbeNowFunc = func() time.Time { return now }
+
+ calls := 0
+ probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
+ calls++
+ return true
+ }
+
+ model := &config.ModelConfig{
+ ModelName: "local-vllm",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ }
+
+ if !probeLocalModelAvailability(model) {
+ t.Fatal("first probe result = false, want true")
+ }
+ if calls != 1 {
+ t.Fatalf("probe calls after first probe = %d, want 1", calls)
+ }
+
+ if !probeLocalModelAvailability(model) {
+ t.Fatal("cached probe result = false, want true")
+ }
+ if calls != 1 {
+ t.Fatalf("probe calls after immediate re-check = %d, want 1", calls)
+ }
+
+ now = now.Add(modelProbeSuccessBaseInterval)
+ if !probeLocalModelAvailability(model) {
+ t.Fatal("second probe result = false, want true")
+ }
+ if calls != 2 {
+ t.Fatalf("probe calls after success backoff window = %d, want 2", calls)
+ }
+
+ now = now.Add(modelProbeSuccessBaseInterval)
+ if !probeLocalModelAvailability(model) {
+ t.Fatal("cached result after doubled backoff = false, want true")
+ }
+ if calls != 2 {
+ t.Fatalf("probe calls before doubled backoff expires = %d, want 2", calls)
+ }
+
+ now = now.Add(modelProbeSuccessBaseInterval)
+ if !probeLocalModelAvailability(model) {
+ t.Fatal("third probe result = false, want true")
+ }
+ if calls != 3 {
+ t.Fatalf("probe calls after doubled backoff expires = %d, want 3", calls)
+ }
+}
+
+func TestProbeLocalModelAvailability_FailureBackoff(t *testing.T) {
+ resetModelProbeHooks(t)
+
+ now := time.Unix(1700000100, 0)
+ modelProbeNowFunc = func() time.Time { return now }
+
+ calls := 0
+ probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
+ calls++
+ return false
+ }
+
+ model := &config.ModelConfig{
+ ModelName: "local-vllm",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ }
+
+ if probeLocalModelAvailability(model) {
+ t.Fatal("first probe result = true, want false")
+ }
+ if calls != 1 {
+ t.Fatalf("probe calls after first failure = %d, want 1", calls)
+ }
+
+ if probeLocalModelAvailability(model) {
+ t.Fatal("cached failed probe result = true, want false")
+ }
+ if calls != 1 {
+ t.Fatalf("probe calls after immediate failed re-check = %d, want 1", calls)
+ }
+
+ now = now.Add(modelProbeFailureBaseInterval)
+ if probeLocalModelAvailability(model) {
+ t.Fatal("second failed probe result = true, want false")
+ }
+ if calls != 2 {
+ t.Fatalf("probe calls after failure backoff window = %d, want 2", calls)
+ }
+
+ now = now.Add(modelProbeFailureBaseInterval)
+ if probeLocalModelAvailability(model) {
+ t.Fatal("cached failure after doubled backoff = true, want false")
+ }
+ if calls != 2 {
+ t.Fatalf("probe calls before doubled failure backoff expires = %d, want 2", calls)
+ }
+
+ now = now.Add(modelProbeFailureBaseInterval)
+ if probeLocalModelAvailability(model) {
+ t.Fatal("third failed probe result = true, want false")
+ }
+ if calls != 3 {
+ t.Fatalf("probe calls after doubled failure backoff expires = %d, want 3", calls)
+ }
+}
+
+func TestProbeLocalModelAvailability_ResultFlipResetsBackoff(t *testing.T) {
+ resetModelProbeHooks(t)
+
+ now := time.Unix(1700000200, 0)
+ modelProbeNowFunc = func() time.Time { return now }
+
+ results := []bool{true, false, false}
+ index := 0
+ probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
+ if index >= len(results) {
+ return false
+ }
+ result := results[index]
+ index++
+ return result
+ }
+
+ model := &config.ModelConfig{
+ ModelName: "local-vllm",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ }
+
+ if !probeLocalModelAvailability(model) {
+ t.Fatal("first probe result = false, want true")
+ }
+
+ now = now.Add(modelProbeSuccessBaseInterval)
+ if probeLocalModelAvailability(model) {
+ t.Fatal("second probe result = true, want false")
+ }
+
+ now = now.Add(modelProbeFailureBaseInterval)
+ if probeLocalModelAvailability(model) {
+ t.Fatal("third probe result = true, want false")
+ }
+
+ if index != 3 {
+ t.Fatalf("probe invocations = %d, want 3", index)
+ }
+}
+
+func TestProbeLocalModelAvailability_DeduplicatesInflightProbe(t *testing.T) {
+ resetModelProbeHooks(t)
+
+ now := time.Unix(1700000300, 0)
+ modelProbeNowFunc = func() time.Time { return now }
+
+ var calls int32
+ probeStarted := make(chan struct{})
+ releaseProbe := make(chan struct{})
+
+ probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
+ if atomic.AddInt32(&calls, 1) == 1 {
+ close(probeStarted)
+ }
+ <-releaseProbe
+ return true
+ }
+
+ model := &config.ModelConfig{
+ ModelName: "local-vllm",
+ Model: "vllm/custom-model",
+ APIBase: "http://127.0.0.1:8000/v1",
+ }
+
+ const workers = 8
+ var wg sync.WaitGroup
+ results := make(chan bool, workers)
+ workerStarted := make(chan struct{}, workers)
+
+ for range workers {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ workerStarted <- struct{}{}
+ results <- probeLocalModelAvailability(model)
+ }()
+ }
+
+ for range workers {
+ <-workerStarted
+ }
+
+ select {
+ case <-probeStarted:
+ case <-time.After(200 * time.Millisecond):
+ t.Fatal("probe did not start in time")
+ }
+
+ if got := atomic.LoadInt32(&calls); got != 1 {
+ t.Fatalf("concurrent probe calls = %d, want 1", got)
+ }
+
+ close(releaseProbe)
+ wg.Wait()
+ close(results)
+
+ for result := range results {
+ if !result {
+ t.Fatal("deduplicated probe result = false, want true")
+ }
+ }
+
+ if got := atomic.LoadInt32(&calls); got != 1 {
+ t.Fatalf("final probe calls = %d, want 1", got)
+ }
+}
+
+func TestOllamaModelMatches_WithTagRequiresExactTag(t *testing.T) {
+ if ollamaModelMatches("llama3:8b", "llama3:7b") {
+ t.Fatal("ollamaModelMatches() = true, want false for mismatched tags")
+ }
+ if !ollamaModelMatches("llama3:7b", "llama3:7b") {
+ t.Fatal("ollamaModelMatches() = false, want true for exact tagged match")
+ }
+ if ollamaModelMatches("llama3:8b", "llama3") {
+ t.Fatal("ollamaModelMatches() = true, want false when request omits tag (defaults to latest)")
+ }
+ if !ollamaModelMatches("llama3:latest", "llama3") {
+ t.Fatal("ollamaModelMatches() = false, want true when request omits tag and candidate is latest")
+ }
+ if !ollamaModelMatches("llama3", "llama3") {
+ t.Fatal("ollamaModelMatches() = false, want true when both candidate and request omit tag (latest)")
+ }
+}
diff --git a/web/backend/api/models.go b/web/backend/api/models.go
index e6749b56e..aa4a775eb 100644
--- a/web/backend/api/models.go
+++ b/web/backend/api/models.go
@@ -32,13 +32,14 @@ type modelResponse struct {
Proxy string `json:"proxy,omitempty"`
AuthMethod string `json:"auth_method,omitempty"`
// Advanced fields
- ConnectMode string `json:"connect_mode,omitempty"`
- Workspace string `json:"workspace,omitempty"`
- RPM int `json:"rpm,omitempty"`
- MaxTokensField string `json:"max_tokens_field,omitempty"`
- RequestTimeout int `json:"request_timeout,omitempty"`
- ThinkingLevel string `json:"thinking_level,omitempty"`
- ExtraBody map[string]any `json:"extra_body,omitempty"`
+ ConnectMode string `json:"connect_mode,omitempty"`
+ Workspace string `json:"workspace,omitempty"`
+ RPM int `json:"rpm,omitempty"`
+ MaxTokensField string `json:"max_tokens_field,omitempty"`
+ RequestTimeout int `json:"request_timeout,omitempty"`
+ ThinkingLevel string `json:"thinking_level,omitempty"`
+ ExtraBody map[string]any `json:"extra_body,omitempty"`
+ CustomHeaders map[string]string `json:"custom_headers,omitempty"`
// Meta
Enabled bool `json:"enabled"`
Available bool `json:"available"`
@@ -87,6 +88,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
+ CustomHeaders: m.CustomHeaders,
Enabled: m.Enabled,
Available: modelStatuses[i].Available,
Status: modelStatuses[i].Status,
@@ -216,6 +218,14 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
} else if len(mc.ExtraBody) == 0 {
mc.ExtraBody = nil
}
+ // Preserve existing CustomHeaders when omitted (nil), but clear it when
+ // the frontend sends an empty object {} to indicate the field should
+ // be removed.
+ if mc.CustomHeaders == nil {
+ mc.CustomHeaders = cfg.ModelList[idx].CustomHeaders
+ } else if len(mc.CustomHeaders) == 0 {
+ mc.CustomHeaders = nil
+ }
cfg.ModelList[idx] = &mc.ModelConfig
diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go
index e78de1606..e4297f679 100644
--- a/web/backend/api/models_test.go
+++ b/web/backend/api/models_test.go
@@ -20,10 +20,14 @@ func resetModelProbeHooks(t *testing.T) {
origTCPProbe := probeTCPServiceFunc
origOllamaProbe := probeOllamaModelFunc
origOpenAIProbe := probeOpenAICompatibleModelFunc
+ origNow := modelProbeNowFunc
+ resetModelProbeCache()
t.Cleanup(func() {
probeTCPServiceFunc = origTCPProbe
probeOllamaModelFunc = origOllamaProbe
probeOpenAICompatibleModelFunc = origOpenAIProbe
+ modelProbeNowFunc = origNow
+ resetModelProbeCache()
})
}
@@ -426,6 +430,112 @@ func TestHandleAddModel_PersistsAPIKey(t *testing.T) {
}
}
+func TestHandleAddModel_PersistsCustomHeaders(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{
+ "model_name":"new-model-headers",
+ "model":"openai/gpt-4o-mini",
+ "custom_headers":{"X-Source":"coding-plan","X-Agent":"openclaw"}
+ }`))
+ req.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error = %v", err)
+ }
+ if len(cfg.ModelList) != 2 {
+ t.Fatalf("len(model_list) = %d, want 2", len(cfg.ModelList))
+ }
+
+ added := cfg.ModelList[1]
+ if added.CustomHeaders == nil {
+ t.Fatal("custom_headers should not be nil")
+ }
+ if got := added.CustomHeaders["X-Source"]; got != "coding-plan" {
+ t.Fatalf("custom_headers[X-Source] = %q, want %q", got, "coding-plan")
+ }
+ if got := added.CustomHeaders["X-Agent"]; got != "openclaw" {
+ t.Fatalf("custom_headers[X-Agent] = %q, want %q", got, "openclaw")
+ }
+}
+
+func TestHandleUpdateModel_CustomHeadersPreserveAndClear(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error = %v", err)
+ }
+ cfg.ModelList = []*config.ModelConfig{{
+ ModelName: "editable",
+ Model: "openai/gpt-4o-mini",
+ APIKeys: config.SimpleSecureStrings("sk-existing"),
+ CustomHeaders: map[string]string{"X-Source": "coding-plan"},
+ }}
+ err = config.SaveConfig(configPath, cfg)
+ if err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ // Omitted custom_headers should preserve existing value.
+ recPreserve := httptest.NewRecorder()
+ reqPreserve := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
+ "model_name":"editable",
+ "model":"openai/gpt-4o-mini"
+ }`))
+ reqPreserve.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(recPreserve, reqPreserve)
+ if recPreserve.Code != http.StatusOK {
+ t.Fatalf("preserve status = %d, want %d, body=%s", recPreserve.Code, http.StatusOK, recPreserve.Body.String())
+ }
+
+ afterPreserve, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() after preserve error = %v", err)
+ }
+ if got := afterPreserve.ModelList[0].CustomHeaders["X-Source"]; got != "coding-plan" {
+ t.Fatalf("preserved custom_headers[X-Source] = %q, want %q", got, "coding-plan")
+ }
+
+ // Empty object should clear custom_headers.
+ recClear := httptest.NewRecorder()
+ reqClear := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
+ "model_name":"editable",
+ "model":"openai/gpt-4o-mini",
+ "custom_headers":{}
+ }`))
+ reqClear.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(recClear, reqClear)
+ if recClear.Code != http.StatusOK {
+ t.Fatalf("clear status = %d, want %d, body=%s", recClear.Code, http.StatusOK, recClear.Body.String())
+ }
+
+ afterClear, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() after clear error = %v", err)
+ }
+ if afterClear.ModelList[0].CustomHeaders != nil {
+ t.Fatalf("custom_headers = %#v, want nil", afterClear.ModelList[0].CustomHeaders)
+ }
+}
+
// TestHandleSetDefaultModel_RejectsNonexistentModel tests that setting a non-existent
// model as default returns 404. This covers the case where virtual models (which are
// filtered by SaveConfig) cannot be set as default.
diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go
index c8ef47308..95bbfd2c1 100644
--- a/web/backend/api/pico.go
+++ b/web/backend/api/pico.go
@@ -11,6 +11,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
+ ppid "github.com/sipeed/picoclaw/pkg/pid"
)
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
@@ -57,9 +58,34 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
gateway.mu.Lock()
ensurePicoTokenCachedLocked(h.configPath)
- gatewayAvailable := gateway.pidData != nil
+ cachedPID := gateway.pidData
+ trackedCmd := gateway.cmd
gateway.mu.Unlock()
+ gatewayAvailable := false
+ // Prefer fresh PID file data when available.
+ if pidData := ppid.ReadPidFileWithCheck(globalConfigDir()); pidData != nil {
+ gateway.mu.Lock()
+ gateway.pidData = pidData
+ setGatewayRuntimeStatusLocked("running")
+ gatewayAvailable = true
+ gateway.mu.Unlock()
+ } else if cachedPID != nil {
+ // No PID file now: keep availability only while tracked process is
+ // still alive (covers short PID-file races at startup/restart).
+ if isCmdProcessAliveLocked(trackedCmd) {
+ gatewayAvailable = true
+ } else {
+ gateway.mu.Lock()
+ if gateway.cmd == trackedCmd {
+ gateway.pidData = nil
+ setGatewayRuntimeStatusLocked("stopped")
+ }
+ gatewayAvailable = gateway.pidData != nil
+ gateway.mu.Unlock()
+ }
+ }
+
if !gatewayAvailable {
logger.Warnf("Gateway not available for WebSocket proxy")
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go
index ee5586746..04888fde7 100644
--- a/web/backend/api/pico_test.go
+++ b/web/backend/api/pico_test.go
@@ -11,6 +11,7 @@ import (
"strconv"
"testing"
+ "github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
@@ -307,6 +308,9 @@ func TestHandlePicoSetup_Response(t *testing.T) {
}
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
+ home := t.TempDir()
+ t.Setenv("PICOCLAW_HOME", home)
+
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
@@ -335,6 +339,16 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
+ if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
+ t.Fatalf("WritePidFile() error = %v", err)
+ }
+ origPidData := gateway.pidData
+ origPicoToken := gateway.picoToken
+ t.Cleanup(func() {
+ ppid.RemovePidFile(globalConfigDir())
+ gateway.pidData = origPidData
+ gateway.picoToken = origPicoToken
+ })
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "pico"
@@ -378,6 +392,9 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
}
func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
+ home := t.TempDir()
+ t.Setenv("PICOCLAW_HOME", home)
+
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
@@ -399,6 +416,12 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
+ if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
+ t.Fatalf("WritePidFile() error = %v", err)
+ }
+ t.Cleanup(func() {
+ ppid.RemovePidFile(globalConfigDir())
+ })
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
@@ -426,6 +449,134 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
}
}
+func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
+ home := t.TempDir()
+ t.Setenv("PICOCLAW_HOME", home)
+
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+ handler := h.handleWebSocketProxy()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/pico/ws" {
+ t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
+ }
+ w.WriteHeader(http.StatusOK)
+ _, _ = io.WriteString(w, r.Header.Get(protocolKey))
+ }))
+ defer server.Close()
+
+ cfg := config.DefaultConfig()
+ cfg.Gateway.Host = "127.0.0.1"
+ cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
+ cfg.Channels.Pico.Enabled = true
+ cfg.Channels.Pico.SetToken("ui-token")
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ pidData, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port)
+ if err != nil {
+ t.Fatalf("WritePidFile() error = %v", err)
+ }
+ t.Cleanup(func() {
+ ppid.RemovePidFile(globalConfigDir())
+ })
+
+ origPidData := gateway.pidData
+ origPicoToken := gateway.picoToken
+ origStatus := gateway.runtimeStatus
+ t.Cleanup(func() {
+ gateway.mu.Lock()
+ gateway.pidData = origPidData
+ gateway.picoToken = origPicoToken
+ gateway.runtimeStatus = origStatus
+ gateway.mu.Unlock()
+ })
+
+ gateway.mu.Lock()
+ gateway.pidData = nil
+ gateway.picoToken = ""
+ setGatewayRuntimeStatusLocked("stopped")
+ gateway.mu.Unlock()
+
+ req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
+ req.Header.Set(protocolKey, tokenPrefix+"ui-token")
+ rec := httptest.NewRecorder()
+ handler(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+
+ expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token"
+ if got := rec.Body.String(); got != expected {
+ t.Fatalf("forwarded protocol = %q, want %q", got, expected)
+ }
+
+ gateway.mu.Lock()
+ defer gateway.mu.Unlock()
+ if gateway.pidData == nil {
+ t.Fatal("gateway.pidData should be loaded from pid file")
+ }
+ if gateway.runtimeStatus != "running" {
+ t.Fatalf("runtimeStatus = %q, want %q", gateway.runtimeStatus, "running")
+ }
+}
+
+func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
+ configPath := filepath.Join(t.TempDir(), "config.json")
+ h := NewHandler(configPath)
+ handler := h.handleWebSocketProxy()
+
+ cfg := config.DefaultConfig()
+ cfg.Channels.Pico.Enabled = true
+ cfg.Channels.Pico.SetToken("ui-token")
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ cmd := startLongRunningProcess(t)
+ if cmd.Process != nil {
+ _ = cmd.Process.Kill()
+ }
+ _ = cmd.Wait()
+
+ origPidData := gateway.pidData
+ origPicoToken := gateway.picoToken
+ origCmd := gateway.cmd
+ origStatus := gateway.runtimeStatus
+ t.Cleanup(func() {
+ gateway.mu.Lock()
+ gateway.pidData = origPidData
+ gateway.picoToken = origPicoToken
+ gateway.cmd = origCmd
+ gateway.runtimeStatus = origStatus
+ gateway.mu.Unlock()
+ })
+
+ gateway.mu.Lock()
+ gateway.pidData = &ppid.PidFileData{PID: cmd.Process.Pid, Token: "stale-token"}
+ gateway.picoToken = "ui-token"
+ gateway.cmd = cmd
+ setGatewayRuntimeStatusLocked("running")
+ gateway.mu.Unlock()
+
+ req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
+ req.Header.Set(protocolKey, tokenPrefix+"ui-token")
+ rec := httptest.NewRecorder()
+ handler(rec, req)
+
+ if rec.Code != http.StatusServiceUnavailable {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
+ }
+ gateway.mu.Lock()
+ defer gateway.mu.Unlock()
+ if gateway.pidData != nil {
+ t.Fatal("gateway.pidData should be cleared after stale process exit is detected")
+ }
+}
+
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()
diff --git a/web/backend/api/router.go b/web/backend/api/router.go
index 3823fe08c..c6781baf1 100644
--- a/web/backend/api/router.go
+++ b/web/backend/api/router.go
@@ -81,6 +81,9 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
// Launcher service parameters (port/public)
h.registerLauncherConfigRoutes(mux)
+ // Self-update endpoint (requires dashboard auth)
+ h.registerUpdateRoutes(mux)
+
// Runtime build/version metadata
h.registerVersionRoutes(mux)
diff --git a/web/backend/api/session.go b/web/backend/api/session.go
index 052f085d6..914e075f9 100644
--- a/web/backend/api/session.go
+++ b/web/backend/api/session.go
@@ -44,12 +44,24 @@ type sessionListItem struct {
Updated string `json:"updated"`
}
+type sessionChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ Media []string `json:"media,omitempty"`
+}
+
// legacyPicoSessionPrefix is the legacy key prefix used by older Pico JSON/JSONL
// sessions before structured scope metadata existed.
const (
legacyPicoSessionPrefix = "agent:main:pico:direct:pico:"
- maxSessionJSONLLineSize = 10 * 1024 * 1024 // 10 MB
+ picoSessionPrefix = legacyPicoSessionPrefix
+
+ // Keep the session API aligned with the shared JSONL store reader limit in
+ // pkg/memory/jsonl.go so oversized lines fail consistently everywhere.
+ maxSessionJSONLLineSize = 10 * 1024 * 1024
maxSessionTitleRunes = 60
+
+ handledToolResponseSummaryText = "Requested output delivered via tool attachment."
)
// extractLegacyPicoSessionID extracts the session UUID from an old Pico key.
@@ -327,32 +339,21 @@ func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessio
func buildSessionListItem(sessionID string, sess sessionFile) sessionListItem {
preview := ""
for _, msg := range sess.Messages {
- if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" {
- preview = msg.Content
+ if msg.Role == "user" {
+ preview = sessionMessagePreview(msg)
+ }
+ if preview != "" {
break
}
}
- title := strings.TrimSpace(sess.Summary)
- if title == "" {
- title = preview
- }
-
- title = truncateRunes(title, maxSessionTitleRunes)
preview = truncateRunes(preview, maxSessionTitleRunes)
if preview == "" {
preview = "(empty)"
}
- if title == "" {
- title = preview
- }
+ title := preview
- validMessageCount := 0
- for _, msg := range sess.Messages {
- if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
- validMessageCount++
- }
- }
+ validMessageCount := len(visibleSessionMessages(sess.Messages))
return sessionListItem{
ID: sessionID,
@@ -379,6 +380,99 @@ func truncateRunes(s string, maxLen int) string {
return string(runes[:maxLen]) + "..."
}
+func sessionMessageVisible(msg providers.Message) bool {
+ return strings.TrimSpace(msg.Content) != "" || len(msg.Media) > 0
+}
+
+func sessionMessagePreview(msg providers.Message) string {
+ if content := strings.TrimSpace(msg.Content); content != "" {
+ return content
+ }
+ if len(msg.Media) > 0 {
+ return "[image]"
+ }
+ return ""
+}
+
+func visibleSessionMessages(messages []providers.Message) []sessionChatMessage {
+ transcript := make([]sessionChatMessage, 0, len(messages))
+
+ for _, msg := range messages {
+ switch msg.Role {
+ case "user":
+ if sessionMessageVisible(msg) {
+ transcript = append(transcript, sessionChatMessage{
+ Role: "user",
+ Content: msg.Content,
+ Media: append([]string(nil), msg.Media...),
+ })
+ }
+
+ case "assistant":
+ visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls)
+ if len(visibleToolMessages) > 0 {
+ transcript = append(transcript, visibleToolMessages...)
+ }
+
+ // Pico web chat can persist both visible `message` tool output and a
+ // later plain assistant reply in the same turn. Hide only the fixed
+ // internal summary that marks handled tool delivery.
+ if len(visibleToolMessages) > 0 || !sessionMessageVisible(msg) || assistantMessageInternalOnly(msg) {
+ continue
+ }
+
+ transcript = append(transcript, sessionChatMessage{
+ Role: "assistant",
+ Content: msg.Content,
+ Media: append([]string(nil), msg.Media...),
+ })
+ }
+ }
+
+ return transcript
+}
+
+func assistantMessageInternalOnly(msg providers.Message) bool {
+ return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText
+}
+
+func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage {
+ if len(toolCalls) == 0 {
+ return nil
+ }
+
+ messages := make([]sessionChatMessage, 0, len(toolCalls))
+ for _, tc := range toolCalls {
+ name := tc.Name
+ argsJSON := ""
+ if tc.Function != nil {
+ if name == "" {
+ name = tc.Function.Name
+ }
+ argsJSON = tc.Function.Arguments
+ }
+
+ switch name {
+ case "message":
+ var args struct {
+ Content string `json:"content"`
+ }
+ if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
+ continue
+ }
+ if strings.TrimSpace(args.Content) == "" {
+ continue
+ }
+ messages = append(messages, sessionChatMessage{
+ Role: "assistant",
+ Content: args.Content,
+ })
+ }
+ }
+
+ return messages
+}
+
// sessionsDir resolves the path to the gateway's session storage directory.
// It reads the workspace from config, falling back to ~/.picoclaw/workspace.
func (h *Handler) sessionsDir() (string, error) {
@@ -530,22 +624,7 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
}
}
- // Convert to a simpler format for the frontend
- type chatMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
- }
-
- messages := make([]chatMessage, 0, len(sess.Messages))
- for _, msg := range sess.Messages {
- // Only include user and assistant messages that have actual content
- if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
- messages = append(messages, chatMessage{
- Role: msg.Role,
- Content: msg.Content,
- })
- }
- }
+ messages := visibleSessionMessages(sess.Messages)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go
index 40e53b0b0..4c871ee30 100644
--- a/web/backend/api/session_test.go
+++ b/web/backend/api/session_test.go
@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
+ "strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
@@ -34,9 +35,9 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
defer cleanup()
dir := sessionsTestDir(t, configPath)
- store, err := memory.NewJSONLStore(dir)
- if err != nil {
- t.Fatalf("NewJSONLStore() error = %v", err)
+ store, storeErr := memory.NewJSONLStore(dir)
+ if storeErr != nil {
+ t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := legacyPicoSessionPrefix + "history-jsonl"
@@ -87,22 +88,26 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
if items[0].MessageCount != 2 {
t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
}
- if items[0].Title != "JSONL-backed session" {
- t.Fatalf("items[0].Title = %q, want %q", items[0].Title, "JSONL-backed session")
+ if items[0].Title != "Explain why the history API is empty after migration." {
+ t.Fatalf(
+ "items[0].Title = %q, want %q",
+ items[0].Title,
+ "Explain why the history API is empty after migration.",
+ )
}
if items[0].Preview != "Explain why the history API is empty after migration." {
t.Fatalf("items[0].Preview = %q", items[0].Preview)
}
}
-func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) {
+func TestHandleListSessions_TitleUsesFirstUserMessage(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
- store, err := memory.NewJSONLStore(dir)
- if err != nil {
- t.Fatalf("NewJSONLStore() error = %v", err)
+ store, storeErr := memory.NewJSONLStore(dir)
+ if storeErr != nil {
+ t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := legacyPicoSessionPrefix + "summary-title"
@@ -139,10 +144,7 @@ func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) {
if len(items) != 1 {
t.Fatalf("len(items) = %d, want 1", len(items))
}
- expectedTitle := truncateRunes(
- "This summary is intentionally longer than sixty characters so it must be truncated in the history menu.",
- maxSessionTitleRunes,
- )
+ expectedTitle := truncateRunes("fallback preview", maxSessionTitleRunes)
if items[0].Title != expectedTitle {
t.Fatalf("items[0].Title = %q", items[0].Title)
}
@@ -220,22 +222,20 @@ func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
defer cleanup()
dir := sessionsTestDir(t, configPath)
- store, err := memory.NewJSONLStore(dir)
- if err != nil {
- t.Fatalf("NewJSONLStore() error = %v", err)
+ store, storeErr := memory.NewJSONLStore(dir)
+ if storeErr != nil {
+ t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := "sk_v1_scope_discovery"
- addErr := store.AddFullMessage(nil, sessionKey, providers.Message{
+ if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Content: "scope discovered session",
- })
- if addErr != nil {
- t.Fatalf("AddFullMessage() error = %v", addErr)
+ }); err != nil {
+ t.Fatalf("AddFullMessage() error = %v", err)
}
- summaryErr := store.SetSummary(nil, sessionKey, "scope summary")
- if summaryErr != nil {
- t.Fatalf("SetSummary() error = %v", summaryErr)
+ if err := store.SetSummary(nil, sessionKey, "scope summary"); err != nil {
+ t.Fatalf("SetSummary() error = %v", err)
}
scopeData, err := json.Marshal(session.SessionScope{
@@ -292,6 +292,359 @@ func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
}
}
+func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ dir := sessionsTestDir(t, configPath)
+ store, err := memory.NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore() error = %v", err)
+ }
+
+ sessionKey := picoSessionPrefix + "detail-message-tool"
+ for _, msg := range []providers.Message{
+ {Role: "user", Content: "test"},
+ {
+ Role: "assistant",
+ Content: "",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Function: &providers.FunctionCall{
+ Name: "message",
+ Arguments: `{"content":"visible tool output"}`,
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "Message sent to pico:pico:detail-message-tool", ToolCallID: "call_1"},
+ {Role: "assistant", Content: handledToolResponseSummaryText},
+ } {
+ if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
+ t.Fatalf("AddFullMessage() error = %v", err)
+ }
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-message-tool", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var resp struct {
+ Messages []struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ } `json:"messages"`
+ }
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if len(resp.Messages) != 2 {
+ t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
+ }
+ if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
+ t.Fatalf("assistant message = %#v, want visible tool output", resp.Messages[1])
+ }
+}
+
+func TestHandleGetSession_PreservesFinalAssistantReplyAfterMessageToolOutput(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ dir := sessionsTestDir(t, configPath)
+ store, err := memory.NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore() error = %v", err)
+ }
+
+ sessionKey := picoSessionPrefix + "detail-message-tool-final-reply"
+ for _, msg := range []providers.Message{
+ {Role: "user", Content: "test"},
+ {
+ Role: "assistant",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Function: &providers.FunctionCall{
+ Name: "message",
+ Arguments: `{"content":"visible tool output"}`,
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "Message sent to pico:pico:detail-message-tool-final-reply", ToolCallID: "call_1"},
+ {Role: "assistant", Content: "final assistant reply"},
+ } {
+ if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
+ t.Fatalf("AddFullMessage() error = %v", err)
+ }
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-message-tool-final-reply", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var resp struct {
+ Messages []struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ } `json:"messages"`
+ }
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if len(resp.Messages) != 3 {
+ t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
+ }
+ if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
+ t.Fatalf("interim assistant message = %#v, want visible tool output", resp.Messages[1])
+ }
+ if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "final assistant reply" {
+ t.Fatalf("final assistant message = %#v, want final assistant reply", resp.Messages[2])
+ }
+}
+
+func TestHandleListSessions_MessageCountUsesVisibleTranscript(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ dir := sessionsTestDir(t, configPath)
+ store, err := memory.NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore() error = %v", err)
+ }
+
+ sessionKey := picoSessionPrefix + "list-visible-count"
+ for _, msg := range []providers.Message{
+ {Role: "user", Content: "test"},
+ {
+ Role: "assistant",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Function: &providers.FunctionCall{
+ Name: "message",
+ Arguments: `{"content":"visible tool output"}`,
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "Message sent to pico:pico:list-visible-count", ToolCallID: "call_1"},
+ {Role: "assistant", Content: handledToolResponseSummaryText},
+ } {
+ if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
+ t.Fatalf("AddFullMessage() error = %v", err)
+ }
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var items []sessionListItem
+ if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if len(items) != 1 {
+ t.Fatalf("len(items) = %d, want 1", len(items))
+ }
+ if items[0].MessageCount != 2 {
+ t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
+ }
+}
+
+func TestHandleGetSession_IncludesMediaOnlyMessages(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ dir := sessionsTestDir(t, configPath)
+ store, err := memory.NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore() error = %v", err)
+ }
+
+ sessionKey := picoSessionPrefix + "detail-media-only"
+ if err := store.AddFullMessage(nil, sessionKey, providers.Message{
+ Role: "user",
+ Media: []string{"data:image/png;base64,abc123"},
+ }); err != nil {
+ t.Fatalf("AddFullMessage(user) error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-media-only", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var resp struct {
+ Messages []struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ Media []string `json:"media"`
+ } `json:"messages"`
+ }
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if len(resp.Messages) != 1 {
+ t.Fatalf("len(resp.Messages) = %d, want 1", len(resp.Messages))
+ }
+ if resp.Messages[0].Role != "user" || len(resp.Messages[0].Media) != 1 {
+ t.Fatalf("message = %#v, want user message with media", resp.Messages[0])
+ }
+}
+
+func TestHandleSessions_SupportsJSONLMessagesUpToStoreCap(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ dir := sessionsTestDir(t, configPath)
+ store, err := memory.NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore() error = %v", err)
+ }
+
+ sessionKey := picoSessionPrefix + "detail-large-jsonl"
+ largeContent := strings.Repeat("x", 9*1024*1024)
+ if err := store.AddFullMessage(nil, sessionKey, providers.Message{
+ Role: "user",
+ Content: largeContent,
+ }); err != nil {
+ t.Fatalf("AddFullMessage() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ listRec := httptest.NewRecorder()
+ listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
+ mux.ServeHTTP(listRec, listReq)
+
+ if listRec.Code != http.StatusOK {
+ t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
+ }
+
+ var items []sessionListItem
+ if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
+ t.Fatalf("list Unmarshal() error = %v", err)
+ }
+ if len(items) != 1 {
+ t.Fatalf("len(items) = %d, want 1", len(items))
+ }
+
+ detailRec := httptest.NewRecorder()
+ detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-large-jsonl", nil)
+ mux.ServeHTTP(detailRec, detailReq)
+
+ if detailRec.Code != http.StatusOK {
+ t.Fatalf(
+ "detail status = %d, want %d, body=%s",
+ detailRec.Code,
+ http.StatusOK,
+ detailRec.Body.String(),
+ )
+ }
+
+ var resp struct {
+ Messages []struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ } `json:"messages"`
+ }
+ if err := json.Unmarshal(detailRec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("detail Unmarshal() error = %v", err)
+ }
+ if len(resp.Messages) != 1 {
+ t.Fatalf("len(resp.Messages) = %d, want 1", len(resp.Messages))
+ }
+ if resp.Messages[0].Role != "user" {
+ t.Fatalf("resp.Messages[0].Role = %q, want %q", resp.Messages[0].Role, "user")
+ }
+ if got := len(resp.Messages[0].Content); got != len(largeContent) {
+ t.Fatalf("len(resp.Messages[0].Content) = %d, want %d", got, len(largeContent))
+ }
+}
+
+func TestHandleListSessions_UsesImagePreviewForMediaOnlyMessage(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ dir := sessionsTestDir(t, configPath)
+ store, err := memory.NewJSONLStore(dir)
+ if err != nil {
+ t.Fatalf("NewJSONLStore() error = %v", err)
+ }
+
+ sessionKey := picoSessionPrefix + "preview-media-only"
+ if err := store.AddFullMessage(nil, sessionKey, providers.Message{
+ Role: "user",
+ Media: []string{"data:image/png;base64,abc123"},
+ }); err != nil {
+ t.Fatalf("AddFullMessage() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var items []sessionListItem
+ if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if len(items) != 1 {
+ t.Fatalf("len(items) = %d, want 1", len(items))
+ }
+ if items[0].Preview != "[image]" {
+ t.Fatalf("items[0].Preview = %q, want %q", items[0].Preview, "[image]")
+ }
+ if items[0].MessageCount != 1 {
+ t.Fatalf("items[0].MessageCount = %d, want 1", items[0].MessageCount)
+ }
+}
+
func TestHandleDeleteSession_JSONLStorage(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
diff --git a/web/backend/api/skills.go b/web/backend/api/skills.go
index b2036f66c..2c054c41b 100644
--- a/web/backend/api/skills.go
+++ b/web/backend/api/skills.go
@@ -1,40 +1,115 @@
package api
import (
+ "bytes"
"encoding/json"
+ "errors"
"fmt"
"io"
+ "io/fs"
"net/http"
+ "net/url"
"os"
"path/filepath"
"regexp"
+ "strconv"
"strings"
+ "sync"
+ "time"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/skills"
+ "github.com/sipeed/picoclaw/pkg/utils"
)
type skillSupportResponse struct {
- Skills []skills.SkillInfo `json:"skills"`
+ Skills []skillSupportItem `json:"skills"`
+}
+
+type skillSupportItem struct {
+ Name string `json:"name"`
+ Path string `json:"path"`
+ Source string `json:"source"`
+ Description string `json:"description"`
+ OriginKind string `json:"origin_kind"`
+ RegistryName string `json:"registry_name,omitempty"`
+ RegistryURL string `json:"registry_url,omitempty"`
+ InstalledVersion string `json:"installed_version,omitempty"`
+ InstalledAt int64 `json:"installed_at,omitempty"`
}
type skillDetailResponse struct {
- Name string `json:"name"`
- Path string `json:"path"`
- Source string `json:"source"`
- Description string `json:"description"`
- Content string `json:"content"`
+ skillSupportItem
+ Content string `json:"content"`
+}
+
+type skillSearchResultItem struct {
+ Score float64 `json:"score"`
+ Slug string `json:"slug"`
+ DisplayName string `json:"display_name"`
+ Summary string `json:"summary"`
+ Version string `json:"version"`
+ RegistryName string `json:"registry_name"`
+ URL string `json:"url,omitempty"`
+ Installed bool `json:"installed"`
+ InstalledName string `json:"installed_name,omitempty"`
+}
+
+type skillSearchResponse struct {
+ Results []skillSearchResultItem `json:"results"`
+ Limit int `json:"limit"`
+ Offset int `json:"offset"`
+ NextOffset int `json:"next_offset,omitempty"`
+ HasMore bool `json:"has_more"`
+}
+
+type installSkillRequest struct {
+ Slug string `json:"slug"`
+ Registry string `json:"registry"`
+ Version string `json:"version,omitempty"`
+ Force bool `json:"force,omitempty"`
+}
+
+type installSkillResponse struct {
+ Status string `json:"status"`
+ Slug string `json:"slug"`
+ Registry string `json:"registry"`
+ Version string `json:"version"`
+ Summary string `json:"summary,omitempty"`
+ IsSuspicious bool `json:"is_suspicious,omitempty"`
+ InstalledSkill *skillSupportItem `json:"skill,omitempty"`
+}
+
+type installedSkillOriginMeta struct {
+ Version int `json:"version"`
+ OriginKind string `json:"origin_kind,omitempty"`
+ Registry string `json:"registry,omitempty"`
+ Slug string `json:"slug,omitempty"`
+ RegistryURL string `json:"registry_url,omitempty"`
+ InstalledVersion string `json:"installed_version,omitempty"`
+ InstalledAt int64 `json:"installed_at"`
}
var (
skillNameSanitizer = regexp.MustCompile(`[^a-z0-9-]+`)
importedSkillFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
skillFrontmatterStripper = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
+ persistSkillOriginMeta = writeSkillOriginMeta
+ workspaceSkillWriteMu sync.Mutex
+ errImportedSkillExists = errors.New("skill already exists")
+)
+
+const (
+ maxImportedSkillSize = 1 << 20
+ maxRegistrySearchFanout = 1000
)
func (h *Handler) registerSkillRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/skills", h.handleListSkills)
mux.HandleFunc("GET /api/skills/{name}", h.handleGetSkill)
+ mux.HandleFunc("GET /api/skills/search", h.handleSearchSkills)
+ mux.HandleFunc("POST /api/skills/install", h.handleInstallSkill)
mux.HandleFunc("POST /api/skills/import", h.handleImportSkill)
mux.HandleFunc("DELETE /api/skills/{name}", h.handleDeleteSkill)
}
@@ -46,11 +121,15 @@ func (h *Handler) handleListSkills(w http.ResponseWriter, r *http.Request) {
return
}
- loader := newSkillsLoader(cfg.WorkspacePath())
+ items, err := buildSkillSupportItems(cfg)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to build skill list: %v", err), http.StatusInternalServerError)
+ return
+ }
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(skillSupportResponse{
- Skills: loader.ListSkills(),
+ Skills: items,
})
}
@@ -61,16 +140,18 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
return
}
- loader := newSkillsLoader(cfg.WorkspacePath())
+ skillItems, err := buildSkillSupportItems(cfg)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to build skill list: %v", err), http.StatusInternalServerError)
+ return
+ }
name := r.PathValue("name")
- allSkills := loader.ListSkills()
-
- for _, skill := range allSkills {
- if skill.Name != name {
+ for _, skillItem := range skillItems {
+ if skillItem.Name != name {
continue
}
- content, err := loadSkillContent(skill.Path)
+ content, err := loadSkillContent(skillItem.Path)
if err != nil {
http.Error(w, "Skill content not found", http.StatusNotFound)
return
@@ -78,11 +159,8 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(skillDetailResponse{
- Name: skill.Name,
- Path: skill.Path,
- Source: skill.Source,
- Description: skill.Description,
- Content: content,
+ skillSupportItem: skillItem,
+ Content: content,
})
return
}
@@ -90,6 +168,266 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Skill not found", http.StatusNotFound)
}
+func (h *Handler) handleSearchSkills(w http.ResponseWriter, r *http.Request) {
+ cfg, loadErr := config.LoadConfig(h.configPath)
+ if loadErr != nil {
+ http.Error(w, fmt.Sprintf("Failed to load config: %v", loadErr), http.StatusInternalServerError)
+ return
+ }
+ if registryErr := ensureSkillRegistryToolEnabled(cfg, "find_skills"); registryErr != nil {
+ http.Error(w, registryErr.Error(), http.StatusBadRequest)
+ return
+ }
+
+ query := strings.TrimSpace(r.URL.Query().Get("q"))
+
+ limit := 20
+ if rawLimit := strings.TrimSpace(r.URL.Query().Get("limit")); rawLimit != "" {
+ parsedLimit, parseErr := strconv.Atoi(rawLimit)
+ if parseErr != nil || parsedLimit < 1 || parsedLimit > 50 {
+ http.Error(w, "limit must be between 1 and 50", http.StatusBadRequest)
+ return
+ }
+ limit = parsedLimit
+ }
+ offset := 0
+ if rawOffset := strings.TrimSpace(r.URL.Query().Get("offset")); rawOffset != "" {
+ parsedOffset, parseErr := strconv.Atoi(rawOffset)
+ if parseErr != nil || parsedOffset < 0 {
+ http.Error(w, "offset must be 0 or greater", http.StatusBadRequest)
+ return
+ }
+ offset = parsedOffset
+ }
+
+ installedSkills, err := buildOccupiedWorkspaceSkillsByDirectory(cfg)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to inspect installed skills: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ if query == "" {
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(skillSearchResponse{
+ Results: []skillSearchResultItem{},
+ Limit: limit,
+ Offset: offset,
+ HasMore: false,
+ })
+ return
+ }
+
+ registryMgr := newSkillsRegistryManager(cfg)
+ searchLimit := offset + limit + 1
+ if searchLimit > maxRegistrySearchFanout {
+ searchLimit = maxRegistrySearchFanout
+ }
+ results, err := registryMgr.SearchAll(r.Context(), query, searchLimit)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to search skills: %v", err), http.StatusBadGateway)
+ return
+ }
+
+ if offset > len(results) {
+ offset = len(results)
+ }
+
+ end := offset + limit
+ if end > len(results) {
+ end = len(results)
+ }
+
+ pageResults := results[offset:end]
+ response := make([]skillSearchResultItem, 0, len(pageResults))
+ for _, result := range pageResults {
+ installedSkill, installed := installedSkills[result.Slug]
+ item := skillSearchResultItem{
+ Score: result.Score,
+ Slug: result.Slug,
+ DisplayName: result.DisplayName,
+ Summary: result.Summary,
+ Version: result.Version,
+ RegistryName: result.RegistryName,
+ URL: registrySkillURL(cfg, result.RegistryName, result.Slug),
+ Installed: installed,
+ }
+ if installed {
+ item.InstalledName = installedSkill.Name
+ }
+ response = append(response, item)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ nextOffset := 0
+ hasMore := len(results) > end
+ if hasMore {
+ nextOffset = end
+ }
+ json.NewEncoder(w).Encode(skillSearchResponse{
+ Results: response,
+ Limit: limit,
+ Offset: offset,
+ NextOffset: nextOffset,
+ HasMore: hasMore,
+ })
+}
+
+func (h *Handler) handleInstallSkill(w http.ResponseWriter, r *http.Request) {
+ cfg, loadErr := config.LoadConfig(h.configPath)
+ if loadErr != nil {
+ http.Error(w, fmt.Sprintf("Failed to load config: %v", loadErr), http.StatusInternalServerError)
+ return
+ }
+ if registryErr := ensureSkillRegistryToolEnabled(cfg, "install_skill"); registryErr != nil {
+ http.Error(w, registryErr.Error(), http.StatusBadRequest)
+ return
+ }
+
+ var req installSkillRequest
+ if decodeErr := json.NewDecoder(r.Body).Decode(&req); decodeErr != nil {
+ http.Error(w, fmt.Sprintf("Invalid JSON: %v", decodeErr), http.StatusBadRequest)
+ return
+ }
+
+ req.Slug = strings.TrimSpace(req.Slug)
+ req.Registry = strings.TrimSpace(req.Registry)
+ req.Version = strings.TrimSpace(req.Version)
+
+ if validateErr := utils.ValidateSkillIdentifier(req.Slug); validateErr != nil {
+ http.Error(
+ w,
+ fmt.Sprintf("invalid slug %q: error: %s", req.Slug, validateErr.Error()),
+ http.StatusBadRequest,
+ )
+ return
+ }
+ if validateErr := utils.ValidateSkillIdentifier(req.Registry); validateErr != nil {
+ http.Error(
+ w,
+ fmt.Sprintf("invalid registry %q: error: %s", req.Registry, validateErr.Error()),
+ http.StatusBadRequest,
+ )
+ return
+ }
+
+ registryMgr := newSkillsRegistryManager(cfg)
+ registry := registryMgr.GetRegistry(req.Registry)
+ if registry == nil {
+ http.Error(w, fmt.Sprintf("registry %q not found", req.Registry), http.StatusBadRequest)
+ return
+ }
+
+ workspace := cfg.WorkspacePath()
+ skillsRoot := filepath.Join(workspace, "skills")
+ targetDir := filepath.Join(workspace, "skills", req.Slug)
+ workspaceSkillWriteMu.Lock()
+ defer workspaceSkillWriteMu.Unlock()
+
+ targetExists := false
+ if _, statErr := os.Stat(targetDir); statErr == nil {
+ targetExists = true
+ } else if !os.IsNotExist(statErr) {
+ http.Error(w, fmt.Sprintf("Failed to inspect install target: %v", statErr), http.StatusInternalServerError)
+ return
+ }
+
+ if !req.Force && targetExists {
+ http.Error(w, fmt.Sprintf("skill %q already installed at %s", req.Slug, targetDir), http.StatusConflict)
+ return
+ }
+ if err := os.MkdirAll(skillsRoot, 0o755); err != nil {
+ http.Error(w, fmt.Sprintf("Failed to create skills directory: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ stagedWorkspaceRoot, stagedTargetDir, err := createStagedSkillInstall(skillsRoot, req.Slug)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to prepare staged install: %v", err), http.StatusInternalServerError)
+ return
+ }
+ defer os.RemoveAll(stagedWorkspaceRoot)
+
+ result, err := registry.DownloadAndInstall(r.Context(), req.Slug, req.Version, stagedTargetDir)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Failed to install skill: %v", err), http.StatusBadGateway)
+ return
+ }
+ if result.IsMalwareBlocked {
+ http.Error(
+ w,
+ fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", req.Slug),
+ http.StatusForbidden,
+ )
+ return
+ }
+
+ if findWorkspaceSkillInfoByDirectory(stagedWorkspaceRoot, req.Slug) == nil {
+ http.Error(
+ w,
+ fmt.Sprintf("Failed to install skill: registry archive for %q is not a valid skill", req.Slug),
+ http.StatusBadGateway,
+ )
+ return
+ }
+
+ installedAt := time.Now().UnixMilli()
+ if err := persistSkillOriginMeta(stagedTargetDir, installedSkillOriginMeta{
+ Version: 1,
+ OriginKind: "third_party",
+ Registry: registry.Name(),
+ Slug: req.Slug,
+ RegistryURL: registrySkillURL(cfg, registry.Name(), req.Slug),
+ InstalledVersion: result.Version,
+ InstalledAt: installedAt,
+ }); err != nil {
+ http.Error(w, fmt.Sprintf("Failed to persist skill metadata: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ if err := commitStagedSkillInstall(
+ stagedWorkspaceRoot,
+ stagedTargetDir,
+ targetDir,
+ req.Force && targetExists,
+ ); err != nil {
+ http.Error(w, fmt.Sprintf("Failed to activate installed skill: %v", err), http.StatusInternalServerError)
+ return
+ }
+
+ validatedSkill := findWorkspaceSkillByDirectory(cfg, req.Slug)
+ if validatedSkill == nil {
+ http.Error(
+ w,
+ fmt.Sprintf("Failed to install skill: activated archive for %q is not a valid skill", req.Slug),
+ http.StatusBadGateway,
+ )
+ return
+ }
+
+ installedSkill := &skillSupportItem{
+ Name: validatedSkill.Name,
+ Path: validatedSkill.Path,
+ Source: validatedSkill.Source,
+ Description: validatedSkill.Description,
+ OriginKind: "third_party",
+ RegistryName: registry.Name(),
+ RegistryURL: registrySkillURL(cfg, registry.Name(), req.Slug),
+ InstalledVersion: result.Version,
+ InstalledAt: installedAt,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(installSkillResponse{
+ Status: "ok",
+ Slug: req.Slug,
+ Registry: registry.Name(),
+ Version: result.Version,
+ Summary: result.Summary,
+ IsSuspicious: result.IsSuspicious,
+ InstalledSkill: installedSkill,
+ })
+}
+
func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
@@ -110,54 +448,26 @@ func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) {
}
defer uploadedFile.Close()
- content, err := io.ReadAll(io.LimitReader(uploadedFile, (1<<20)+1))
+ content, err := io.ReadAll(io.LimitReader(uploadedFile, maxImportedSkillSize+1))
if err != nil {
http.Error(w, fmt.Sprintf("Failed to read file: %v", err), http.StatusBadRequest)
return
}
- if len(content) > 1<<20 {
+ if len(content) > maxImportedSkillSize {
http.Error(w, "file exceeds 1MB limit", http.StatusBadRequest)
return
}
+ workspaceSkillWriteMu.Lock()
+ defer workspaceSkillWriteMu.Unlock()
- skillName, err := normalizeImportedSkillName(fileHeader.Filename, content)
+ importedSkill, statusCode, err := importUploadedSkill(cfg, fileHeader.Filename, content)
if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
+ http.Error(w, err.Error(), statusCode)
return
}
- content = normalizeImportedSkillContent(content, skillName)
-
- workspace := cfg.WorkspacePath()
- skillDir := filepath.Join(workspace, "skills", skillName)
- skillFile := filepath.Join(skillDir, "SKILL.md")
- if _, err := os.Stat(skillDir); err == nil {
- http.Error(w, "skill already exists", http.StatusConflict)
- return
- }
-
- if err := os.MkdirAll(skillDir, 0o755); err != nil {
- http.Error(w, fmt.Sprintf("Failed to create skill directory: %v", err), http.StatusInternalServerError)
- return
- }
- if err := os.WriteFile(skillFile, content, 0o644); err != nil {
- http.Error(w, fmt.Sprintf("Failed to save skill: %v", err), http.StatusInternalServerError)
- return
- }
-
- loader := newSkillsLoader(workspace)
- for _, skill := range loader.ListSkills() {
- if skill.Path == skillFile || (skill.Name == skillName && skill.Source == "workspace") {
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(skill)
- return
- }
- }
w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(map[string]string{
- "name": skillName,
- "path": skillFile,
- })
+ json.NewEncoder(w).Encode(importedSkill)
}
func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) {
@@ -169,6 +479,9 @@ func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) {
loader := newSkillsLoader(cfg.WorkspacePath())
name := r.PathValue("name")
+ workspaceSkillWriteMu.Lock()
+ defer workspaceSkillWriteMu.Unlock()
+
for _, skill := range loader.ListSkills() {
if skill.Name != name {
continue
@@ -197,12 +510,274 @@ func newSkillsLoader(workspace string) *skills.SkillsLoader {
)
}
+func newSkillsRegistryManager(cfg *config.Config) *skills.RegistryManager {
+ clawHubConfig := cfg.Tools.Skills.Registries.ClawHub
+ return skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
+ MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
+ ClawHub: skills.ClawHubConfig{
+ Enabled: clawHubConfig.Enabled,
+ BaseURL: clawHubConfig.BaseURL,
+ AuthToken: clawHubConfig.AuthToken.String(),
+ SearchPath: clawHubConfig.SearchPath,
+ SkillsPath: clawHubConfig.SkillsPath,
+ DownloadPath: clawHubConfig.DownloadPath,
+ Timeout: clawHubConfig.Timeout,
+ MaxZipSize: clawHubConfig.MaxZipSize,
+ MaxResponseSize: clawHubConfig.MaxResponseSize,
+ },
+ })
+}
+
+func ensureSkillRegistryToolEnabled(cfg *config.Config, toolName string) error {
+ if !cfg.Tools.IsToolEnabled("skills") {
+ return fmt.Errorf("tools.skills is disabled")
+ }
+ if !cfg.Tools.IsToolEnabled(toolName) {
+ return fmt.Errorf("%s is disabled", toolName)
+ }
+ return nil
+}
+
+func buildSkillSupportItems(cfg *config.Config) ([]skillSupportItem, error) {
+ rawSkills := newSkillsLoader(cfg.WorkspacePath()).ListSkills()
+ items := make([]skillSupportItem, 0, len(rawSkills))
+ for _, skill := range rawSkills {
+ item, err := enrichSkillInfo(cfg, skill)
+ if err != nil {
+ return nil, err
+ }
+ items = append(items, item)
+ }
+ return items, nil
+}
+
+func buildWorkspaceSkillItemsByDirectory(cfg *config.Config) (map[string]skillSupportItem, error) {
+ result := make(map[string]skillSupportItem)
+ items, err := buildSkillSupportItems(cfg)
+ if err != nil {
+ return nil, err
+ }
+ for _, skill := range items {
+ if skill.Source != "workspace" {
+ continue
+ }
+ dir := filepath.Base(filepath.Dir(skill.Path))
+ if dir == "" {
+ continue
+ }
+ result[dir] = skill
+ }
+ return result, nil
+}
+
+func buildOccupiedWorkspaceSkillsByDirectory(cfg *config.Config) (map[string]skillSupportItem, error) {
+ result := make(map[string]skillSupportItem)
+ items, err := buildSkillSupportItems(cfg)
+ if err != nil {
+ return nil, err
+ }
+ for _, skill := range items {
+ if skill.Source != "workspace" {
+ continue
+ }
+
+ key := filepath.Base(filepath.Dir(skill.Path))
+ if meta, err := readInstalledSkillOriginMeta(skill.Path); err == nil && meta != nil && meta.Slug != "" {
+ key = meta.Slug
+ }
+ if key == "" {
+ continue
+ }
+ result[key] = skill
+ }
+ return result, nil
+}
+
+func findWorkspaceSkillByDirectory(cfg *config.Config, directory string) *skillSupportItem {
+ items, err := buildWorkspaceSkillItemsByDirectory(cfg)
+ if err != nil {
+ return nil
+ }
+ skill, ok := items[directory]
+ if !ok {
+ return nil
+ }
+ return &skill
+}
+
+func findWorkspaceSkillInfoByDirectory(workspace, directory string) *skills.SkillInfo {
+ loader := skills.NewSkillsLoader(workspace, "", "")
+ for _, skill := range loader.ListSkills() {
+ if skill.Source != "workspace" {
+ continue
+ }
+ if filepath.Base(filepath.Dir(skill.Path)) != directory {
+ continue
+ }
+ skillCopy := skill
+ return &skillCopy
+ }
+ return nil
+}
+
+func createStagedSkillInstall(skillsRoot, slug string) (string, string, error) {
+ stagedWorkspaceRoot, err := os.MkdirTemp(skillsRoot, "."+slug+"-install-*")
+ if err != nil {
+ return "", "", err
+ }
+ stagedTargetDir := filepath.Join(stagedWorkspaceRoot, "skills", slug)
+ return stagedWorkspaceRoot, stagedTargetDir, nil
+}
+
+func commitStagedSkillInstall(stagedWorkspaceRoot, stagedTargetDir, targetDir string, replaceExisting bool) error {
+ if !replaceExisting {
+ return os.Rename(stagedTargetDir, targetDir)
+ }
+
+ backupDir, err := reserveTempDirPath(filepath.Dir(targetDir), "."+filepath.Base(targetDir)+"-backup-*")
+ if err != nil {
+ return err
+ }
+
+ if err := os.Rename(targetDir, backupDir); err != nil {
+ return fmt.Errorf("failed to move existing skill aside: %w", err)
+ }
+
+ if err := os.Rename(stagedTargetDir, targetDir); err != nil {
+ if rollbackErr := os.Rename(backupDir, targetDir); rollbackErr != nil {
+ return fmt.Errorf("failed to activate replacement: %w (rollback failed: %v)", err, rollbackErr)
+ }
+ return fmt.Errorf("failed to activate replacement: %w", err)
+ }
+
+ _ = os.RemoveAll(backupDir)
+ _ = os.RemoveAll(stagedWorkspaceRoot)
+ return nil
+}
+
+func reserveTempDirPath(parent, pattern string) (string, error) {
+ tempDir, err := os.MkdirTemp(parent, pattern)
+ if err != nil {
+ return "", err
+ }
+ if err := os.Remove(tempDir); err != nil {
+ return "", err
+ }
+ return tempDir, nil
+}
+
+func enrichSkillInfo(cfg *config.Config, skill skills.SkillInfo) (skillSupportItem, error) {
+ item := skillSupportItem{
+ Name: skill.Name,
+ Path: skill.Path,
+ Source: skill.Source,
+ Description: skill.Description,
+ OriginKind: "builtin",
+ }
+
+ switch skill.Source {
+ case "builtin":
+ item.OriginKind = "builtin"
+ case "global":
+ item.OriginKind = "builtin"
+ case "workspace":
+ meta, err := readInstalledSkillOriginMeta(skill.Path)
+ if err == nil && meta != nil {
+ switch meta.OriginKind {
+ case "manual":
+ item.OriginKind = "manual"
+ item.InstalledAt = meta.InstalledAt
+ case "third_party":
+ item.OriginKind = "third_party"
+ item.RegistryName = meta.Registry
+ item.RegistryURL = registrySkillURLFromMeta(cfg, meta)
+ item.InstalledVersion = meta.InstalledVersion
+ item.InstalledAt = meta.InstalledAt
+ default:
+ if meta.Registry != "" || meta.Slug != "" || meta.InstalledVersion != "" {
+ item.OriginKind = "third_party"
+ item.RegistryName = meta.Registry
+ item.RegistryURL = registrySkillURLFromMeta(cfg, meta)
+ item.InstalledVersion = meta.InstalledVersion
+ item.InstalledAt = meta.InstalledAt
+ } else {
+ item.OriginKind = "builtin"
+ item.InstalledAt = meta.InstalledAt
+ }
+ }
+ } else {
+ item.OriginKind = "builtin"
+ }
+ default:
+ item.OriginKind = "builtin"
+ }
+
+ return item, nil
+}
+
+func readInstalledSkillOriginMeta(skillPath string) (*installedSkillOriginMeta, error) {
+ metaPath := filepath.Join(filepath.Dir(skillPath), ".skill-origin.json")
+ data, err := os.ReadFile(metaPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ var meta installedSkillOriginMeta
+ if err := json.Unmarshal(data, &meta); err != nil {
+ return nil, err
+ }
+ return &meta, nil
+}
+
+func writeSkillOriginMeta(targetDir string, meta installedSkillOriginMeta) error {
+ data, err := json.MarshalIndent(meta, "", " ")
+ if err != nil {
+ return err
+ }
+ return fileutil.WriteFileAtomic(filepath.Join(targetDir, ".skill-origin.json"), data, 0o600)
+}
+
+func registrySkillURL(cfg *config.Config, registryName, slug string) string {
+ switch registryName {
+ case "clawhub":
+ baseURL := strings.TrimRight(cfg.Tools.Skills.Registries.ClawHub.BaseURL, "/")
+ if baseURL == "" {
+ baseURL = "https://clawhub.ai"
+ }
+ return baseURL + "/skills/" + url.PathEscape(slug)
+ default:
+ return ""
+ }
+}
+
+func registrySkillURLFromMeta(cfg *config.Config, meta *installedSkillOriginMeta) string {
+ if meta == nil || meta.Slug == "" {
+ return ""
+ }
+ if meta.RegistryURL != "" {
+ return meta.RegistryURL
+ }
+ if cfg == nil || meta.Registry == "" {
+ return ""
+ }
+ return registrySkillURL(cfg, meta.Registry, meta.Slug)
+}
+
func normalizeImportedSkillName(filename string, content []byte) (string, error) {
+ return normalizeImportedSkillNameWithHint(filename, "", content)
+}
+
+func normalizeImportedSkillNameWithHint(filename, directoryHint string, content []byte) (string, error) {
rawContent := strings.ReplaceAll(string(content), "\r\n", "\n")
rawContent = strings.ReplaceAll(rawContent, "\r", "\n")
metadata, _ := extractImportedSkillMetadata(rawContent)
raw := strings.TrimSpace(metadata["name"])
+ if raw == "" {
+ raw = strings.TrimSpace(directoryHint)
+ }
if raw == "" {
raw = strings.TrimSpace(strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename)))
}
@@ -259,6 +834,210 @@ func normalizeImportedSkillContent(content []byte, skillName string) []byte {
return []byte(builder.String())
}
+func importUploadedSkill(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
+ if isImportedSkillArchive(filename, content) {
+ return importUploadedSkillArchive(cfg, filename, content)
+ }
+ return importUploadedMarkdownSkill(cfg, filename, content)
+}
+
+func importUploadedMarkdownSkill(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
+ skillName, err := normalizeImportedSkillName(filename, content)
+ if err != nil {
+ return nil, http.StatusBadRequest, err
+ }
+
+ normalizedContent := normalizeImportedSkillContent(content, skillName)
+ workspace := cfg.WorkspacePath()
+ skillDir := filepath.Join(workspace, "skills", skillName)
+ skillFile := filepath.Join(skillDir, "SKILL.md")
+
+ if err := ensureWorkspaceSkillDoesNotExist(skillDir); err != nil {
+ return nil, statusCodeForImportedSkillWriteError(err), err
+ }
+ if err := os.MkdirAll(skillDir, 0o755); err != nil {
+ return nil, http.StatusInternalServerError, fmt.Errorf("Failed to create skill directory: %v", err)
+ }
+ if err := fileutil.WriteFileAtomic(skillFile, normalizedContent, 0o644); err != nil {
+ _ = os.RemoveAll(skillDir)
+ return nil, http.StatusInternalServerError, fmt.Errorf("Failed to save skill: %v", err)
+ }
+
+ return finalizeImportedSkill(cfg, skillDir, skillName, false)
+}
+
+func importUploadedSkillArchive(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
+ tmpDir, tempDirErr := os.MkdirTemp("", "picoclaw-skill-import-*")
+ if tempDirErr != nil {
+ return nil, http.StatusInternalServerError, fmt.Errorf("Failed to create temp directory: %v", tempDirErr)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ archivePath := filepath.Join(tmpDir, "import.zip")
+ if writeErr := fileutil.WriteFileAtomic(archivePath, content, 0o600); writeErr != nil {
+ return nil, http.StatusInternalServerError, fmt.Errorf("Failed to stage uploaded archive: %v", writeErr)
+ }
+
+ extractDir := filepath.Join(tmpDir, "extract")
+ if extractErr := utils.ExtractZipFile(archivePath, extractDir); extractErr != nil {
+ return nil, http.StatusBadRequest, fmt.Errorf("invalid ZIP archive: %w", extractErr)
+ }
+
+ skillRoot, err := findImportedSkillRoot(extractDir)
+ if err != nil {
+ return nil, http.StatusBadRequest, err
+ }
+
+ skillFile := filepath.Join(skillRoot, "SKILL.md")
+ skillContent, err := os.ReadFile(skillFile)
+ if err != nil {
+ return nil, http.StatusBadRequest, fmt.Errorf("failed to read SKILL.md from archive: %w", err)
+ }
+
+ directoryHint := ""
+ if filepath.Clean(skillRoot) != filepath.Clean(extractDir) {
+ directoryHint = filepath.Base(skillRoot)
+ }
+ skillName, err := normalizeImportedSkillNameWithHint(filename, directoryHint, skillContent)
+ if err != nil {
+ return nil, http.StatusBadRequest, err
+ }
+
+ workspace := cfg.WorkspacePath()
+ skillDir := filepath.Join(workspace, "skills", skillName)
+ if err := ensureWorkspaceSkillDoesNotExist(skillDir); err != nil {
+ return nil, statusCodeForImportedSkillWriteError(err), err
+ }
+ if err := copyImportedSkillTree(skillRoot, skillDir); err != nil {
+ _ = os.RemoveAll(skillDir)
+ return nil, http.StatusInternalServerError, fmt.Errorf("Failed to save skill: %v", err)
+ }
+
+ normalizedContent := normalizeImportedSkillContent(skillContent, skillName)
+ if err := fileutil.WriteFileAtomic(filepath.Join(skillDir, "SKILL.md"), normalizedContent, 0o644); err != nil {
+ _ = os.RemoveAll(skillDir)
+ return nil, http.StatusInternalServerError, fmt.Errorf("Failed to normalize skill: %v", err)
+ }
+
+ return finalizeImportedSkill(cfg, skillDir, skillName, true)
+}
+
+func isImportedSkillArchive(filename string, content []byte) bool {
+ if strings.EqualFold(filepath.Ext(filename), ".zip") {
+ return true
+ }
+ return len(content) >= 4 && bytes.HasPrefix(content, []byte("PK\x03\x04"))
+}
+
+func ensureWorkspaceSkillDoesNotExist(skillDir string) error {
+ if _, err := os.Stat(skillDir); err == nil {
+ return errImportedSkillExists
+ } else if !os.IsNotExist(err) {
+ return fmt.Errorf("failed to inspect skill directory: %w", err)
+ }
+ return nil
+}
+
+func statusCodeForImportedSkillWriteError(err error) int {
+ if err == nil {
+ return http.StatusOK
+ }
+ if errors.Is(err, errImportedSkillExists) {
+ return http.StatusConflict
+ }
+ return http.StatusInternalServerError
+}
+
+func finalizeImportedSkill(
+ cfg *config.Config,
+ skillDir string,
+ skillName string,
+ requireValidatedSkill bool,
+) (*skillSupportItem, int, error) {
+ if err := persistSkillOriginMeta(skillDir, installedSkillOriginMeta{
+ Version: 1,
+ OriginKind: "manual",
+ InstalledAt: time.Now().UnixMilli(),
+ }); err != nil {
+ _ = os.RemoveAll(skillDir)
+ return nil, http.StatusInternalServerError, fmt.Errorf("Failed to persist skill metadata: %v", err)
+ }
+
+ if importedSkill := findWorkspaceSkillByDirectory(cfg, skillName); importedSkill != nil {
+ return importedSkill, http.StatusOK, nil
+ }
+
+ if requireValidatedSkill {
+ _ = os.RemoveAll(skillDir)
+ return nil, http.StatusBadRequest, fmt.Errorf("imported archive is not a valid skill")
+ }
+
+ return &skillSupportItem{
+ Name: skillName,
+ Path: filepath.Join(skillDir, "SKILL.md"),
+ Source: "workspace",
+ Description: "Imported skill",
+ OriginKind: "manual",
+ }, http.StatusOK, nil
+}
+
+func findImportedSkillRoot(extractDir string) (string, error) {
+ skillFiles := make([]string, 0, 1)
+ err := filepath.WalkDir(extractDir, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr != nil {
+ return walkErr
+ }
+ if d.IsDir() {
+ return nil
+ }
+ if d.Name() == "SKILL.md" {
+ skillFiles = append(skillFiles, path)
+ }
+ return nil
+ })
+ if err != nil {
+ return "", fmt.Errorf("failed to inspect ZIP archive: %w", err)
+ }
+
+ switch len(skillFiles) {
+ case 0:
+ return "", fmt.Errorf("ZIP archive must contain a SKILL.md file")
+ case 1:
+ return filepath.Dir(skillFiles[0]), nil
+ default:
+ return "", fmt.Errorf("ZIP archive must contain exactly one SKILL.md file")
+ }
+}
+
+func copyImportedSkillTree(srcDir, destDir string) error {
+ return filepath.WalkDir(srcDir, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr != nil {
+ return walkErr
+ }
+
+ relPath, err := filepath.Rel(srcDir, path)
+ if err != nil {
+ return err
+ }
+ if relPath == "." {
+ return os.MkdirAll(destDir, 0o755)
+ }
+
+ destPath := filepath.Join(destDir, relPath)
+ info, err := d.Info()
+ if err != nil {
+ return err
+ }
+ if d.IsDir() {
+ return os.MkdirAll(destPath, 0o755)
+ }
+ if !info.Mode().IsRegular() {
+ return fmt.Errorf("archive contains unsupported file %q", relPath)
+ }
+ return fileutil.CopyFile(path, destPath, info.Mode().Perm())
+ })
+}
+
func extractImportedSkillMetadata(raw string) (map[string]string, string) {
matches := importedSkillFrontmatter.FindStringSubmatch(raw)
if len(matches) != 2 {
diff --git a/web/backend/api/skills_test.go b/web/backend/api/skills_test.go
index 3289d5b33..17aef485e 100644
--- a/web/backend/api/skills_test.go
+++ b/web/backend/api/skills_test.go
@@ -1,15 +1,19 @@
package api
import (
+ "archive/zip"
"bytes"
"encoding/json"
+ "errors"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
+ "strconv"
"testing"
+ "time"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -99,8 +103,10 @@ func TestHandleListSkills(t *testing.T) {
}
gotSkills := make(map[string]string, len(resp.Skills))
+ gotOriginKinds := make(map[string]string, len(resp.Skills))
for _, skill := range resp.Skills {
gotSkills[skill.Name] = skill.Source
+ gotOriginKinds[skill.Name] = skill.OriginKind
}
if gotSkills["workspace-skill"] != "workspace" {
t.Fatalf("workspace-skill source = %q, want workspace", gotSkills["workspace-skill"])
@@ -111,6 +117,15 @@ func TestHandleListSkills(t *testing.T) {
if gotSkills["builtin-skill"] != "builtin" {
t.Fatalf("builtin-skill source = %q, want builtin", gotSkills["builtin-skill"])
}
+ if gotOriginKinds["workspace-skill"] != "builtin" {
+ t.Fatalf("workspace-skill origin_kind = %q, want builtin", gotOriginKinds["workspace-skill"])
+ }
+ if gotOriginKinds["global-skill"] != "builtin" {
+ t.Fatalf("global-skill origin_kind = %q, want builtin", gotOriginKinds["global-skill"])
+ }
+ if gotOriginKinds["builtin-skill"] != "builtin" {
+ t.Fatalf("builtin-skill origin_kind = %q, want builtin", gotOriginKinds["builtin-skill"])
+ }
}
func TestHandleGetSkill(t *testing.T) {
@@ -162,6 +177,9 @@ func TestHandleGetSkill(t *testing.T) {
if resp.Name != "viewer-skill" || resp.Source != "workspace" || resp.Description != "Viewable skill" {
t.Fatalf("unexpected response: %#v", resp)
}
+ if resp.OriginKind != "builtin" {
+ t.Fatalf("resp.OriginKind = %q, want builtin", resp.OriginKind)
+ }
if resp.Content != "# Viewer Skill\n\nThis is visible content.\n" {
t.Fatalf("content = %q", resp.Content)
}
@@ -271,6 +289,17 @@ func TestHandleImportSkill(t *testing.T) {
if string(content) != expected {
t.Fatalf("saved skill content mismatch:\n%s", string(content))
}
+ metaContent, err := os.ReadFile(filepath.Join(workspace, "skills", "plain-skill", ".skill-origin.json"))
+ if err != nil {
+ t.Fatalf("ReadFile(origin metadata) error = %v", err)
+ }
+ var originMeta installedSkillOriginMeta
+ if err := json.Unmarshal(metaContent, &originMeta); err != nil {
+ t.Fatalf("Unmarshal(origin metadata) error = %v", err)
+ }
+ if originMeta.OriginKind != "manual" {
+ t.Fatalf("originMeta.OriginKind = %q, want manual", originMeta.OriginKind)
+ }
rec2 := httptest.NewRecorder()
req2 := httptest.NewRequest(http.MethodGet, "/api/skills", nil)
@@ -293,6 +322,174 @@ func TestHandleImportSkill(t *testing.T) {
}
}
+func TestHandleImportSkillZip(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ zipContent := buildSkillZip(t, map[string]string{
+ "Wrapped Skill/SKILL.md": "---\nname: wrapped-skill\ndescription: Wrapped skill\n---\n# Wrapped Skill\n\nUse this skill from zip.\n",
+ "Wrapped Skill/docs/README.md": "# Extra file\n",
+ })
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ part, createErr := writer.CreateFormFile("file", "Wrapped Skill.zip")
+ if createErr != nil {
+ t.Fatalf("CreateFormFile() error = %v", createErr)
+ }
+ if _, writeErr := part.Write(zipContent); writeErr != nil {
+ t.Fatalf("Write(zipContent) error = %v", writeErr)
+ }
+ if closeErr := writer.Close(); closeErr != nil {
+ t.Fatalf("Close() error = %v", closeErr)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/import", &body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ skillDir := filepath.Join(workspace, "skills", "wrapped-skill")
+ skillFile := filepath.Join(skillDir, "SKILL.md")
+ content, err := os.ReadFile(skillFile)
+ if err != nil {
+ t.Fatalf("ReadFile() error = %v", err)
+ }
+ expected := "---\nname: wrapped-skill\ndescription: Wrapped skill\n---\n\n# Wrapped Skill\n\nUse this skill from zip.\n"
+ if string(content) != expected {
+ t.Fatalf("saved skill content mismatch:\n%s", string(content))
+ }
+
+ extraFile := filepath.Join(skillDir, "docs", "README.md")
+ extraContent, err := os.ReadFile(extraFile)
+ if err != nil {
+ t.Fatalf("ReadFile(extra file) error = %v", err)
+ }
+ if string(extraContent) != "# Extra file\n" {
+ t.Fatalf("extra file content = %q", string(extraContent))
+ }
+}
+
+func TestHandleImportSkillZipRejectsArchiveWithoutSkill(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ zipContent := buildSkillZip(t, map[string]string{
+ "README.md": "# Not a skill\n",
+ })
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ part, err := writer.CreateFormFile("file", "invalid.zip")
+ if err != nil {
+ t.Fatalf("CreateFormFile() error = %v", err)
+ }
+ if _, err := part.Write(zipContent); err != nil {
+ t.Fatalf("Write(zipContent) error = %v", err)
+ }
+ if err := writer.Close(); err != nil {
+ t.Fatalf("Close() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/import", &body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
+ }
+ if _, err := os.Stat(filepath.Join(workspace, "skills", "invalid")); !os.IsNotExist(err) {
+ t.Fatalf("invalid archive should not leave behind a skill dir, stat err=%v", err)
+ }
+}
+
+func TestHandleImportSkillRollsBackOnOriginMetadataWriteFailure(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ previousPersist := persistSkillOriginMeta
+ persistSkillOriginMeta = func(targetDir string, meta installedSkillOriginMeta) error {
+ return errors.New("forced metadata failure")
+ }
+ defer func() {
+ persistSkillOriginMeta = previousPersist
+ }()
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ part, err := writer.CreateFormFile("file", "Rollback Skill.md")
+ if err != nil {
+ t.Fatalf("CreateFormFile() error = %v", err)
+ }
+ if _, err := io.WriteString(part, "# Rollback Skill\n"); err != nil {
+ t.Fatalf("WriteString() error = %v", err)
+ }
+ if err := writer.Close(); err != nil {
+ t.Fatalf("Close() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/import", &body)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusInternalServerError {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusInternalServerError, rec.Body.String())
+ }
+
+ skillDir := filepath.Join(workspace, "skills", "rollback-skill")
+ if _, err := os.Stat(skillDir); !os.IsNotExist(err) {
+ t.Fatalf("skill directory should be removed after metadata write failure, stat err=%v", err)
+ }
+}
+
func TestHandleDeleteSkill(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
@@ -334,3 +531,888 @@ func TestHandleDeleteSkill(t *testing.T) {
t.Fatalf("skill directory should be removed, stat err=%v", err)
}
}
+
+func TestHandleSearchSkills(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error = %v", err)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ if err := os.MkdirAll(filepath.Join(workspace, "skills", "github"), 0o755); err != nil {
+ t.Fatalf("MkdirAll() error = %v", err)
+ }
+ if err := os.WriteFile(
+ filepath.Join(workspace, "skills", "github", "SKILL.md"),
+ []byte("---\nname: github\ndescription: Installed GitHub skill\n---\n# GitHub\n"),
+ 0o644,
+ ); err != nil {
+ t.Fatalf("WriteFile() error = %v", err)
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/v1/search" {
+ http.NotFound(w, r)
+ return
+ }
+ if got := r.URL.Query().Get("q"); got != "github" {
+ t.Fatalf("query = %q, want github", got)
+ }
+ json.NewEncoder(w).Encode(map[string]any{
+ "results": []map[string]any{
+ {
+ "score": 0.95,
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub integration skill",
+ "version": "1.2.3",
+ },
+ {
+ "score": 0.87,
+ "slug": "jira",
+ "displayName": "Jira",
+ "summary": "Issue tracker skill",
+ "version": "0.9.0",
+ },
+ },
+ })
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/skills/search?q=github&limit=5", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var resp skillSearchResponse
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if resp.Limit != 5 {
+ t.Fatalf("limit = %d, want 5", resp.Limit)
+ }
+ if resp.Offset != 0 {
+ t.Fatalf("offset = %d, want 0", resp.Offset)
+ }
+ if resp.HasMore {
+ t.Fatalf("has_more = true, want false")
+ }
+ if len(resp.Results) != 2 {
+ t.Fatalf("results count = %d, want 2", len(resp.Results))
+ }
+ if resp.Results[0].URL != server.URL+"/skills/github" {
+ t.Fatalf("first result URL = %q, want %q", resp.Results[0].URL, server.URL+"/skills/github")
+ }
+ if !resp.Results[0].Installed || resp.Results[0].InstalledName != "github" {
+ t.Fatalf("first result should be treated as occupying the workspace slug, got %#v", resp.Results[0])
+ }
+ if resp.Results[1].Installed {
+ t.Fatalf("second result should not be installed, got %#v", resp.Results[1])
+ }
+}
+
+func TestHandleSearchSkillsPagination(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error = %v", err)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/v1/search" {
+ http.NotFound(w, r)
+ return
+ }
+ if got := r.URL.Query().Get("limit"); got != "5" {
+ t.Fatalf("limit = %q, want 5", got)
+ }
+ json.NewEncoder(w).Encode(map[string]any{
+ "results": []map[string]any{
+ {
+ "score": 0.99,
+ "slug": "skill-1",
+ "displayName": "Skill 1",
+ "summary": "Summary 1",
+ "version": "1.0.0",
+ },
+ {
+ "score": 0.98,
+ "slug": "skill-2",
+ "displayName": "Skill 2",
+ "summary": "Summary 2",
+ "version": "1.0.0",
+ },
+ {
+ "score": 0.97,
+ "slug": "skill-3",
+ "displayName": "Skill 3",
+ "summary": "Summary 3",
+ "version": "1.0.0",
+ },
+ {
+ "score": 0.96,
+ "slug": "skill-4",
+ "displayName": "Skill 4",
+ "summary": "Summary 4",
+ "version": "1.0.0",
+ },
+ },
+ })
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/skills/search?q=github&limit=2&offset=2", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var resp skillSearchResponse
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if resp.Limit != 2 {
+ t.Fatalf("limit = %d, want 2", resp.Limit)
+ }
+ if resp.Offset != 2 {
+ t.Fatalf("offset = %d, want 2", resp.Offset)
+ }
+ if resp.HasMore {
+ t.Fatalf("has_more = true, want false")
+ }
+ if len(resp.Results) != 2 {
+ t.Fatalf("results count = %d, want 2", len(resp.Results))
+ }
+ if resp.Results[0].Slug != "skill-3" || resp.Results[1].Slug != "skill-4" {
+ t.Fatalf("unexpected paged results: %#v", resp.Results)
+ }
+ if resp.NextOffset != 0 {
+ t.Fatalf("next_offset = %d, want 0", resp.NextOffset)
+ }
+}
+
+func TestHandleSearchSkillsClampsRegistryFanout(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, err := config.LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error = %v", err)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api/v1/search" {
+ http.NotFound(w, r)
+ return
+ }
+ if got := r.URL.Query().Get("limit"); got != strconv.Itoa(maxRegistrySearchFanout) {
+ t.Fatalf("limit = %q, want %d", got, maxRegistrySearchFanout)
+ }
+ json.NewEncoder(w).Encode(map[string]any{
+ "results": []map[string]any{
+ {
+ "score": 0.99,
+ "slug": "skill-1",
+ "displayName": "Skill 1",
+ "summary": "Summary 1",
+ "version": "1.0.0",
+ },
+ },
+ })
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if err := config.SaveConfig(configPath, cfg); err != nil {
+ t.Fatalf("SaveConfig() error = %v", err)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/skills/search?q=github&limit=20&offset=100000", nil)
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var resp skillSearchResponse
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if len(resp.Results) != 0 {
+ t.Fatalf("results count = %d, want 0", len(resp.Results))
+ }
+}
+
+func TestHandleInstallSkill(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ zipContent := buildSkillZip(t, map[string]string{
+ "SKILL.md": "---\nname: github\ndescription: GitHub registry skill\n---\n# GitHub\n\nUse this skill.\n",
+ })
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/search":
+ json.NewEncoder(w).Encode(map[string]any{
+ "results": []map[string]any{
+ {
+ "score": 0.95,
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub registry skill",
+ "version": "1.2.3",
+ },
+ },
+ })
+ case "/api/v1/skills/github":
+ json.NewEncoder(w).Encode(map[string]any{
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub registry skill",
+ "latestVersion": map[string]any{
+ "version": "1.2.3",
+ },
+ "moderation": map[string]any{
+ "isMalwareBlocked": false,
+ "isSuspicious": false,
+ },
+ })
+ case "/api/v1/download":
+ if got := r.URL.Query().Get("slug"); got != "github" {
+ t.Fatalf("slug = %q, want github", got)
+ }
+ if got := r.URL.Query().Get("version"); got != "1.2.3" {
+ t.Fatalf("version = %q, want 1.2.3", got)
+ }
+ w.Header().Set("Content-Type", "application/zip")
+ _, _ = w.Write(zipContent)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ body, err := json.Marshal(installSkillRequest{
+ Slug: "github",
+ Registry: "clawhub",
+ })
+ if err != nil {
+ t.Fatalf("Marshal() error = %v", err)
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/install", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var resp installSkillResponse
+ if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("Unmarshal() error = %v", err)
+ }
+ if resp.Status != "ok" || resp.Version != "1.2.3" || resp.InstalledSkill == nil {
+ t.Fatalf("unexpected response: %#v", resp)
+ }
+ if resp.InstalledSkill.OriginKind != "third_party" {
+ t.Fatalf("resp.InstalledSkill.OriginKind = %q, want third_party", resp.InstalledSkill.OriginKind)
+ }
+ if resp.InstalledSkill.RegistryURL != server.URL+"/skills/github" {
+ t.Fatalf(
+ "resp.InstalledSkill.RegistryURL = %q, want %q",
+ resp.InstalledSkill.RegistryURL,
+ server.URL+"/skills/github",
+ )
+ }
+
+ skillFile := filepath.Join(workspace, "skills", "github", "SKILL.md")
+ if _, err := os.Stat(skillFile); err != nil {
+ t.Fatalf("installed skill file missing: %v", err)
+ }
+ if _, err := os.Stat(filepath.Join(workspace, "skills", "github", ".skill-origin.json")); err != nil {
+ t.Fatalf("origin metadata missing: %v", err)
+ }
+
+ detailRec := httptest.NewRecorder()
+ detailReq := httptest.NewRequest(http.MethodGet, "/api/skills/github", nil)
+ mux.ServeHTTP(detailRec, detailReq)
+
+ if detailRec.Code != http.StatusOK {
+ t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String())
+ }
+
+ var detailResp skillDetailResponse
+ if err := json.Unmarshal(detailRec.Body.Bytes(), &detailResp); err != nil {
+ t.Fatalf("Unmarshal(detail response) error = %v", err)
+ }
+ if detailResp.RegistryURL != server.URL+"/skills/github" {
+ t.Fatalf("detailResp.RegistryURL = %q, want %q", detailResp.RegistryURL, server.URL+"/skills/github")
+ }
+
+ searchRec := httptest.NewRecorder()
+ searchReq := httptest.NewRequest(http.MethodGet, "/api/skills/search?q=github&limit=5", nil)
+ mux.ServeHTTP(searchRec, searchReq)
+
+ if searchRec.Code != http.StatusOK {
+ t.Fatalf("search status = %d, want %d, body=%s", searchRec.Code, http.StatusOK, searchRec.Body.String())
+ }
+
+ var searchResp skillSearchResponse
+ if err := json.Unmarshal(searchRec.Body.Bytes(), &searchResp); err != nil {
+ t.Fatalf("Unmarshal(search response) error = %v", err)
+ }
+ if len(searchResp.Results) != 1 {
+ t.Fatalf("search results count = %d, want 1", len(searchResp.Results))
+ }
+ if !searchResp.Results[0].Installed || searchResp.Results[0].InstalledName != "github" {
+ t.Fatalf("search result should be treated as installed after registry install, got %#v", searchResp.Results[0])
+ }
+}
+
+func TestHandleInstallSkillForcePreservesExistingSkillOnFailure(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ skillDir := filepath.Join(workspace, "skills", "github")
+ if err := os.MkdirAll(skillDir, 0o755); err != nil {
+ t.Fatalf("MkdirAll() error = %v", err)
+ }
+ oldContent := []byte("---\nname: github\ndescription: Existing skill\n---\n# Existing\n")
+ if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), oldContent, 0o644); err != nil {
+ t.Fatalf("WriteFile() error = %v", err)
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/skills/github":
+ json.NewEncoder(w).Encode(map[string]any{
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub registry skill",
+ "latestVersion": map[string]any{
+ "version": "1.2.3",
+ },
+ "moderation": map[string]any{
+ "isMalwareBlocked": false,
+ "isSuspicious": false,
+ },
+ })
+ case "/api/v1/download":
+ http.Error(w, "upstream download failed", http.StatusBadGateway)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ body, err := json.Marshal(installSkillRequest{
+ Slug: "github",
+ Registry: "clawhub",
+ Force: true,
+ })
+ if err != nil {
+ t.Fatalf("Marshal() error = %v", err)
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/install", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusBadGateway {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadGateway, rec.Body.String())
+ }
+
+ gotContent, err := os.ReadFile(filepath.Join(skillDir, "SKILL.md"))
+ if err != nil {
+ t.Fatalf("ReadFile() error = %v", err)
+ }
+ if !bytes.Equal(gotContent, oldContent) {
+ t.Fatalf("existing skill should remain unchanged, got:\n%s", string(gotContent))
+ }
+}
+
+func TestHandleInstallSkillRollsBackOnOriginMetadataWriteFailure(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ zipContent := buildSkillZip(t, map[string]string{
+ "SKILL.md": "---\nname: github\ndescription: GitHub registry skill\n---\n# GitHub\n",
+ })
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/skills/github":
+ json.NewEncoder(w).Encode(map[string]any{
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub registry skill",
+ "latestVersion": map[string]any{
+ "version": "1.2.3",
+ },
+ "moderation": map[string]any{
+ "isMalwareBlocked": false,
+ "isSuspicious": false,
+ },
+ })
+ case "/api/v1/download":
+ w.Header().Set("Content-Type", "application/zip")
+ _, _ = w.Write(zipContent)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ previousPersist := persistSkillOriginMeta
+ persistSkillOriginMeta = func(targetDir string, meta installedSkillOriginMeta) error {
+ return errors.New("forced metadata failure")
+ }
+ defer func() {
+ persistSkillOriginMeta = previousPersist
+ }()
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ body, err := json.Marshal(installSkillRequest{
+ Slug: "github",
+ Registry: "clawhub",
+ })
+ if err != nil {
+ t.Fatalf("Marshal() error = %v", err)
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/install", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusInternalServerError {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusInternalServerError, rec.Body.String())
+ }
+
+ skillDir := filepath.Join(workspace, "skills", "github")
+ if _, err := os.Stat(skillDir); !os.IsNotExist(err) {
+ t.Fatalf("skill directory should be removed after metadata write failure, stat err=%v", err)
+ }
+}
+
+func TestHandleInstallSkillSerializesConcurrentRequests(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ zipContent := buildSkillZip(t, map[string]string{
+ "SKILL.md": "---\nname: github\ndescription: GitHub registry skill\n---\n# GitHub\n",
+ })
+
+ downloadStarted := make(chan struct{}, 2)
+ releaseFirstDownload := make(chan struct{})
+ downloadCount := 0
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/skills/github":
+ json.NewEncoder(w).Encode(map[string]any{
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub registry skill",
+ "latestVersion": map[string]any{
+ "version": "1.2.3",
+ },
+ "moderation": map[string]any{
+ "isMalwareBlocked": false,
+ "isSuspicious": false,
+ },
+ })
+ case "/api/v1/download":
+ downloadCount++
+ downloadStarted <- struct{}{}
+ if downloadCount == 1 {
+ <-releaseFirstDownload
+ }
+ w.Header().Set("Content-Type", "application/zip")
+ _, _ = w.Write(zipContent)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ body, err := json.Marshal(installSkillRequest{
+ Slug: "github",
+ Registry: "clawhub",
+ })
+ if err != nil {
+ t.Fatalf("Marshal() error = %v", err)
+ }
+
+ type installResult struct {
+ code int
+ body string
+ }
+ results := make(chan installResult, 2)
+ startInstall := func() {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/install", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(rec, req)
+ results <- installResult{
+ code: rec.Code,
+ body: rec.Body.String(),
+ }
+ }
+
+ go startInstall()
+
+ select {
+ case <-downloadStarted:
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for first install download to start")
+ }
+
+ go startInstall()
+
+ select {
+ case <-downloadStarted:
+ t.Fatal("second install should not reach registry download before the first request completes")
+ case <-time.After(200 * time.Millisecond):
+ }
+
+ close(releaseFirstDownload)
+
+ firstResult := <-results
+ secondResult := <-results
+
+ codes := map[int]int{
+ firstResult.code: 1,
+ secondResult.code: 1,
+ }
+ if codes[http.StatusOK] != 1 || codes[http.StatusConflict] != 1 {
+ t.Fatalf(
+ "unexpected install results: first=(%d, %q) second=(%d, %q)",
+ firstResult.code,
+ firstResult.body,
+ secondResult.code,
+ secondResult.body,
+ )
+ }
+}
+
+func TestHandleImportSkillWaitsForConcurrentInstall(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ zipContent := buildSkillZip(t, map[string]string{
+ "SKILL.md": "---\nname: github\ndescription: GitHub registry skill\n---\n# GitHub\n",
+ })
+
+ downloadStarted := make(chan struct{}, 1)
+ releaseDownload := make(chan struct{})
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/skills/github":
+ json.NewEncoder(w).Encode(map[string]any{
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub registry skill",
+ "latestVersion": map[string]any{
+ "version": "1.2.3",
+ },
+ "moderation": map[string]any{
+ "isMalwareBlocked": false,
+ "isSuspicious": false,
+ },
+ })
+ case "/api/v1/download":
+ downloadStarted <- struct{}{}
+ <-releaseDownload
+ w.Header().Set("Content-Type", "application/zip")
+ _, _ = w.Write(zipContent)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ installBody, err := json.Marshal(installSkillRequest{
+ Slug: "github",
+ Registry: "clawhub",
+ })
+ if err != nil {
+ t.Fatalf("Marshal() error = %v", err)
+ }
+
+ type result struct {
+ code int
+ body string
+ }
+ installResults := make(chan result, 1)
+ importResults := make(chan result, 1)
+
+ go func() {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/install", bytes.NewReader(installBody))
+ req.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(rec, req)
+ installResults <- result{code: rec.Code, body: rec.Body.String()}
+ }()
+
+ select {
+ case <-downloadStarted:
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for install download to start")
+ }
+
+ var importBody bytes.Buffer
+ writer := multipart.NewWriter(&importBody)
+ part, err := writer.CreateFormFile("file", "github.md")
+ if err != nil {
+ t.Fatalf("CreateFormFile() error = %v", err)
+ }
+ if _, err := io.WriteString(part, "# GitHub\n"); err != nil {
+ t.Fatalf("WriteString() error = %v", err)
+ }
+ if err := writer.Close(); err != nil {
+ t.Fatalf("Close() error = %v", err)
+ }
+
+ go func() {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/import", &importBody)
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ mux.ServeHTTP(rec, req)
+ importResults <- result{code: rec.Code, body: rec.Body.String()}
+ }()
+
+ select {
+ case got := <-importResults:
+ t.Fatalf("import should wait for the install lock, got early response (%d, %q)", got.code, got.body)
+ case <-time.After(200 * time.Millisecond):
+ }
+
+ close(releaseDownload)
+
+ installResult := <-installResults
+ importResult := <-importResults
+
+ if installResult.code != http.StatusOK {
+ t.Fatalf("install status = %d, want %d, body=%s", installResult.code, http.StatusOK, installResult.body)
+ }
+ if importResult.code != http.StatusConflict {
+ t.Fatalf("import status = %d, want %d, body=%s", importResult.code, http.StatusConflict, importResult.body)
+ }
+}
+
+func TestHandleInstallSkillRejectsInvalidArchive(t *testing.T) {
+ configPath, cleanup := setupOAuthTestEnv(t)
+ defer cleanup()
+
+ cfg, loadErr := config.LoadConfig(configPath)
+ if loadErr != nil {
+ t.Fatalf("LoadConfig() error = %v", loadErr)
+ }
+ workspace := filepath.Join(t.TempDir(), "workspace")
+ cfg.Agents.Defaults.Workspace = workspace
+
+ zipContent := buildSkillZip(t, map[string]string{
+ "README.md": "# Not a skill\n",
+ })
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/skills/github":
+ json.NewEncoder(w).Encode(map[string]any{
+ "slug": "github",
+ "displayName": "GitHub",
+ "summary": "GitHub registry skill",
+ "latestVersion": map[string]any{
+ "version": "1.2.3",
+ },
+ "moderation": map[string]any{
+ "isMalwareBlocked": false,
+ "isSuspicious": false,
+ },
+ })
+ case "/api/v1/download":
+ w.Header().Set("Content-Type", "application/zip")
+ _, _ = w.Write(zipContent)
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg.Tools.Skills.Registries.ClawHub.BaseURL = server.URL
+ if saveErr := config.SaveConfig(configPath, cfg); saveErr != nil {
+ t.Fatalf("SaveConfig() error = %v", saveErr)
+ }
+
+ h := NewHandler(configPath)
+ mux := http.NewServeMux()
+ h.RegisterRoutes(mux)
+
+ body, err := json.Marshal(installSkillRequest{
+ Slug: "github",
+ Registry: "clawhub",
+ })
+ if err != nil {
+ t.Fatalf("Marshal() error = %v", err)
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/skills/install", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ mux.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusBadGateway {
+ t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadGateway, rec.Body.String())
+ }
+
+ skillDir := filepath.Join(workspace, "skills", "github")
+ if _, err := os.Stat(skillDir); !os.IsNotExist(err) {
+ t.Fatalf("invalid installed archive should be removed, stat err=%v", err)
+ }
+}
+
+func buildSkillZip(t *testing.T, files map[string]string) []byte {
+ t.Helper()
+
+ var buf bytes.Buffer
+ zipWriter := zip.NewWriter(&buf)
+ for name, content := range files {
+ writer, err := zipWriter.Create(name)
+ if err != nil {
+ t.Fatalf("Create(%q) error = %v", name, err)
+ }
+ if _, err := io.WriteString(writer, content); err != nil {
+ t.Fatalf("WriteString(%q) error = %v", name, err)
+ }
+ }
+ if err := zipWriter.Close(); err != nil {
+ t.Fatalf("Close() error = %v", err)
+ }
+ return buf.Bytes()
+}
diff --git a/web/backend/api/update.go b/web/backend/api/update.go
new file mode 100644
index 000000000..2ba862631
--- /dev/null
+++ b/web/backend/api/update.go
@@ -0,0 +1,52 @@
+package api
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/sipeed/picoclaw/pkg/updater"
+)
+
+// registerUpdateRoutes registers the self-update endpoint.
+func (h *Handler) registerUpdateRoutes(mux *http.ServeMux) {
+ mux.HandleFunc("/api/update", h.handleUpdate)
+}
+
+type updateRequest struct {
+ URL string `json:"url,omitempty"`
+ Binary string `json:"binary,omitempty"`
+}
+
+type updateResponse struct {
+ Status string `json:"status"`
+ Message string `json:"message,omitempty"`
+}
+
+func (h *Handler) handleUpdate(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ _ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: "method not allowed"})
+ return
+ }
+
+ dec := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<20))
+ var req updateRequest
+ if err := dec.Decode(&req); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ _ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: "invalid request body"})
+ return
+ }
+
+ binary := req.Binary
+ if binary == "" {
+ binary = "picoclaw-launcher"
+ }
+
+ if err := updater.UpdateSelfFromRelease(req.URL, "", "", binary); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: err.Error()})
+ return
+ }
+
+ _ = json.NewEncoder(w).Encode(updateResponse{Status: "ok", Message: "update applied; restart to use new version"})
+}
diff --git a/web/backend/launcherconfig/config.go b/web/backend/launcherconfig/config.go
index b8465ef74..60c369f4f 100644
--- a/web/backend/launcherconfig/config.go
+++ b/web/backend/launcherconfig/config.go
@@ -23,11 +23,20 @@ const (
dashboardTokenEntropyBytes = 32
)
+type DashboardTokenSource string
+
+const (
+ DashboardTokenSourceEnv DashboardTokenSource = "env"
+ DashboardTokenSourceConfig DashboardTokenSource = "config"
+ DashboardTokenSourceRandom DashboardTokenSource = "random"
+)
+
// Config stores launch parameters for the web backend service.
type Config struct {
- Port int `json:"port"`
- Public bool `json:"public"`
- AllowedCIDRs []string `json:"allowed_cidrs,omitempty"`
+ Port int `json:"port"`
+ Public bool `json:"public"`
+ AllowedCIDRs []string `json:"allowed_cidrs,omitempty"`
+ LauncherToken string `json:"launcher_token,omitempty"`
}
// Default returns default launcher settings.
@@ -49,23 +58,30 @@ func Validate(cfg Config) error {
}
// EnsureDashboardSecrets returns signing key bytes and the effective dashboard token for this
-// process. The signing key is freshly random each call; the token comes from the environment
-// variable PICOCLAW_LAUNCHER_TOKEN when set, otherwise a new random token.
-func EnsureDashboardSecrets() (effectiveToken string, signingKey []byte, newRandomDashboardToken bool, err error) {
+// process. The signing key is freshly random each call; the token comes from
+// PICOCLAW_LAUNCHER_TOKEN when set, otherwise launcher-config.json launcher_token,
+// otherwise a new random token.
+func EnsureDashboardSecrets(
+ cfg Config,
+) (effectiveToken string, signingKey []byte, source DashboardTokenSource, err error) {
signingKey = make([]byte, dashboardSigningKeyBytes)
if _, err = rand.Read(signingKey); err != nil {
- return "", nil, false, err
+ return "", nil, "", err
}
effectiveToken = strings.TrimSpace(os.Getenv("PICOCLAW_LAUNCHER_TOKEN"))
if effectiveToken != "" {
- return effectiveToken, signingKey, false, nil
+ return effectiveToken, signingKey, DashboardTokenSourceEnv, nil
+ }
+ effectiveToken = strings.TrimSpace(cfg.LauncherToken)
+ if effectiveToken != "" {
+ return effectiveToken, signingKey, DashboardTokenSourceConfig, nil
}
tok, genErr := randomDashboardToken()
if genErr != nil {
- return "", nil, false, genErr
+ return "", nil, "", genErr
}
- return tok, signingKey, true, nil
+ return tok, signingKey, DashboardTokenSourceRandom, nil
}
func randomDashboardToken() (string, error) {
@@ -124,6 +140,7 @@ func Load(path string, fallback Config) (Config, error) {
return Config{}, err
}
cfg.AllowedCIDRs = NormalizeCIDRs(cfg.AllowedCIDRs)
+ cfg.LauncherToken = strings.TrimSpace(cfg.LauncherToken)
if err := Validate(cfg); err != nil {
return Config{}, err
}
@@ -133,6 +150,7 @@ func Load(path string, fallback Config) (Config, error) {
// Save writes launcher settings to disk.
func Save(path string, cfg Config) error {
cfg.AllowedCIDRs = NormalizeCIDRs(cfg.AllowedCIDRs)
+ cfg.LauncherToken = strings.TrimSpace(cfg.LauncherToken)
if err := Validate(cfg); err != nil {
return err
}
diff --git a/web/backend/launcherconfig/config_test.go b/web/backend/launcherconfig/config_test.go
index 4e8a54e41..528116417 100644
--- a/web/backend/launcherconfig/config_test.go
+++ b/web/backend/launcherconfig/config_test.go
@@ -25,9 +25,10 @@ func TestSaveAndLoadRoundTrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "launcher-config.json")
want := Config{
- Port: 18080,
- Public: true,
- AllowedCIDRs: []string{"192.168.1.0/24", "10.0.0.0/8"},
+ Port: 18080,
+ Public: true,
+ AllowedCIDRs: []string{"192.168.1.0/24", "10.0.0.0/8"},
+ LauncherToken: "saved-launcher-token",
}
if err := Save(path, want); err != nil {
@@ -40,6 +41,9 @@ func TestSaveAndLoadRoundTrip(t *testing.T) {
if got.Port != want.Port || got.Public != want.Public {
t.Fatalf("Load() = %+v, want %+v", got, want)
}
+ if got.LauncherToken != want.LauncherToken {
+ t.Fatalf("launcher_token = %q, want %q", got.LauncherToken, want.LauncherToken)
+ }
if len(got.AllowedCIDRs) != len(want.AllowedCIDRs) {
t.Fatalf("allowed_cidrs len = %d, want %d", len(got.AllowedCIDRs), len(want.AllowedCIDRs))
}
@@ -80,24 +84,24 @@ func TestValidateRejectsInvalidCIDR(t *testing.T) {
func TestEnsureDashboardSecrets_GeneratesEphemeral(t *testing.T) {
t.Setenv("PICOCLAW_LAUNCHER_TOKEN", "")
- tok, key, newTok, err := EnsureDashboardSecrets()
+ tok, key, source, err := EnsureDashboardSecrets(Default())
if err != nil {
t.Fatalf("EnsureDashboardSecrets() error = %v", err)
}
- if !newTok || tok == "" || len(key) != dashboardSigningKeyBytes {
- t.Fatalf("unexpected first call: newTok=%v tok=%q keyLen=%d", newTok, tok, len(key))
+ if source != DashboardTokenSourceRandom || tok == "" || len(key) != dashboardSigningKeyBytes {
+ t.Fatalf("unexpected first call: source=%q tok=%q keyLen=%d", source, tok, len(key))
}
mac := middleware.SessionCookieValue(key, tok)
if mac == "" {
t.Fatal("empty session mac")
}
- tok2, key2, newTok2, err := EnsureDashboardSecrets()
+ tok2, key2, source2, err := EnsureDashboardSecrets(Default())
if err != nil {
t.Fatalf("EnsureDashboardSecrets() second error = %v", err)
}
- if !newTok2 {
- t.Fatal("second call without env should generate another random token")
+ if source2 != DashboardTokenSourceRandom {
+ t.Fatalf("second call source = %q, want %q", source2, DashboardTokenSourceRandom)
}
if tok2 == tok {
t.Fatal("expected a new random dashboard token")
@@ -110,15 +114,30 @@ func TestEnsureDashboardSecrets_GeneratesEphemeral(t *testing.T) {
func TestEnsureDashboardSecrets_EnvOverridesGenerated(t *testing.T) {
t.Setenv("PICOCLAW_LAUNCHER_TOKEN", "env-only-token-override")
- tok, _, newTok, err := EnsureDashboardSecrets()
+ tok, _, source, err := EnsureDashboardSecrets(Config{LauncherToken: "config-token"})
if err != nil {
t.Fatalf("EnsureDashboardSecrets() error = %v", err)
}
if tok != "env-only-token-override" {
t.Fatalf("token = %q, want env value", tok)
}
- if newTok {
- t.Fatal("newRandomDashboardToken should be false when env is set")
+ if source != DashboardTokenSourceEnv {
+ t.Fatalf("source = %q, want %q", source, DashboardTokenSourceEnv)
+ }
+}
+
+func TestEnsureDashboardSecrets_ConfigOverridesGenerated(t *testing.T) {
+ t.Setenv("PICOCLAW_LAUNCHER_TOKEN", "")
+
+ tok, _, source, err := EnsureDashboardSecrets(Config{LauncherToken: "config-token"})
+ if err != nil {
+ t.Fatalf("EnsureDashboardSecrets() error = %v", err)
+ }
+ if tok != "config-token" {
+ t.Fatalf("token = %q, want config value", tok)
+ }
+ if source != DashboardTokenSourceConfig {
+ t.Fatalf("source = %q, want %q", source, DashboardTokenSourceConfig)
}
}
diff --git a/web/backend/main.go b/web/backend/main.go
index 218e3bfce..5e9f3315f 100644
--- a/web/backend/main.go
+++ b/web/backend/main.go
@@ -59,6 +59,13 @@ func shouldEnableLauncherFileLogging(enableConsole, debug bool) bool {
return !enableConsole || debug
}
+func dashboardTokenConfigHelpPath(source launcherconfig.DashboardTokenSource, launcherPath string) string {
+ if source != launcherconfig.DashboardTokenSourceConfig {
+ return ""
+ }
+ return launcherPath
+}
+
func main() {
port := flag.String("port", "18800", "Port to listen on")
public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only")
@@ -195,7 +202,9 @@ func main() {
logger.Fatalf("Invalid port %q: %v", effectivePort, err)
}
- dashboardToken, dashboardSigningKey, newDashTok, dashErr := launcherconfig.EnsureDashboardSecrets()
+ dashboardToken, dashboardSigningKey, dashboardTokenSource, dashErr := launcherconfig.EnsureDashboardSecrets(
+ launcherCfg,
+ )
if dashErr != nil {
logger.Fatalf("Dashboard auth setup failed: %v", dashErr)
}
@@ -223,6 +232,7 @@ func main() {
TokenHelp: api.LauncherAuthTokenHelp{
EnvVarName: "PICOCLAW_LAUNCHER_TOKEN",
LogFileAbs: tokenLogFileAbs,
+ ConfigFileAbs: dashboardTokenConfigHelpPath(dashboardTokenSource, launcherPath),
TrayCopyMenu: trayOffersDashboardTokenCopy(),
ConsoleStdout: enableConsole,
},
@@ -272,19 +282,26 @@ func main() {
}
}
fmt.Println()
- if newDashTok {
+ switch dashboardTokenSource {
+ case launcherconfig.DashboardTokenSourceRandom:
fmt.Printf(" Dashboard token (this run): %s\n", dashboardToken)
- } else if os.Getenv("PICOCLAW_LAUNCHER_TOKEN") != "" {
+ case launcherconfig.DashboardTokenSourceEnv:
fmt.Printf(" Dashboard token: %s (from PICOCLAW_LAUNCHER_TOKEN)\n", dashboardToken)
+ case launcherconfig.DashboardTokenSourceConfig:
+ fmt.Printf(" Dashboard token: %s (from %s)\n", dashboardToken, launcherPath)
}
fmt.Println()
}
- if os.Getenv("PICOCLAW_LAUNCHER_TOKEN") != "" {
+ switch dashboardTokenSource {
+ case launcherconfig.DashboardTokenSourceEnv:
logger.InfoC("web", "Dashboard token: environment PICOCLAW_LAUNCHER_TOKEN")
- }
- if !enableConsole && newDashTok {
- logger.InfoC("web", "Dashboard token (this run): "+dashboardToken)
+ case launcherconfig.DashboardTokenSourceConfig:
+ logger.InfoC("web", fmt.Sprintf("Dashboard token: configured in %s", launcherPath))
+ case launcherconfig.DashboardTokenSourceRandom:
+ if !enableConsole {
+ logger.InfoC("web", "Dashboard token (this run): "+dashboardToken)
+ }
}
// Log startup info to file
diff --git a/web/backend/main_test.go b/web/backend/main_test.go
index c24a53704..f69705179 100644
--- a/web/backend/main_test.go
+++ b/web/backend/main_test.go
@@ -1,6 +1,10 @@
package main
-import "testing"
+import (
+ "testing"
+
+ "github.com/sipeed/picoclaw/web/backend/launcherconfig"
+)
func TestShouldEnableLauncherFileLogging(t *testing.T) {
tests := []struct {
@@ -29,3 +33,37 @@ func TestShouldEnableLauncherFileLogging(t *testing.T) {
})
}
}
+
+func TestDashboardTokenConfigHelpPath(t *testing.T) {
+ const launcherPath = "/tmp/launcher-config.json"
+
+ tests := []struct {
+ name string
+ source launcherconfig.DashboardTokenSource
+ want string
+ }{
+ {
+ name: "env token does not expose config path",
+ source: launcherconfig.DashboardTokenSourceEnv,
+ want: "",
+ },
+ {
+ name: "config token exposes config path",
+ source: launcherconfig.DashboardTokenSourceConfig,
+ want: launcherPath,
+ },
+ {
+ name: "random token does not expose config path",
+ source: launcherconfig.DashboardTokenSourceRandom,
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := dashboardTokenConfigHelpPath(tt.source, launcherPath); got != tt.want {
+ t.Fatalf("dashboardTokenConfigHelpPath(%q, %q) = %q, want %q", tt.source, launcherPath, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/web/backend/middleware/middleware.go b/web/backend/middleware/middleware.go
index a0b7eb998..f9eb3149d 100644
--- a/web/backend/middleware/middleware.go
+++ b/web/backend/middleware/middleware.go
@@ -71,6 +71,7 @@ func Recoverer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
+ logger.RecoverPanicNoExit(err)
logger.ErrorC("http", fmt.Sprintf("panic recovered: %v\n%s", err, debug.Stack()))
http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError)
}
diff --git a/web/frontend/eslint.config.js b/web/frontend/eslint.config.js
index bc9c64344..85d380c4f 100644
--- a/web/frontend/eslint.config.js
+++ b/web/frontend/eslint.config.js
@@ -28,4 +28,12 @@ export default defineConfig([
],
},
},
+ {
+ files: ["src/routes/**/*.{ts,tsx}"],
+ rules: {
+ // TanStack Router route modules must export Route objects, so this rule
+ // produces false positives for framework-managed files.
+ "react-refresh/only-export-components": "off",
+ },
+ },
])
diff --git a/web/frontend/package.json b/web/frontend/package.json
index 906425b58..c802c71ff 100644
--- a/web/frontend/package.json
+++ b/web/frontend/package.json
@@ -19,25 +19,25 @@
"@fontsource-variable/inter": "^5.2.8",
"@tabler/icons-react": "^3.40.0",
"@tailwindcss/vite": "^4.2.2",
- "@tanstack/react-query": "^5.90.21",
+ "@tanstack/react-query": "^5.96.1",
"@tanstack/react-router": "^1.167.0",
"@tanstack/react-router-devtools": "^1.163.3",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"dayjs": "^1.11.20",
- "i18next": "^26.0.1",
+ "i18next": "^26.0.3",
"i18next-browser-languagedetector": "^8.2.1",
"jotai": "^2.18.1",
"radix-ui": "^1.4.3",
"react": "^19.2.0",
"react-dom": "^19.2.0",
- "react-i18next": "^16.5.8",
+ "react-i18next": "^17.0.2",
"react-markdown": "^10.1.0",
"react-textarea-autosize": "^8.5.9",
"rehype-raw": "^7.0.0",
"rehype-sanitize": "^6.0.0",
"remark-gfm": "^4.0.1",
- "shadcn": "^4.1.0",
+ "shadcn": "^4.1.2",
"sonner": "^2.0.7",
"tailwind-merge": "^3.5.0",
"tailwindcss": "^4.2.2",
@@ -57,7 +57,7 @@
"eslint": "^10.1.0",
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-react-hooks": "^7.0.1",
- "eslint-plugin-react-refresh": "^0.4.26",
+ "eslint-plugin-react-refresh": "^0.5.2",
"globals": "^17.4.0",
"prettier": "^3.8.1",
"prettier-plugin-tailwindcss": "^0.7.2",
diff --git a/web/frontend/pnpm-lock.yaml b/web/frontend/pnpm-lock.yaml
index abb906c81..eb464f62d 100644
--- a/web/frontend/pnpm-lock.yaml
+++ b/web/frontend/pnpm-lock.yaml
@@ -18,8 +18,8 @@ importers:
specifier: ^4.2.2
version: 4.2.2(vite@8.0.3(@emnapi/core@1.9.1)(@emnapi/runtime@1.9.1)(@types/node@25.5.0)(esbuild@0.27.4)(jiti@2.6.1)(tsx@4.21.0))
'@tanstack/react-query':
- specifier: ^5.90.21
- version: 5.95.2(react@19.2.4)
+ specifier: ^5.96.1
+ version: 5.96.1(react@19.2.4)
'@tanstack/react-router':
specifier: ^1.167.0
version: 1.168.8(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
@@ -36,8 +36,8 @@ importers:
specifier: ^1.11.20
version: 1.11.20
i18next:
- specifier: ^26.0.1
- version: 26.0.1(typescript@5.9.3)
+ specifier: ^26.0.3
+ version: 26.0.3(typescript@5.9.3)
i18next-browser-languagedetector:
specifier: ^8.2.1
version: 8.2.1
@@ -54,8 +54,8 @@ importers:
specifier: ^19.2.0
version: 19.2.4(react@19.2.4)
react-i18next:
- specifier: ^16.5.8
- version: 16.6.6(i18next@26.0.1(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)
+ specifier: ^17.0.2
+ version: 17.0.2(i18next@26.0.3(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)
react-markdown:
specifier: ^10.1.0
version: 10.1.0(@types/react@19.2.14)(react@19.2.4)
@@ -72,8 +72,8 @@ importers:
specifier: ^4.0.1
version: 4.0.1
shadcn:
- specifier: ^4.1.0
- version: 4.1.1(@types/node@25.5.0)(typescript@5.9.3)
+ specifier: ^4.1.2
+ version: 4.1.2(@types/node@25.5.0)(typescript@5.9.3)
sonner:
specifier: ^2.0.7
version: 2.0.7(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
@@ -127,8 +127,8 @@ importers:
specifier: ^7.0.1
version: 7.0.1(eslint@10.1.0(jiti@2.6.1))
eslint-plugin-react-refresh:
- specifier: ^0.4.26
- version: 0.4.26(eslint@10.1.0(jiti@2.6.1))
+ specifier: ^0.5.2
+ version: 0.5.2(eslint@10.1.0(jiti@2.6.1))
globals:
specifier: ^17.4.0
version: 17.4.0
@@ -287,9 +287,9 @@ packages:
resolution: {integrity: sha512-Qg+meC+XFxliuVSDlEPkKnaUjdaJKK6FNx/Wwl2UxhQR8pyPIuLhMavsF7ePdB9qFZUWV1jEK3ckbJir/WmF4w==}
hasBin: true
- '@ecies/ciphers@0.2.5':
- resolution: {integrity: sha512-GalEZH4JgOMHYYcYmVqnFirFsjZHeoGMDt9IxEnM9F7GRUUyUksJ7Ou53L83WHJq3RWKD3AcBpo0iQh0oMpf8A==}
- engines: {bun: '>=1', deno: '>=2', node: '>=16'}
+ '@ecies/ciphers@0.2.6':
+ resolution: {integrity: sha512-patgsRPKGkhhoBjETV4XxD0En4ui5fbX0hzayqI3M8tvNMGUoUvmyYAIWwlxBc1KX5cturfqByYdj5bYGRpN9g==}
+ engines: {bun: '>=1', deno: '>=2.7.10', node: '>=16'}
peerDependencies:
'@noble/ciphers': ^1.0.0
@@ -515,8 +515,8 @@ packages:
'@fontsource-variable/inter@5.2.8':
resolution: {integrity: sha512-kOfP2D+ykbcX/P3IFnokOhVRNoTozo5/JxhAIVYLpea/UBmCQ/YWPBfWIDuBImXX/15KH+eKh4xpEUyS2sQQGQ==}
- '@hono/node-server@1.19.11':
- resolution: {integrity: sha512-dr8/3zEaB+p0D2n/IUrlPF1HZm586qgJNXK1a9fhg/PzdtkK7Ksd5l312tJX2yBuALqDYBlG20QEbayqPyxn+g==}
+ '@hono/node-server@1.19.12':
+ resolution: {integrity: sha512-txsUW4SQ1iilgE0l9/e9VQWmELXifEFvmdA1j6WFh/aFPj99hIntrSsq/if0UWyGVkmrRPKA1wCeP+UCr1B9Uw==}
engines: {node: '>=18.14.1'}
peerDependencies:
hono: ^4
@@ -588,8 +588,8 @@ packages:
'@jridgewell/trace-mapping@0.3.31':
resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==}
- '@modelcontextprotocol/sdk@1.28.0':
- resolution: {integrity: sha512-gmloF+i+flI8ouQK7MWW4mOwuMh4RePBuPFAEPC6+pdqyWOUMDOixb6qZ69owLJpz6XmyllCouc4t8YWO+E2Nw==}
+ '@modelcontextprotocol/sdk@1.29.0':
+ resolution: {integrity: sha512-zo37mZA9hJWpULgkRpowewez1y6ML5GsXJPY8FI0tBBCd77HEvza4jDqRKOXgHNn867PVGCyTdzqpz0izu5ZjQ==}
engines: {node: '>=18'}
peerDependencies:
'@cfworker/json-schema': ^4.1.1
@@ -1543,11 +1543,11 @@ packages:
resolution: {integrity: sha512-NaOGLRrddszbQj9upGat6HG/4TKvXLvu+osAIgfxPYA+eIvYKv8GKDJOrY2D3/U9MRnKfMWD7bU4jeD4xmqyIg==}
engines: {node: '>=20.19'}
- '@tanstack/query-core@5.95.2':
- resolution: {integrity: sha512-o4T8vZHZET4Bib3jZ/tCW9/7080urD4c+0/AUaYVpIqOsr7y0reBc1oX3ttNaSW5mYyvZHctiQ/UOP2PfdmFEQ==}
+ '@tanstack/query-core@5.96.1':
+ resolution: {integrity: sha512-u1yBgtavSy+N8wgtW3PiER6UpxcplMje65yXnnVgiHTqiMwLlxiw4WvQDrXyn+UD6lnn8kHaxmerJUzQcV/MMg==}
- '@tanstack/react-query@5.95.2':
- resolution: {integrity: sha512-/wGkvLj/st5Ud1Q76KF1uFxScV7WeqN1slQx5280ycwAyYkIPGaRZAEgHxe3bjirSd5Zpwkj6zNcR4cqYni/ZA==}
+ '@tanstack/react-query@5.96.1':
+ resolution: {integrity: sha512-2X7KYK5KKWUKGeWCVcqxXAkYefJtrKB7tSKWgeG++b0H6BRHxQaLSSi8AxcgjmUnnosHuh9WsFZqvE16P1WCzA==}
peerDependencies:
react: ^18 || ^19
@@ -1856,8 +1856,8 @@ packages:
resolution: {integrity: sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==}
engines: {node: 18 || 20 || >=22}
- baseline-browser-mapping@2.10.12:
- resolution: {integrity: sha512-qyq26DxfY4awP2gIRXhhLWfwzwI+N5Nxk6iQi8EFizIaWIjqicQTE4sLnZZVdeKPRcVNoJOkkpfzoIYuvCKaIQ==}
+ baseline-browser-mapping@2.10.13:
+ resolution: {integrity: sha512-BL2sTuHOdy0YT1lYieUxTw/QMtPBC3pmlJC6xk8BBYVv6vcw3SGdKemQ+Xsx9ik2F/lYDO9tqsFQH1r9PFuHKw==}
engines: {node: '>=6.0.0'}
hasBin: true
@@ -1880,8 +1880,8 @@ packages:
resolution: {integrity: sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==}
engines: {node: '>=8'}
- browserslist@4.28.1:
- resolution: {integrity: sha512-ZC5Bd0LgJXgwGqUknZY/vkUQ04r8NXnJZ3yYi4vDmSiZmC/pdSN0NbNRPxZpbtO4uAfDUAFffO8IZoM3Gj8IkA==}
+ browserslist@4.28.2:
+ resolution: {integrity: sha512-48xSriZYYg+8qXna9kwqjIVzuQxi+KYWp2+5nCYnYKPTr0LvD89Jqk2Or5ogxz0NUMfIjhh2lIUX/LyX9B4oIg==}
engines: {node: ^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7}
hasBin: true
@@ -1905,8 +1905,8 @@ packages:
resolution: {integrity: sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==}
engines: {node: '>=6'}
- caniuse-lite@1.0.30001782:
- resolution: {integrity: sha512-dZcaJLJeDMh4rELYFw1tvSn1bhZWYFOt468FcbHHxx/Z/dFidd1I6ciyFdi3iwfQCyOjqo9upF6lGQYtMiJWxw==}
+ caniuse-lite@1.0.30001784:
+ resolution: {integrity: sha512-WU346nBTklUV9YfUl60fqRbU5ZqyXlqvo1SgigE1OAXK5bFL8LL9q1K7aap3N739l4BvNqnkm3YrGHiY9sfUQw==}
ccount@2.0.1:
resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==}
@@ -2094,8 +2094,8 @@ packages:
resolution: {integrity: sha512-DPi0FmjiSU5EvQV0++GFDOJ9ASQUVFh5kD+OzOnYdi7n3Wpm9hWWGfB/O2blfHcMVTL5WkQXSnRiK9makhrcnw==}
engines: {node: '>=0.3.1'}
- dotenv@17.3.1:
- resolution: {integrity: sha512-IO8C/dzEb6O3F9/twg6ZLXz164a2fhTnEWb95H23Dm4OuN+92NmEAlTrupP9VW6Jm3sO26tQlqyvyi4CsnY9GA==}
+ dotenv@17.4.0:
+ resolution: {integrity: sha512-kCKF62fwtzwYm0IGBNjRUjtJgMfGapII+FslMHIjMR5KTnwEmBmWLDRSnc3XSNP8bNy34tekgQyDT0hr7pERRQ==}
engines: {node: '>=12'}
dunder-proto@1.0.1:
@@ -2109,8 +2109,8 @@ packages:
ee-first@1.1.1:
resolution: {integrity: sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==}
- electron-to-chromium@1.5.328:
- resolution: {integrity: sha512-QNQ5l45DzYytThO21403XN3FvK0hOkWDG8viNf6jqS42msJ8I4tGDSpBCgvDRRPnkffafiwAym2X2eHeGD2V0w==}
+ electron-to-chromium@1.5.331:
+ resolution: {integrity: sha512-IbxXrsTlD3hRodkLnbxAPP4OuJYdWCeM3IOdT+CpcMoIwIoDfCmRpEtSPfwBXxVkg9xmBeY7Lz2Eo2TDn/HC3Q==}
emoji-regex@10.6.0:
resolution: {integrity: sha512-toUI84YS5YmxW219erniWD0CIVOo46xGKColeNQRgOzDorgBi1v4D71/OFzgD9GO2UGKIv1C3Sp8DAn0+j5w7A==}
@@ -2181,10 +2181,10 @@ packages:
peerDependencies:
eslint: ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0 || ^9.0.0
- eslint-plugin-react-refresh@0.4.26:
- resolution: {integrity: sha512-1RETEylht2O6FM/MvgnyvT+8K21wLqDNg4qD51Zj3guhjt433XbnnkVttHMyaVyAFD03QSV4LPS5iE3VQmO7XQ==}
+ eslint-plugin-react-refresh@0.5.2:
+ resolution: {integrity: sha512-hmgTH57GfzoTFjVN0yBwTggnsVUF2tcqi7RJZHqi9lIezSs4eFyAMktA68YD4r5kNw1mxyY4dmkyoFDb3FIqrA==}
peerDependencies:
- eslint: '>=8.40'
+ eslint: ^9 || ^10
eslint-scope@9.1.2:
resolution: {integrity: sha512-xS90H51cKw0jltxmvmHy2Iai1LIqrfbw57b79w/J7MfvDfkIkFZ+kj6zC3BjtUwh150HsSSdxXZcsuv72miDFQ==}
@@ -2256,8 +2256,8 @@ packages:
resolution: {integrity: sha512-9Be3ZoN4LmYR90tUoVu2te2BsbzHfhJyfEiAVfz7N5/zv+jduIfLrV2xdQXOHbaD6KgpGdO9PRPM1Y4Q9QkPkA==}
engines: {node: ^18.19.0 || >=20.5.0}
- express-rate-limit@8.3.1:
- resolution: {integrity: sha512-D1dKN+cmyPWuvB+G2SREQDzPY1agpBIcTa9sJxOPMCNeH3gwzhqJRDWCXW3gg0y//+LQ/8j52JbMROWyrKdMdw==}
+ express-rate-limit@8.3.2:
+ resolution: {integrity: sha512-77VmFeJkO0/rvimEDuUC5H30oqUC4EyOhyGccfqoLebB0oiEYfM7nwPrsDsBL1gsTpwfzX8SFy2MT3TDyRq+bg==}
engines: {node: '>= 16'}
peerDependencies:
express: '>= 4.11'
@@ -2463,8 +2463,8 @@ packages:
hermes-parser@0.25.1:
resolution: {integrity: sha512-6pEjquH3rqaI6cYAXYPcz9MS4rY6R4ngRgrgfDshRptUZIc3lw0MCIJIGDj9++mfySOuPTHB4nrSW99BCvOPIA==}
- hono@4.12.9:
- resolution: {integrity: sha512-wy3T8Zm2bsEvxKZM5w21VdHDDcwVS1yUFFY6i8UobSsKfFceT7TOwhbhfKsDyx7tYQlmRM5FLpIuYvNFyjctiA==}
+ hono@4.12.10:
+ resolution: {integrity: sha512-mx/p18PLy5og9ufies2GOSUqep98Td9q4i/EF6X7yJgAiIopxqdfIO3jbqsi3jRgTgw88jMDEzVKi+V2EF+27w==}
engines: {node: '>=16.9.0'}
html-parse-stringify@3.0.1:
@@ -2495,8 +2495,8 @@ packages:
i18next-browser-languagedetector@8.2.1:
resolution: {integrity: sha512-bZg8+4bdmaOiApD7N7BPT9W8MLZG+nPTOFlLiJiT8uzKXFjhxw4v2ierCXOwB5sFDMtuA5G4kgYZ0AznZxQ/cw==}
- i18next@26.0.1:
- resolution: {integrity: sha512-vtz5sXU4+nkCm8yEU+JJ6yYIx0mkg9e68W0G0PXpnOsmzLajNsW5o28DJMqbajxfsfq0gV3XdrBudsDQnwxfsQ==}
+ i18next@26.0.3:
+ resolution: {integrity: sha512-1571kXINxHKY7LksWp8wP+zP0YqHSSpl/OW0Y0owFEf2H3s8gCAffWaZivcz14rMkOvn3R/psiQxVsR9t2Nafg==}
peerDependencies:
typescript: ^5 || ^6
peerDependenciesMeta:
@@ -2988,6 +2988,10 @@ packages:
resolution: {integrity: sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==}
engines: {node: 18 || 20 || >=22}
+ minimatch@10.2.5:
+ resolution: {integrity: sha512-MULkVLfKGYDFYejP07QOurDLLQpcjk7Fw+7jXS2R2czRQzR56yHRveU5NDJEOviH+hETZKSkIk5c+T23GjFUMg==}
+ engines: {node: 18 || 20 || >=22}
+
minimatch@9.0.9:
resolution: {integrity: sha512-OBwBN9AL4dqmETlpS2zasx+vTeWclWzkblfZk7KTA5j3jeOONz/tRCnZomUyvNg83wL5Zv9Ss6HMJXAgL8R2Yg==}
engines: {node: '>=16 || 14 >=14.17'}
@@ -3033,8 +3037,8 @@ packages:
resolution: {integrity: sha512-dRB78srN/l6gqWulah9SrxeYnxeddIG30+GOqK/9OlLVyLg3HPnr6SqOWTWOXKRwC2eGYCkZ59NNuSgvSrpgOA==}
engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0}
- node-releases@2.0.36:
- resolution: {integrity: sha512-TdC8FSgHz8Mwtw9g5L4gR/Sh9XhSP/0DEkQxfEFXOpiul5IiHgHan2VhYYb6agDSfp4KuvltmGApc8HMgUrIkA==}
+ node-releases@2.0.37:
+ resolution: {integrity: sha512-1h5gKZCF+pO/o3Iqt5Jp7wc9rH3eJJ0+nh/CIoiRwjRxde/hAHyLPXYN4V3CqKAbiZPSeJFSWHmJsbkicta0Eg==}
normalize-path@3.0.0:
resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==}
@@ -3144,8 +3148,8 @@ packages:
path-to-regexp@6.3.0:
resolution: {integrity: sha512-Yhpw4T9C6hPpgPeA28us07OJeqZ5EzQTkbfwuhsUg0c237RomFoETJgmp2sa3F/41gfLE6G5cqcYwznmeEeOlQ==}
- path-to-regexp@8.4.0:
- resolution: {integrity: sha512-PuseHIvAnz3bjrM2rGJtSgo1zjgxapTLZ7x2pjhzWwlp4SJQgK3f3iZIQwkpEnBaKz6seKBADpM4B4ySkuYypg==}
+ path-to-regexp@8.4.2:
+ resolution: {integrity: sha512-qRcuIdP69NPm4qbACK+aDogI5CBDMi1jKe0ry5rSQJz8JVLsC7jV8XpiJjGRLLol3N+R5ihGYcrPLTno6pAdBA==}
pathe@2.0.3:
resolution: {integrity: sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==}
@@ -3297,10 +3301,10 @@ packages:
peerDependencies:
react: ^19.2.4
- react-i18next@16.6.6:
- resolution: {integrity: sha512-ZgL2HUoW34UKUkOV7uSQFE1CDnRPD+tCR3ywSuWH7u2iapnz86U8Bi3Vrs620qNDzCf1F47NxglCEkchCTDOHw==}
+ react-i18next@17.0.2:
+ resolution: {integrity: sha512-shBftH2vaTWK2Bsp7FiL+cevx3xFJlvFxmsDFQSrJc+6twHkP0tv/bGa01VVWzpreUVVwU+3Hev5iFqRg65RwA==}
peerDependencies:
- i18next: '>= 25.10.9'
+ i18next: '>= 26.0.1'
react: '>= 16.8.0'
react-dom: '*'
react-native: '*'
@@ -3463,8 +3467,8 @@ packages:
setprototypeof@1.2.0:
resolution: {integrity: sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==}
- shadcn@4.1.1:
- resolution: {integrity: sha512-nBj+7LYC9kzV9v9QmRPpoOhfW4KctJVQejywdAt/K+K+z4RYlJOcO2a4AaF7elrRWkfCbgXeGK02liV0KB9HvQ==}
+ shadcn@4.1.2:
+ resolution: {integrity: sha512-qNQcCavkbYsgBj+X09tF2bTcwRd8abR880bsFkDU2kMqceMCLAm5c+cLg7kWDhfh1H9g08knpQ5ZEf6y/co16g==}
hasBin: true
shebang-command@2.0.0:
@@ -3976,7 +3980,7 @@ snapshots:
dependencies:
'@babel/compat-data': 7.29.0
'@babel/helper-validator-option': 7.27.1
- browserslist: 4.28.1
+ browserslist: 4.28.2
lru-cache: 5.1.1
semver: 6.3.1
@@ -4123,7 +4127,7 @@ snapshots:
'@dotenvx/dotenvx@1.59.1':
dependencies:
commander: 11.1.0
- dotenv: 17.3.1
+ dotenv: 17.4.0
eciesjs: 0.4.18
execa: 5.1.1
fdir: 6.5.0(picomatch@4.0.4)
@@ -4132,7 +4136,7 @@ snapshots:
picomatch: 4.0.4
which: 4.0.0
- '@ecies/ciphers@0.2.5(@noble/ciphers@1.3.0)':
+ '@ecies/ciphers@0.2.6(@noble/ciphers@1.3.0)':
dependencies:
'@noble/ciphers': 1.3.0
@@ -4283,9 +4287,9 @@ snapshots:
'@fontsource-variable/inter@5.2.8': {}
- '@hono/node-server@1.19.11(hono@4.12.9)':
+ '@hono/node-server@1.19.12(hono@4.12.10)':
dependencies:
- hono: 4.12.9
+ hono: 4.12.10
'@humanfs/core@0.19.1': {}
@@ -4345,9 +4349,9 @@ snapshots:
'@jridgewell/resolve-uri': 3.1.2
'@jridgewell/sourcemap-codec': 1.5.5
- '@modelcontextprotocol/sdk@1.28.0(zod@3.25.76)':
+ '@modelcontextprotocol/sdk@1.29.0(zod@3.25.76)':
dependencies:
- '@hono/node-server': 1.19.11(hono@4.12.9)
+ '@hono/node-server': 1.19.12(hono@4.12.10)
ajv: 8.18.0
ajv-formats: 3.0.1(ajv@8.18.0)
content-type: 1.0.5
@@ -4356,8 +4360,8 @@ snapshots:
eventsource: 3.0.7
eventsource-parser: 3.0.6
express: 5.2.1
- express-rate-limit: 8.3.1(express@5.2.1)
- hono: 4.12.9
+ express-rate-limit: 8.3.2(express@5.2.1)
+ hono: 4.12.10
jose: 6.2.2
json-schema-typed: 8.0.2
pkce-challenge: 5.0.1
@@ -5301,11 +5305,11 @@ snapshots:
'@tanstack/history@1.161.6': {}
- '@tanstack/query-core@5.95.2': {}
+ '@tanstack/query-core@5.96.1': {}
- '@tanstack/react-query@5.95.2(react@19.2.4)':
+ '@tanstack/react-query@5.96.1(react@19.2.4)':
dependencies:
- '@tanstack/query-core': 5.95.2
+ '@tanstack/query-core': 5.96.1
react: 19.2.4
'@tanstack/react-router-devtools@1.166.11(@tanstack/react-router@1.168.8(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.168.7)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
@@ -5419,7 +5423,7 @@ snapshots:
'@ts-morph/common@0.27.0':
dependencies:
fast-glob: 3.3.3
- minimatch: 10.2.4
+ minimatch: 10.2.5
path-browserify: 1.0.1
'@tybys/wasm-util@0.10.1':
@@ -5642,7 +5646,7 @@ snapshots:
balanced-match@4.0.4: {}
- baseline-browser-mapping@2.10.12: {}
+ baseline-browser-mapping@2.10.13: {}
binary-extensions@2.3.0: {}
@@ -5672,13 +5676,13 @@ snapshots:
dependencies:
fill-range: 7.1.1
- browserslist@4.28.1:
+ browserslist@4.28.2:
dependencies:
- baseline-browser-mapping: 2.10.12
- caniuse-lite: 1.0.30001782
- electron-to-chromium: 1.5.328
- node-releases: 2.0.36
- update-browserslist-db: 1.2.3(browserslist@4.28.1)
+ baseline-browser-mapping: 2.10.13
+ caniuse-lite: 1.0.30001784
+ electron-to-chromium: 1.5.331
+ node-releases: 2.0.37
+ update-browserslist-db: 1.2.3(browserslist@4.28.2)
bundle-name@4.1.0:
dependencies:
@@ -5698,7 +5702,7 @@ snapshots:
callsites@3.1.0: {}
- caniuse-lite@1.0.30001782: {}
+ caniuse-lite@1.0.30001784: {}
ccount@2.0.1: {}
@@ -5837,7 +5841,7 @@ snapshots:
diff@8.0.4: {}
- dotenv@17.3.1: {}
+ dotenv@17.4.0: {}
dunder-proto@1.0.1:
dependencies:
@@ -5847,14 +5851,14 @@ snapshots:
eciesjs@0.4.18:
dependencies:
- '@ecies/ciphers': 0.2.5(@noble/ciphers@1.3.0)
+ '@ecies/ciphers': 0.2.6(@noble/ciphers@1.3.0)
'@noble/ciphers': 1.3.0
'@noble/curves': 1.9.7
'@noble/hashes': 1.8.0
ee-first@1.1.1: {}
- electron-to-chromium@1.5.328: {}
+ electron-to-chromium@1.5.331: {}
emoji-regex@10.6.0: {}
@@ -5935,7 +5939,7 @@ snapshots:
transitivePeerDependencies:
- supports-color
- eslint-plugin-react-refresh@0.4.26(eslint@10.1.0(jiti@2.6.1)):
+ eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@2.6.1)):
dependencies:
eslint: 10.1.0(jiti@2.6.1)
@@ -6044,7 +6048,7 @@ snapshots:
strip-final-newline: 4.0.0
yoctocolors: 2.1.2
- express-rate-limit@8.3.1(express@5.2.1):
+ express-rate-limit@8.3.2(express@5.2.1):
dependencies:
express: 5.2.1
ip-address: 10.1.0
@@ -6321,7 +6325,7 @@ snapshots:
dependencies:
hermes-estree: 0.25.1
- hono@4.12.9: {}
+ hono@4.12.10: {}
html-parse-stringify@3.0.1:
dependencies:
@@ -6354,7 +6358,7 @@ snapshots:
dependencies:
'@babel/runtime': 7.29.2
- i18next@26.0.1(typescript@5.9.3):
+ i18next@26.0.3(typescript@5.9.3):
dependencies:
'@babel/runtime': 7.29.2
optionalDependencies:
@@ -6949,6 +6953,10 @@ snapshots:
dependencies:
brace-expansion: 5.0.5
+ minimatch@10.2.5:
+ dependencies:
+ brace-expansion: 5.0.5
+
minimatch@9.0.9:
dependencies:
brace-expansion: 2.0.3
@@ -6998,7 +7006,7 @@ snapshots:
fetch-blob: 3.2.0
formdata-polyfill: 4.0.10
- node-releases@2.0.36: {}
+ node-releases@2.0.37: {}
normalize-path@3.0.0: {}
@@ -7118,7 +7126,7 @@ snapshots:
path-to-regexp@6.3.0: {}
- path-to-regexp@8.4.0: {}
+ path-to-regexp@8.4.2: {}
pathe@2.0.3: {}
@@ -7259,11 +7267,11 @@ snapshots:
react: 19.2.4
scheduler: 0.27.0
- react-i18next@16.6.6(i18next@26.0.1(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3):
+ react-i18next@17.0.2(i18next@26.0.3(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3):
dependencies:
'@babel/runtime': 7.29.2
html-parse-stringify: 3.0.1
- i18next: 26.0.1(typescript@5.9.3)
+ i18next: 26.0.3(typescript@5.9.3)
react: 19.2.4
use-sync-external-store: 1.6.0(react@19.2.4)
optionalDependencies:
@@ -7430,7 +7438,7 @@ snapshots:
depd: 2.0.0
is-promise: 4.0.0
parseurl: 1.3.3
- path-to-regexp: 8.4.0
+ path-to-regexp: 8.4.2
transitivePeerDependencies:
- supports-color
@@ -7481,16 +7489,16 @@ snapshots:
setprototypeof@1.2.0: {}
- shadcn@4.1.1(@types/node@25.5.0)(typescript@5.9.3):
+ shadcn@4.1.2(@types/node@25.5.0)(typescript@5.9.3):
dependencies:
'@babel/core': 7.29.0
'@babel/parser': 7.29.2
'@babel/plugin-transform-typescript': 7.28.6(@babel/core@7.29.0)
'@babel/preset-typescript': 7.28.5(@babel/core@7.29.0)
'@dotenvx/dotenvx': 1.59.1
- '@modelcontextprotocol/sdk': 1.28.0(zod@3.25.76)
+ '@modelcontextprotocol/sdk': 1.29.0(zod@3.25.76)
'@types/validate-npm-package-name': 4.0.2
- browserslist: 4.28.1
+ browserslist: 4.28.2
commander: 14.0.3
cosmiconfig: 9.0.1(typescript@5.9.3)
dedent: 1.7.2
@@ -7771,9 +7779,9 @@ snapshots:
until-async@3.0.2: {}
- update-browserslist-db@1.2.3(browserslist@4.28.1):
+ update-browserslist-db@1.2.3(browserslist@4.28.2):
dependencies:
- browserslist: 4.28.1
+ browserslist: 4.28.2
escalade: 3.2.0
picocolors: 1.1.1
diff --git a/web/frontend/src/api/channels.ts b/web/frontend/src/api/channels.ts
index eb4d41fd7..42a3a0606 100644
--- a/web/frontend/src/api/channels.ts
+++ b/web/frontend/src/api/channels.ts
@@ -1,5 +1,3 @@
-// API client for channels navigation and channel-specific config flows.
-
import { launcherFetch } from "@/api/http"
export type ChannelConfig = Record
@@ -12,6 +10,13 @@ export interface SupportedChannel {
variant?: string
}
+export interface ChannelConfigResponse {
+ config: ChannelConfig
+ configured_secrets: string[]
+ config_key: string
+ variant?: string
+}
+
interface ChannelsCatalogResponse {
channels: SupportedChannel[]
}
@@ -54,6 +59,14 @@ export async function getAppConfig(): Promise {
return request("/api/config")
}
+export async function getChannelConfig(
+ channelName: string,
+): Promise {
+ return request(
+ `/api/channels/${encodeURIComponent(channelName)}/config`,
+ )
+}
+
export async function patchAppConfig(
patch: Record,
): Promise {
diff --git a/web/frontend/src/api/launcher-auth.ts b/web/frontend/src/api/launcher-auth.ts
index 247d5ab9e..4ca51993b 100644
--- a/web/frontend/src/api/launcher-auth.ts
+++ b/web/frontend/src/api/launcher-auth.ts
@@ -17,6 +17,7 @@ export async function postLauncherDashboardLogin(
export type LauncherAuthTokenHelp = {
env_var_name: string
log_file?: string
+ config_file?: string
tray_copy_menu: boolean
console_stdout: boolean
}
diff --git a/web/frontend/src/api/models.ts b/web/frontend/src/api/models.ts
index eb8d287dd..bfdd80d6d 100644
--- a/web/frontend/src/api/models.ts
+++ b/web/frontend/src/api/models.ts
@@ -19,6 +19,7 @@ export interface ModelInfo {
request_timeout?: number
thinking_level?: string
extra_body?: Record
+ custom_headers?: Record
// Meta
available: boolean
status: "available" | "unconfigured" | "unreachable"
diff --git a/web/frontend/src/api/sessions.ts b/web/frontend/src/api/sessions.ts
index c91495901..dd0fa1f53 100644
--- a/web/frontend/src/api/sessions.ts
+++ b/web/frontend/src/api/sessions.ts
@@ -1,5 +1,3 @@
-// Sessions API — list and retrieve chat session history
-
import { launcherFetch } from "@/api/http"
export interface SessionSummary {
@@ -13,7 +11,11 @@ export interface SessionSummary {
export interface SessionDetail {
id: string
- messages: { role: "user" | "assistant"; content: string }[]
+ messages: {
+ role: "user" | "assistant"
+ content: string
+ media?: string[]
+ }[]
summary: string
created: string
updated: string
diff --git a/web/frontend/src/api/skills.ts b/web/frontend/src/api/skills.ts
index 72ccbcfe5..958808afd 100644
--- a/web/frontend/src/api/skills.ts
+++ b/web/frontend/src/api/skills.ts
@@ -5,22 +5,60 @@ export interface SkillSupportItem {
path: string
source: "workspace" | "global" | "builtin" | string
description: string
+ origin_kind: "builtin" | "third_party" | "manual" | string
+ registry_name?: string
+ registry_url?: string
+ installed_version?: string
+ installed_at?: number
}
export interface SkillDetailResponse extends SkillSupportItem {
content: string
}
+export interface SkillRegistrySearchResult {
+ score: number
+ slug: string
+ display_name: string
+ summary: string
+ version: string
+ registry_name: string
+ url?: string
+ installed: boolean
+ installed_name?: string
+}
+
interface SkillsResponse {
skills: SkillSupportItem[]
}
-interface SkillActionResponse {
+export interface SkillSearchResponse {
+ results: SkillRegistrySearchResult[]
+ limit: number
+ offset: number
+ next_offset?: number
+ has_more: boolean
+}
+
+type SkillActionResponse = Partial & {
status?: string
- name?: string
- path?: string
- source?: string
- description?: string
+}
+
+export interface InstallSkillRequest {
+ slug: string
+ registry: string
+ version?: string
+ force?: boolean
+}
+
+export interface InstallSkillResponse {
+ status: string
+ slug: string
+ registry: string
+ version: string
+ summary?: string
+ is_suspicious?: boolean
+ skill?: SkillSupportItem
}
async function request(path: string, options?: RequestInit): Promise {
@@ -39,6 +77,29 @@ export async function getSkill(name: string): Promise {
return request(`/api/skills/${encodeURIComponent(name)}`)
}
+export async function searchSkills(
+ query: string,
+ limit = 20,
+ offset = 0,
+): Promise {
+ const params = new URLSearchParams({
+ q: query,
+ limit: String(limit),
+ offset: String(offset),
+ })
+ return request(`/api/skills/search?${params.toString()}`)
+}
+
+export async function installSkill(
+ input: InstallSkillRequest,
+): Promise {
+ return request("/api/skills/install", {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify(input),
+ })
+}
+
export async function importSkill(file: File): Promise {
const formData = new FormData()
formData.set("file", file)
@@ -64,15 +125,23 @@ export async function deleteSkill(name: string): Promise {
async function extractErrorMessage(res: Response): Promise {
try {
- const body = (await res.json()) as {
- error?: string
- errors?: string[]
+ const raw = await res.text()
+ if (raw.trim() === "") {
+ return `API error: ${res.status} ${res.statusText}`
}
- if (Array.isArray(body.errors) && body.errors.length > 0) {
- return body.errors.join("; ")
- }
- if (typeof body.error === "string" && body.error.trim() !== "") {
- return body.error
+ try {
+ const body = JSON.parse(raw) as {
+ error?: string
+ errors?: string[]
+ }
+ if (Array.isArray(body.errors) && body.errors.length > 0) {
+ return body.errors.join("; ")
+ }
+ if (typeof body.error === "string" && body.error.trim() !== "") {
+ return body.error
+ }
+ } catch {
+ return raw.trim()
}
} catch {
// ignore invalid body
diff --git a/web/frontend/src/api/system.ts b/web/frontend/src/api/system.ts
index dfc48b6b8..8623c7e78 100644
--- a/web/frontend/src/api/system.ts
+++ b/web/frontend/src/api/system.ts
@@ -11,6 +11,7 @@ export interface LauncherConfig {
port: number
public: boolean
allowed_cidrs: string[]
+ launcher_token: string
}
export interface SystemVersionInfo {
diff --git a/web/frontend/src/components/agent/hub/hub-page.tsx b/web/frontend/src/components/agent/hub/hub-page.tsx
new file mode 100644
index 000000000..69f0be638
--- /dev/null
+++ b/web/frontend/src/components/agent/hub/hub-page.tsx
@@ -0,0 +1,51 @@
+import { useTranslation } from "react-i18next"
+
+import { PageHeader } from "@/components/page-header"
+
+import { ResultsPanel } from "./results-panel"
+import { SearchPanel } from "./search-panel"
+import { useHubMarketplace } from "./use-hub-marketplace"
+
+export function HubPage() {
+ const { t } = useTranslation()
+ const hub = useHubMarketplace()
+
+ return (
+