diff --git a/.env.example b/.env.example index 06d43070c..bc68456d6 100644 --- a/.env.example +++ b/.env.example @@ -17,4 +17,4 @@ # BRAVE_SEARCH_API_KEY=BSA... # ── Timezone ────────────────────────────── -TZ=Asia/Tokyo +TZ=Asia/Shanghai diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index be1c10c52..1e9a7919a 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -24,6 +24,25 @@ jobs: with: version: v2.10.1 + vuln_check: + name: Security Check + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run Govulncheck + uses: golang/govulncheck-action@v1 + with: + go-package: ./... + test: name: Tests runs-on: ubuntu-latest diff --git a/LICENSE b/LICENSE index 410acae26..b38d9340d 100644 --- a/LICENSE +++ b/LICENSE @@ -19,7 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ---- - -PicoClaw is heavily inspired by and based on [nanobot](https://github.com/HKUDS/nanobot) by HKUDS. diff --git a/README.fr.md b/README.fr.md index e537fc13a..320aa9e22 100644 --- a/README.fr.md +++ b/README.fr.md @@ -827,7 +827,7 @@ Le sous-agent a accès aux outils (message, web_search, etc.) et peut communique ### Fournisseurs > [!NOTE] -> Groq fournit la transcription vocale gratuite via Whisper. Si configuré, les messages vocaux Telegram seront automatiquement transcrits. +> Groq fournit la transcription vocale gratuite via Whisper. Si configuré, les messages audio de n'importe quel canal seront automatiquement transcrits au niveau de l'agent. | Fournisseur | Utilisation | Obtenir une Clé API | | ------------------------ | ---------------------------------------- | ------------------------------------------------------ | diff --git a/README.ja.md b/README.ja.md index 20ad5033b..ea6bc7e72 100644 --- a/README.ja.md +++ b/README.ja.md @@ -785,7 +785,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る ### プロバイダー > [!NOTE] -> Groq は Whisper による無料の音声文字起こしを提供しています。設定すると、Telegram の音声メッセージが自動的に文字起こしされます。 +> Groq は Whisper による無料の音声文字起こしを提供しています。設定すると、あらゆるチャンネルからの音声メッセージがエージェントレベルで自動的に文字起こしされます。 | プロバイダー | 用途 | API キー取得先 | | --- | --- | --- | diff --git a/README.md b/README.md index 6714ac6eb..7a31f9364 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ ## 📢 News -2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/ROADMAP.md) —we can’t wait to have you on board! +2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](ROADMAP.md) —we can’t wait to have you on board! 2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs & issues coming in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development. 🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting. @@ -216,7 +216,7 @@ docker compose -f docker/docker-compose.yml --profile gateway up -d > [!TIP] > Set your API key in `~/.picoclaw/config.json`. > Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) -> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback. +> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month), [SearXNG](https://github.com/searxng/searxng) (free, self-hosted) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback. **1. Initialize** @@ -265,6 +265,16 @@ picoclaw onboard "duckduckgo": { "enabled": true, "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "YOUR_PERPLEXITY_API_KEY", + "max_results": 5 + }, + "searxng": { + "enabled": false, + "base_url": "http://your-searxng-instance:8888", + "max_results": 5 } } } @@ -277,7 +287,12 @@ picoclaw onboard **3. Get API Keys** * **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) -* **Web Search** (optional): [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month) · [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month) +* **Web Search** (optional): + * [Brave Search](https://brave.com/search/api) - Paid ($5/1000 queries, ~$5-6/month) + * [Perplexity](https://www.perplexity.ai) - AI-powered search with chat interface + * [SearXNG](https://github.com/searxng/searxng) - Self-hosted metasearch engine (free, no API key needed) + * [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month) + * DuckDuckGo - Built-in fallback (no API key required) > **Note**: See `config.example.json` for a complete configuration template. @@ -721,6 +736,20 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa └── USER.md # User preferences ``` +### Skill Sources + +By default, skills are loaded from: + +1. `~/.picoclaw/workspace/skills` (workspace) +2. `~/.picoclaw/skills` (global) +3. `/skills` (builtin) + +For advanced/test setups, you can override the builtin skills root with: + +```bash +export PICOCLAW_BUILTIN_SKILLS=/path/to/skills +``` + ### 🔒 Security Sandbox PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. @@ -897,7 +926,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate ### Providers > [!NOTE] -> Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. +> Groq provides free voice transcription via Whisper. If configured, audio messages from any channel will be automatically transcribed at the agent level. | Provider | Purpose | Get API Key | | -------------------------- | --------------------------------------- | -------------------------------------------------------------------- | @@ -1227,6 +1256,16 @@ picoclaw agent -m "Hello" "duckduckgo": { "enabled": true, "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "", + "max_results": 5 + }, + "searxng": { + "enabled": false, + "base_url": "http://localhost:8888", + "max_results": 5 } }, "cron": { @@ -1284,10 +1323,69 @@ discord: This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching. -To enable web search: +#### Search Provider Priority -1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results. -2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required). +PicoClaw automatically selects the best available search provider in this order: +1. **Perplexity** (if enabled and API key configured) - AI-powered search with citations +2. **Brave Search** (if enabled and API key configured) - Privacy-focused paid API ($5/1000 queries) +3. **SearXNG** (if enabled and base_url configured) - Self-hosted metasearch aggregating 70+ engines (free) +4. **DuckDuckGo** (if enabled, default fallback) - No API key required (free) + +#### Web Search Configuration Options + +**Option 1 (Best Results)**: Perplexity AI Search +```json +{ + "tools": { + "web": { + "perplexity": { + "enabled": true, + "api_key": "YOUR_PERPLEXITY_API_KEY", + "max_results": 5 + } + } + } +} +``` + +**Option 2 (Paid API)**: Get an API key at [https://brave.com/search/api](https://brave.com/search/api) ($5/1000 queries, ~$5-6/month) +```json +{ + "tools": { + "web": { + "brave": { + "enabled": true, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + } +} +``` + +**Option 3 (Self-Hosted)**: Deploy your own [SearXNG](https://github.com/searxng/searxng) instance +```json +{ + "tools": { + "web": { + "searxng": { + "enabled": true, + "base_url": "http://your-server:8888", + "max_results": 5 + } + } + } +} +``` + +Benefits of SearXNG: +- **Zero cost**: No API fees or rate limits +- **Privacy-focused**: Self-hosted, no tracking +- **Aggregate results**: Queries 70+ search engines simultaneously +- **Perfect for cloud VMs**: Solves datacenter IP blocking issues (Oracle Cloud, GCP, AWS, Azure) +- **No API key needed**: Just deploy and configure the base URL + +**Option 4 (No Setup Required)**: DuckDuckGo is enabled by default as fallback (no API key needed) Add the key to `~/.picoclaw/config.json` if using Brave: @@ -1303,6 +1401,16 @@ Add the key to `~/.picoclaw/config.json` if using Brave: "duckduckgo": { "enabled": true, "max_results": 5 + }, + "perplexity": { + "enabled": false, + "api_key": "YOUR_PERPLEXITY_API_KEY", + "max_results": 5 + }, + "searxng": { + "enabled": false, + "base_url": "http://your-searxng-instance:8888", + "max_results": 5 } } } @@ -1321,10 +1429,11 @@ This happens when another instance of the bot is running. Make sure only one `pi ## 📝 API Key Comparison -| Service | Free Tier | Use Case | -| ---------------- | ------------------- | ------------------------------------- | -| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | -| **Zhipu** | 200K tokens/month | Best for Chinese users | -| **Brave Search** | 2000 queries/month | Web search functionality | -| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | -| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) | +| Service | Free Tier | Use Case | +| ---------------- | ------------------------ | ------------------------------------- | +| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/month | Best for Chinese users | +| **Brave Search** | Paid ($5/1000 queries) | Web search functionality | +| **SearXNG** | Unlimited (self-hosted) | Privacy-focused metasearch (70+ engines) | +| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | +| **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) | diff --git a/README.pt-br.md b/README.pt-br.md index bfe655770..67ce9e0d3 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -823,7 +823,7 @@ O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se com ### Provedores > [!NOTE] -> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de voz do Telegram serão automaticamente transcritas. +> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de áudio de qualquer canal serão automaticamente transcritas no nível do agente. | Provedor | Finalidade | Obter API Key | | --- | --- | --- | diff --git a/README.vi.md b/README.vi.md index b30659614..5755896ed 100644 --- a/README.vi.md +++ b/README.vi.md @@ -795,7 +795,7 @@ Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và ### Nhà cung cấp (Providers) > [!NOTE] -> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn thoại trên Telegram sẽ được tự động chuyển thành văn bản. +> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn âm thanh từ bất kỳ kênh nào sẽ được tự động chuyển thành văn bản ở cấp độ agent. | Nhà cung cấp | Mục đích | Lấy API Key | | --- | --- | --- | diff --git a/README.zh.md b/README.zh.md index d3a49ee8d..bd90173f9 100644 --- a/README.zh.md +++ b/README.zh.md @@ -362,6 +362,20 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work ``` +### 技能来源 (Skill Sources) + +默认情况下,技能会按以下顺序加载: + +1. `~/.picoclaw/workspace/skills`(工作区) +2. `~/.picoclaw/skills`(全局) +3. `/skills`(内置) + +在高级/测试场景下,可通过以下环境变量覆盖内置技能目录: + +```bash +export PICOCLAW_BUILTIN_SKILLS=/path/to/skills +``` + ### 心跳 / 周期性任务 (Heartbeat) PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: @@ -445,7 +459,7 @@ Agent 读取 HEARTBEAT.md ### 提供商 (Providers) > [!NOTE] -> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。 +> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,任意渠道的音频消息都将在 Agent 层面自动转录为文字。 | 提供商 | 用途 | 获取 API Key | | -------------------- | ---------------------------- | -------------------------------------------------------------------- | diff --git a/assets/wechat.png b/assets/wechat.png index 1c0b88295..32998c122 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw-launcher-tui/internal/ui/model.go b/cmd/picoclaw-launcher-tui/internal/ui/model.go index ba91f5b09..304b4efa7 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/model.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/model.go @@ -335,7 +335,11 @@ func (s *appState) testModel(model *picoclawconfig.ModelConfig) { s.showMessage("Test OK", resp.Status) return } - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 2048)) + if err != nil { + s.showMessage("Test failed", fmt.Sprintf("failed to read response: %v", err)) + return + } s.showMessage( "Test failed", fmt.Sprintf("%s: %s", resp.Status, strings.TrimSpace(string(body))), diff --git a/cmd/picoclaw-launcher/internal/server/auth_handlers.go b/cmd/picoclaw-launcher/internal/server/auth_handlers.go index 1e9b8be0a..3b48f9739 100644 --- a/cmd/picoclaw-launcher/internal/server/auth_handlers.go +++ b/cmd/picoclaw-launcher/internal/server/auth_handlers.go @@ -297,7 +297,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading userinfo response: %w", err) + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("userinfo request failed: %s", string(body)) } diff --git a/cmd/picoclaw/internal/auth/helpers.go b/cmd/picoclaw/internal/auth/helpers.go index 633ce8740..4dfbc92e7 100644 --- a/cmd/picoclaw/internal/auth/helpers.go +++ b/cmd/picoclaw/internal/auth/helpers.go @@ -177,7 +177,10 @@ func fetchGoogleUserEmail(accessToken string) (string, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading userinfo response: %w", err) + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("userinfo request failed: %s", string(body)) } diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index 747f7d44e..174f5db62 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -36,6 +36,7 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/voice" ) func gatewayCmd(debug bool) error { @@ -134,6 +135,12 @@ func gatewayCmd(debug bool) error { agentLoop.SetChannelManager(channelManager) agentLoop.SetMediaStore(mediaStore) + // Wire up voice transcription if a supported provider is configured. + if transcriber := voice.DetectTranscriber(cfg); transcriber != nil { + agentLoop.SetTranscriber(transcriber) + logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()}) + } + enabledChannels := channelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf("✓ Channels enabled: %s\n", enabledChannels) @@ -223,19 +230,25 @@ func setupCronTool( // Create cron service cronService := cron.NewCronService(cronStorePath, nil) - // Create and register CronTool - cronTool, err := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) - if err != nil { - log.Fatalf("Critical error during CronTool initialization: %v", err) + // Create and register CronTool if enabled + var cronTool *tools.CronTool + if cfg.Tools.IsToolEnabled("cron") { + var err error + cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg) + if err != nil { + log.Fatalf("Critical error during CronTool initialization: %v", err) + } + + agentLoop.RegisterTool(cronTool) } - agentLoop.RegisterTool(cronTool) - - // Set the onJob handler - cronService.SetOnJob(func(job *cron.CronJob) (string, error) { - result := cronTool.ExecuteJob(context.Background(), job) - return result, nil - }) + // Set onJob handler + if cronTool != nil { + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + } return cronService } diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go index 9655d3c08..f81d7013d 100644 --- a/cmd/picoclaw/internal/helpers.go +++ b/cmd/picoclaw/internal/helpers.go @@ -18,12 +18,21 @@ var ( goVersion string ) +// GetPicoclawHome returns the picoclaw home directory. +// Priority: $PICOCLAW_HOME > ~/.picoclaw +func GetPicoclawHome() string { + if home := os.Getenv("PICOCLAW_HOME"); home != "" { + return home + } + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw") +} + func GetConfigPath() string { if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" { return configPath } - home, _ := os.UserHomeDir() - return filepath.Join(home, ".picoclaw", "config.json") + return filepath.Join(GetPicoclawHome(), "config.json") } func LoadConfig() (*config.Config, error) { diff --git a/cmd/picoclaw/internal/helpers_test.go b/cmd/picoclaw/internal/helpers_test.go index 47e2f8c07..646be1ba1 100644 --- a/cmd/picoclaw/internal/helpers_test.go +++ b/cmd/picoclaw/internal/helpers_test.go @@ -19,6 +19,27 @@ func TestGetConfigPath(t *testing.T) { assert.Equal(t, want, got) } +func TestGetConfigPath_WithPICOCLAW_HOME(t *testing.T) { + t.Setenv("PICOCLAW_HOME", "/custom/picoclaw") + t.Setenv("HOME", "/tmp/home") + + got := GetConfigPath() + want := filepath.Join("/custom/picoclaw", "config.json") + + assert.Equal(t, want, got) +} + +func TestGetConfigPath_WithPICOCLAW_CONFIG(t *testing.T) { + t.Setenv("PICOCLAW_CONFIG", "/custom/config.json") + t.Setenv("PICOCLAW_HOME", "/custom/picoclaw") + t.Setenv("HOME", "/tmp/home") + + got := GetConfigPath() + want := "/custom/config.json" + + assert.Equal(t, want, got) +} + func TestFormatVersion_NoGitCommit(t *testing.T) { oldVersion, oldGit := version, gitCommit t.Cleanup(func() { version, gitCommit = oldVersion, oldGit }) diff --git a/cmd/picoclaw/internal/skills/install.go b/cmd/picoclaw/internal/skills/install.go index a30f68632..78bc421db 100644 --- a/cmd/picoclaw/internal/skills/install.go +++ b/cmd/picoclaw/internal/skills/install.go @@ -21,8 +21,8 @@ picoclaw skills install --registry clawhub github `, Args: func(cmd *cobra.Command, args []string) error { if registry != "" { - if len(args) != 2 { - return fmt.Errorf("when --registry is set, exactly 2 arguments are required: ") + if len(args) != 1 { + return fmt.Errorf("when --registry is set, exactly 1 argument is required: ") } return nil } @@ -45,7 +45,7 @@ picoclaw skills install --registry clawhub github return err } - return skillsInstallFromRegistry(cfg, args[0], args[1]) + return skillsInstallFromRegistry(cfg, registry, args[0]) } return skillsInstallCmd(installer, args[0]) diff --git a/cmd/picoclaw/internal/skills/install_test.go b/cmd/picoclaw/internal/skills/install_test.go index 97787a986..6b362822d 100644 --- a/cmd/picoclaw/internal/skills/install_test.go +++ b/cmd/picoclaw/internal/skills/install_test.go @@ -26,3 +26,72 @@ func TestNewInstallSubcommand(t *testing.T) { assert.Len(t, cmd.Aliases, 0) } + +func TestInstallCommandArgs(t *testing.T) { + tests := []struct { + name string + args []string + registry string + expectError bool + errorMsg string + }{ + { + name: "no registry, one arg", + args: []string{"sipeed/picoclaw-skills/weather"}, + registry: "", + expectError: false, + }, + { + name: "no registry, no args", + args: []string{}, + registry: "", + expectError: true, + errorMsg: "exactly 1 argument is required: ", + }, + { + name: "no registry, too many args", + args: []string{"arg1", "arg2"}, + registry: "", + expectError: true, + errorMsg: "exactly 1 argument is required: ", + }, + { + name: "with registry, one arg", + args: []string{"weather-skill"}, + registry: "clawhub", + expectError: false, + }, + { + name: "with registry, no args", + args: []string{}, + registry: "clawhub", + expectError: true, + errorMsg: "when --registry is set, exactly 1 argument is required: ", + }, + { + name: "with registry, too many args", + args: []string{"arg1", "arg2"}, + registry: "clawhub", + expectError: true, + errorMsg: "when --registry is set, exactly 1 argument is required: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := newInstallCommand(nil) + + if tt.registry != "" { + require.NoError(t, cmd.Flags().Set("registry", tt.registry)) + } + + err := cmd.Args(cmd, tt.args) + if tt.expectError { + require.Error(t, err) + assert.Equal(t, tt.errorMsg, err.Error()) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/config/config.example.json b/config/config.example.json index 0c4991a49..2f643d41b 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -6,7 +6,9 @@ "model_name": "gpt4", "max_tokens": 8192, "temperature": 0.7, - "max_tool_iterations": 20 + "max_tool_iterations": 20, + "summarize_message_threshold": 20, + "summarize_token_percent": 75 } }, "model_list": [ @@ -20,7 +22,8 @@ "model_name": "claude-sonnet-4.6", "model": "anthropic/claude-sonnet-4.6", "api_key": "sk-ant-your-key", - "api_base": "https://api.anthropic.com/v1" + "api_base": "https://api.anthropic.com/v1", + "thinking_level": "high" }, { "model_name": "gemini", @@ -49,6 +52,7 @@ "telegram": { "enabled": false, "token": "YOUR_TELEGRAM_BOT_TOKEN", + "base_url": "", "proxy": "", "allow_from": [ "YOUR_USER_ID" @@ -58,6 +62,7 @@ "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", + "proxy": "", "allow_from": [], "group_trigger": { "mention_only": false @@ -220,32 +225,66 @@ "mistral": { "api_key": "", "api_base": "https://api.mistral.ai/v1" + }, + "avian": { + "api_key": "", + "api_base": "https://api.avian.io/v1" } }, "tools": { + "allow_read_paths": null, + "allow_write_paths": null, "web": { + "enabled": true, "brave": { "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 }, + "tavily": { + "enabled": false, + "api_key": "", + "base_url": "", + "max_results": 0 + }, "duckduckgo": { "enabled": true, "max_results": 5 }, "perplexity": { "enabled": false, - "api_key": "pplx-xxx", + "api_key": "", "max_results": 5 }, - "proxy": "" + "searxng": { + "enabled": false, + "base_url": "http://localhost:8888", + "max_results": 5 + }, + "glm_search": { + "enabled": false, + "api_key": "", + "base_url": "https://open.bigmodel.cn/api/paas/v4/web_search", + "search_engine": "search_std", + "max_results": 5 + }, + "fetch_limit_bytes": 10485760 }, "cron": { + "enabled": true, "exec_timeout_minutes": 5 }, "mcp": { "enabled": false, "servers": { + "context7": { + "enabled": false, + "type": "http", + "url": "https://mcp.context7.com/mcp", + "headers": { + "CONTEXT7_API_KEY": "ctx7sk-xx" + } + }, "filesystem": { "enabled": false, "command": "npx", @@ -301,19 +340,75 @@ } }, "exec": { - "enable_deny_patterns": false, - "custom_deny_patterns": [] + "enabled": true, + "enable_deny_patterns": true, + "custom_deny_patterns": null, + "custom_allow_patterns": null }, "skills": { + "enabled": true, "registries": { "clawhub": { "enabled": true, "base_url": "https://clawhub.ai", - "search_path": "/api/v1/search", - "skills_path": "/api/v1/skills", - "download_path": "/api/v1/download" + "auth_token": "", + "search_path": "", + "skills_path": "", + "download_path": "", + "timeout": 0, + "max_zip_size": 0, + "max_response_size": 0 } + }, + "max_concurrent_searches": 2, + "search_cache": { + "max_size": 50, + "ttl_seconds": 300 } + }, + "media_cleanup": { + "enabled": true, + "max_age_minutes": 30, + "interval_minutes": 5 + }, + "append_file": { + "enabled": true + }, + "edit_file": { + "enabled": true + }, + "find_skills": { + "enabled": true + }, + "i2c": { + "enabled": false + }, + "install_skill": { + "enabled": true + }, + "list_dir": { + "enabled": true + }, + "message": { + "enabled": true + }, + "read_file": { + "enabled": true + }, + "spawn": { + "enabled": true + }, + "spi": { + "enabled": false + }, + "subagent": { + "enabled": true + }, + "web_fetch": { + "enabled": true + }, + "write_file": { + "enabled": true } }, "heartbeat": { @@ -328,4 +423,4 @@ "host": "127.0.0.1", "port": 18790 } -} \ No newline at end of file +} diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md index 6204fb0c8..e64a3a107 100644 --- a/docs/tools_configuration.md +++ b/docs/tools_configuration.md @@ -180,6 +180,7 @@ The skills tool configures skill discovery and installation via registries like | ---------------------------------- | ------ | -------------------- | ----------------------- | | `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | | `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | +| `registries.clawhub.auth_token` | string | `""` | Optional Bearer token for higher rate limits | | `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | | `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | | `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | @@ -194,6 +195,7 @@ The skills tool configures skill discovery and installation via registries like "clawhub": { "enabled": true, "base_url": "https://clawhub.ai", + "auth_token": "", "search_path": "/api/v1/search", "skills_path": "/api/v1/skills", "download_path": "/api/v1/download" diff --git a/go.mod b/go.mod index 9f755bbc9..238bd405c 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 @@ -16,6 +17,7 @@ require ( github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 + github.com/rivo/tview v0.42.0 github.com/slack-go/slack v0.17.3 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 @@ -35,7 +37,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect - github.com/gdamore/tcell/v2 v2.13.8 // indirect + github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -44,7 +46,6 @@ require ( github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/tview v0.42.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/zerolog v1.34.0 // indirect github.com/spf13/pflag v1.0.10 // indirect diff --git a/go.sum b/go.sum index 9041826a5..060594d06 100644 --- a/go.sum +++ b/go.sum @@ -98,6 +98,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= +github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= +github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 6fccbaf53..d84aea627 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -34,9 +34,17 @@ type ContextBuilder struct { // created (didn't exist at cache time, now exist) or deleted (existed at // cache time, now gone) — both of which should trigger a cache rebuild. existedAtCache map[string]bool + + // skillFilesAtCache snapshots the skill tree file set and mtimes at cache + // build time. This catches nested file creations/deletions/mtime changes + // that may not update the top-level skill root directory mtime. + skillFilesAtCache map[string]time.Time } func getGlobalConfigDir() string { + if home := os.Getenv("PICOCLAW_HOME"); home != "" { + return home + } home, err := os.UserHomeDir() if err != nil { return "" @@ -47,8 +55,11 @@ func getGlobalConfigDir() string { func NewContextBuilder(workspace string) *ContextBuilder { // builtin skills: skills directory in current project // Use the skills/ directory under the current working directory - wd, _ := os.Getwd() - builtinSkillsDir := filepath.Join(wd, "skills") + builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS")) + if builtinSkillsDir == "" { + wd, _ := os.Getwd() + builtinSkillsDir = filepath.Join(wd, "skills") + } globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills") return &ContextBuilder{ @@ -148,6 +159,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string { cb.cachedSystemPrompt = prompt cb.cachedAt = baseline.maxMtime cb.existedAtCache = baseline.existed + cb.skillFilesAtCache = baseline.skillFiles logger.DebugCF("agent", "System prompt cached", map[string]any{ @@ -167,14 +179,14 @@ func (cb *ContextBuilder) InvalidateCache() { cb.cachedSystemPrompt = "" cb.cachedAt = time.Time{} cb.existedAtCache = nil + cb.skillFilesAtCache = nil logger.DebugCF("agent", "System prompt cache invalidated", nil) } -// sourcePaths returns the workspace source file paths tracked for cache -// invalidation (bootstrap files + memory). The skills directory is handled -// separately in sourceFilesChangedLocked because it requires both directory- -// level and recursive file-level mtime checks. +// sourcePaths returns non-skill workspace source files tracked for cache +// invalidation (bootstrap files + memory). Skill roots are handled separately +// because they require both directory-level and recursive file-level checks. func (cb *ContextBuilder) sourcePaths() []string { return []string{ filepath.Join(cb.workspace, "AGENTS.md"), @@ -185,23 +197,39 @@ func (cb *ContextBuilder) sourcePaths() []string { } } +// skillRoots returns all skill root directories that can affect +// BuildSkillsSummary output (workspace/global/builtin). +func (cb *ContextBuilder) skillRoots() []string { + if cb.skillsLoader == nil { + return []string{filepath.Join(cb.workspace, "skills")} + } + + roots := cb.skillsLoader.SkillRoots() + if len(roots) == 0 { + return []string{filepath.Join(cb.workspace, "skills")} + } + return roots +} + // cacheBaseline holds the file existence snapshot and the latest observed // mtime across all tracked paths. Used as the cache reference point. type cacheBaseline struct { - existed map[string]bool - maxMtime time.Time + existed map[string]bool + skillFiles map[string]time.Time + maxMtime time.Time } // buildCacheBaseline records which tracked paths currently exist and computes // the latest mtime across all tracked files + skills directory contents. // Called under write lock when the cache is built. func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { - skillsDir := filepath.Join(cb.workspace, "skills") + skillRoots := cb.skillRoots() - // All paths whose existence we track: source files + skills dir. - allPaths := append(cb.sourcePaths(), skillsDir) + // All paths whose existence we track: source files + all skill roots. + allPaths := append(cb.sourcePaths(), skillRoots...) existed := make(map[string]bool, len(allPaths)) + skillFiles := make(map[string]time.Time) var maxMtime time.Time for _, p := range allPaths { @@ -212,17 +240,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { } } - // Walk skills files to capture their mtimes too. - // Use os.Stat (not d.Info) to match the stat method used in - // fileChangedSince / skillFilesModifiedSince for consistency. - _ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { - if walkErr == nil && !d.IsDir() { - if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) { - maxMtime = info.ModTime() + // Walk all skill roots recursively to snapshot skill files and mtimes. + // Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks. + for _, root := range skillRoots { + _ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr == nil && !d.IsDir() { + if info, err := os.Stat(path); err == nil { + skillFiles[path] = info.ModTime() + if info.ModTime().After(maxMtime) { + maxMtime = info.ModTime() + } + } } - } - return nil - }) + return nil + }) + } // If no tracked files exist yet (empty workspace), maxMtime is zero. // Use a very old non-zero time so that: @@ -234,7 +266,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { maxMtime = time.Unix(1, 0) } - return cacheBaseline{existed: existed, maxMtime: maxMtime} + return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime} } // sourceFilesChangedLocked checks whether any workspace source file has been @@ -254,21 +286,17 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool { return true } - // --- Skills directory (handled separately from sourcePaths) --- + // --- Skill roots (workspace/global/builtin) --- // - // 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files. - skillsDir := filepath.Join(cb.workspace, "skills") - if cb.fileChangedSince(skillsDir) { - return true + // For each root: + // 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince. + // 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot. + for _, root := range cb.skillRoots() { + if cb.fileChangedSince(root) { + return true + } } - - // 2. Structural changes (add/remove entries inside the dir) are reflected - // in the directory's own mtime, which fileChangedSince already checks. - // - // 3. Content-only edits to files inside skills/ do NOT update the parent - // directory mtime on most filesystems, so we recursively walk to check - // individual file mtimes at any nesting depth. - if skillFilesModifiedSince(skillsDir, cb.cachedAt) { + if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) { return true } @@ -309,28 +337,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool { // if the callback returned nil when its err parameter is non-nil. var errWalkStop = errors.New("walk stop") -// skillFilesModifiedSince recursively walks the skills directory and checks -// whether any file was modified after t. This catches content-only edits at -// any nesting depth (e.g. skills/name/docs/extra.md) that don't update -// parent directory mtimes. -func skillFilesModifiedSince(skillsDir string, t time.Time) bool { - changed := false - err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { - if walkErr == nil && !d.IsDir() { - if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) { - changed = true - return errWalkStop // stop walking - } - } - return nil - }) - // errWalkStop is expected (early exit on first changed file). - // os.IsNotExist means the skills dir doesn't exist yet — not an error. - // Any other error is unexpected and worth logging. - if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) { - logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()}) +// skillFilesChangedSince compares the current recursive skill file tree +// against the cache-time snapshot. Any create/delete/mtime drift invalidates +// the cache. +func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool { + // Defensive: if the snapshot was never initialized, force rebuild. + if filesAtCache == nil { + return true } - return changed + + // Check cached files still exist and keep the same mtime. + for path, cachedMtime := range filesAtCache { + info, err := os.Stat(path) + if err != nil { + // A previously tracked file disappeared (or became inaccessible): + // either way, cached skill summary may now be stale. + return true + } + if !info.ModTime().Equal(cachedMtime) { + return true + } + } + + // Check no new files appeared under any skill root. + changed := false + for _, root := range skillRoots { + if strings.TrimSpace(root) == "" { + continue + } + + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + // Treat unexpected walk errors as changed to avoid stale cache. + if !os.IsNotExist(walkErr) { + changed = true + return errWalkStop + } + return nil + } + if d.IsDir() { + return nil + } + if _, ok := filesAtCache[path]; !ok { + changed = true + return errWalkStop + } + return nil + }) + + if changed { + return true + } + if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) { + logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()}) + return true + } + } + + return false } func (cb *ContextBuilder) LoadBootstrapFiles() string { @@ -466,10 +530,14 @@ func (cb *ContextBuilder) BuildMessages( // Add current user message if strings.TrimSpace(currentMessage) != "" { - messages = append(messages, providers.Message{ + msg := providers.Message{ Role: "user", Content: currentMessage, - }) + } + if len(media) > 0 { + msg.Media = media + } + messages = append(messages, msg) } return messages diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go index 0905e8a46..707510820 100644 --- a/pkg/agent/context_cache_test.go +++ b/pkg/agent/context_cache_test.go @@ -383,6 +383,162 @@ Updated content.` } } +// TestGlobalSkillFileContentChange verifies that modifying a global skill +// (~/.picoclaw/skills) invalidates the cached system prompt. +func TestGlobalSkillFileContentChange(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md") + if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil { + t.Fatal(err) + } + v1 := `--- +name: global-skill +description: global-v1 +--- +# Global Skill v1` + if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil { + t.Fatal(err) + } + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "global-v1") { + t.Fatal("expected initial prompt to contain global skill description") + } + + v2 := `--- +name: global-skill +description: global-v2 +--- +# Global Skill v2` + if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil { + t.Fatal(err) + } + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(globalSkillPath, future, future); err != nil { + t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect global skill file content change") + } + + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, "global-v2") { + t.Error("rebuilt prompt should contain updated global skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when global skill file content changes") + } +} + +// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill +// invalidates the cached system prompt. +func TestBuiltinSkillFileContentChange(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + builtinRoot := t.TempDir() + t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot) + + builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md") + if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil { + t.Fatal(err) + } + v1 := `--- +name: builtin-skill +description: builtin-v1 +--- +# Builtin Skill v1` + if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil { + t.Fatal(err) + } + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "builtin-v1") { + t.Fatal("expected initial prompt to contain builtin skill description") + } + + v2 := `--- +name: builtin-skill +description: builtin-v2 +--- +# Builtin Skill v2` + if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil { + t.Fatal(err) + } + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(builtinSkillPath, future, future); err != nil { + t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change") + } + + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, "builtin-v2") { + t.Error("rebuilt prompt should contain updated builtin skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when builtin skill file content changes") + } +} + +// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill +// file invalidates the cached system prompt. +func TestSkillFileDeletionInvalidatesCache(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "skills/delete-me/SKILL.md": `--- +name: delete-me +description: delete-me-v1 +--- +# Delete Me`, + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "delete-me-v1") { + t.Fatal("expected initial prompt to contain skill description") + } + + skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md") + if err := os.Remove(skillPath); err != nil { + t.Fatal(err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect deleted skill file") + } + + sp2 := cb.BuildSystemPromptWithCache() + if strings.Contains(sp2, "delete-me-v1") { + t.Error("rebuilt prompt should not contain deleted skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when skill file is deleted") + } +} + // TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines // can safely call BuildSystemPromptWithCache concurrently without producing // empty results, panics, or data races. diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index ec8871e30..97cf0fa05 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -18,22 +18,25 @@ import ( // AgentInstance represents a fully configured agent with its own workspace, // session manager, context builder, and tool registry. type AgentInstance struct { - ID string - Name string - Model string - Fallbacks []string - Workspace string - MaxIterations int - MaxTokens int - Temperature float64 - ContextWindow int - Provider providers.LLMProvider - Sessions *session.SessionManager - ContextBuilder *ContextBuilder - Tools *tools.ToolRegistry - Subagents *config.SubagentsConfig - SkillsFilter []string - Candidates []providers.FallbackCandidate + ID string + Name string + Model string + Fallbacks []string + Workspace string + MaxIterations int + MaxTokens int + Temperature float64 + ThinkingLevel ThinkingLevel + ContextWindow int + SummarizeMessageThreshold int + SummarizeTokenPercent int + Provider providers.LLMProvider + Sessions *session.SessionManager + ContextBuilder *ContextBuilder + Tools *tools.ToolRegistry + Subagents *config.SubagentsConfig + SkillsFilter []string + Candidates []providers.FallbackCandidate // Router is non-nil when model routing is configured and the light model // was successfully resolved. It scores each incoming message and decides @@ -65,17 +68,30 @@ func NewAgentInstance( allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths) toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths)) - toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths)) - toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths)) - execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg) - if err != nil { - log.Fatalf("Critical error: unable to initialize exec tool: %v", err) - } - toolsRegistry.Register(execTool) - toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths)) - toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths)) + if cfg.Tools.IsToolEnabled("read_file") { + toolsRegistry.Register(tools.NewReadFileTool(workspace, readRestrict, allowReadPaths)) + } + if cfg.Tools.IsToolEnabled("write_file") { + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict, allowWritePaths)) + } + if cfg.Tools.IsToolEnabled("list_dir") { + toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths)) + } + if cfg.Tools.IsToolEnabled("exec") { + execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg) + if err != nil { + log.Fatalf("Critical error: unable to initialize exec tool: %v", err) + } + toolsRegistry.Register(execTool) + } + + if cfg.Tools.IsToolEnabled("edit_file") { + toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict, allowWritePaths)) + } + if cfg.Tools.IsToolEnabled("append_file") { + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict, allowWritePaths)) + } sessionsDir := filepath.Join(workspace, "sessions") sessionsManager := session.NewSessionManager(sessionsDir) @@ -109,6 +125,22 @@ func NewAgentInstance( temperature = *defaults.Temperature } + var thinkingLevelStr string + if mc, err := cfg.GetModelConfig(model); err == nil { + thinkingLevelStr = mc.ThinkingLevel + } + thinkingLevel := parseThinkingLevel(thinkingLevelStr) + + summarizeMessageThreshold := defaults.SummarizeMessageThreshold + if summarizeMessageThreshold == 0 { + summarizeMessageThreshold = 20 + } + + summarizeTokenPercent := defaults.SummarizeTokenPercent + if summarizeTokenPercent == 0 { + summarizeTokenPercent = 75 + } + // Resolve fallback candidates modelCfg := providers.ModelConfig{ Primary: model, @@ -176,24 +208,27 @@ func NewAgentInstance( } return &AgentInstance{ - ID: agentID, - Name: agentName, - Model: model, - Fallbacks: fallbacks, - Workspace: workspace, - MaxIterations: maxIter, - MaxTokens: maxTokens, - Temperature: temperature, - ContextWindow: maxTokens, - Provider: provider, - Sessions: sessionsManager, - ContextBuilder: contextBuilder, - Tools: toolsRegistry, - Subagents: subagents, - SkillsFilter: skillsFilter, - Candidates: candidates, - Router: router, - LightCandidates: lightCandidates, + ID: agentID, + Name: agentName, + Model: model, + Fallbacks: fallbacks, + Workspace: workspace, + MaxIterations: maxIter, + MaxTokens: maxTokens, + Temperature: temperature, + ThinkingLevel: thinkingLevel, + ContextWindow: maxTokens, + SummarizeMessageThreshold: summarizeMessageThreshold, + SummarizeTokenPercent: summarizeTokenPercent, + Provider: provider, + Sessions: sessionsManager, + ContextBuilder: contextBuilder, + Tools: toolsRegistry, + Subagents: subagents, + SkillsFilter: skillsFilter, + Candidates: candidates, + Router: router, + LightCandidates: lightCandidates, } } @@ -202,12 +237,13 @@ func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentD if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" { return expandHome(strings.TrimSpace(agentCfg.Workspace)) } + // Use the configured default workspace (respects PICOCLAW_HOME) if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" { return expandHome(defaults.Workspace) } - home, _ := os.UserHomeDir() + // For named agents without explicit workspace, use default workspace with agent ID suffix id := routing.NormalizeAgentID(agentCfg.ID) - return filepath.Join(home, ".picoclaw", "workspace-"+id) + return filepath.Join(expandHome(defaults.Workspace), "..", "workspace-"+id) } // resolveAgentModel resolves the primary model for an agent. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8c78c2e89..5e68e4931 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "path/filepath" + "regexp" "strings" "sync" "sync/atomic" @@ -31,6 +32,7 @@ import ( "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" ) type AgentLoop struct { @@ -43,18 +45,20 @@ type AgentLoop struct { fallback *providers.FallbackChain channelManager *channels.Manager mediaStore media.MediaStore + transcriber voice.Transcriber } // processOptions configures how a message is processed type processOptions struct { - SessionKey string // Session identifier for history/context - Channel string // Target channel for tool execution - ChatID string // Target chat ID for tool execution - UserMessage string // User message content (may include prefix) - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - NoHistory bool // If true, don't load session history (for heartbeat) + SessionKey string // Session identifier for history/context + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + UserMessage string // User message content (may include prefix) + Media []string // media:// refs from inbound message + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) } const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." @@ -104,71 +108,106 @@ func registerSharedTools( } // Web tools - searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKey: cfg.Tools.Web.Brave.APIKey, - BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, - BraveEnabled: cfg.Tools.Web.Brave.Enabled, - TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, - TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, - TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, - TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, - DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, - DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, - PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, - PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, - PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, - Proxy: cfg.Tools.Web.Proxy, - }) - if err != nil { - logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) - } else if searchTool != nil { - agent.Tools.Register(searchTool) + if cfg.Tools.IsToolEnabled("web") { + searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey, + TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL, + TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults, + TavilyEnabled: cfg.Tools.Web.Tavily.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, + PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL, + SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults, + SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled, + GLMSearchAPIKey: cfg.Tools.Web.GLMSearch.APIKey, + GLMSearchBaseURL: cfg.Tools.Web.GLMSearch.BaseURL, + GLMSearchEngine: cfg.Tools.Web.GLMSearch.SearchEngine, + GLMSearchMaxResults: cfg.Tools.Web.GLMSearch.MaxResults, + GLMSearchEnabled: cfg.Tools.Web.GLMSearch.Enabled, + Proxy: cfg.Tools.Web.Proxy, + }) + if err != nil { + logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) + } else if searchTool != nil { + agent.Tools.Register(searchTool) + } } - fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) - if err != nil { - logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) - } else { - agent.Tools.Register(fetchTool) + if cfg.Tools.IsToolEnabled("web_fetch") { + fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy, cfg.Tools.Web.FetchLimitBytes) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } else { + agent.Tools.Register(fetchTool) + } } // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms - agent.Tools.Register(tools.NewI2CTool()) - agent.Tools.Register(tools.NewSPITool()) + if cfg.Tools.IsToolEnabled("i2c") { + agent.Tools.Register(tools.NewI2CTool()) + } + if cfg.Tools.IsToolEnabled("spi") { + agent.Tools.Register(tools.NewSPITool()) + } // Message tool - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer pubCancel() - return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, + if cfg.Tools.IsToolEnabled("message") { + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) }) - }) - agent.Tools.Register(messageTool) + agent.Tools.Register(messageTool) + } // Skill discovery and installation tools - registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ - MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, - ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), - }) - searchCache := skills.NewSearchCache( - cfg.Tools.Skills.SearchCache.MaxSize, - time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, - ) - agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) - agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + skills_enabled := cfg.Tools.IsToolEnabled("skills") + find_skills_enable := cfg.Tools.IsToolEnabled("find_skills") + install_skills_enable := cfg.Tools.IsToolEnabled("install_skill") + if skills_enabled && (find_skills_enable || install_skills_enable) { + registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{ + MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches, + ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub), + }) + + if find_skills_enable { + searchCache := skills.NewSearchCache( + cfg.Tools.Skills.SearchCache.MaxSize, + time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second, + ) + agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache)) + } + + if install_skills_enable { + agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace)) + } + } // Spawn tool with allowlist checker - subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) - subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) - spawnTool := tools.NewSpawnTool(subagentManager) - currentAgentID := agentID - spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { - return registry.CanSpawnSubagent(currentAgentID, targetAgentID) - }) - agent.Tools.Register(spawnTool) + if cfg.Tools.IsToolEnabled("spawn") { + if cfg.Tools.IsToolEnabled("subagent") { + subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature) + spawnTool := tools.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) + } else { + logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil) + } + } } } @@ -176,8 +215,19 @@ func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) // Initialize MCP servers for all agents - if al.cfg.Tools.MCP.Enabled { + if al.cfg.Tools.IsToolEnabled("mcp") { mcpManager := mcp.NewManager() + // Ensure MCP connections are cleaned up on exit, regardless of initialization success + // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails + defer func() { + if err := mcpManager.Close(); err != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]any{ + "error": err.Error(), + }) + } + }() + defaultAgent := al.registry.GetDefaultAgent() var workspacePath string if defaultAgent != nil && defaultAgent.Workspace != "" { @@ -192,16 +242,6 @@ func (al *AgentLoop) Run(ctx context.Context) error { "error": err.Error(), }) } else { - // Ensure MCP connections are cleaned up on exit, only if initialization succeeded - defer func() { - if err := mcpManager.Close(); err != nil { - logger.ErrorCF("agent", "Failed to close MCP manager", - map[string]any{ - "error": err.Error(), - }) - } - }() - // Register MCP tools for all agents servers := mcpManager.GetServers() uniqueTools := 0 @@ -217,6 +257,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { if !ok { continue } + mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) agent.Tools.Register(mcpTool) totalRegistrations++ @@ -332,6 +373,64 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s } +// SetTranscriber injects a voice transcriber for agent-level audio transcription. +func (al *AgentLoop) SetTranscriber(t voice.Transcriber) { + al.transcriber = t +} + +var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`) + +// transcribeAudioInMessage resolves audio media refs, transcribes them, and +// replaces audio annotations in msg.Content with the transcribed text. +func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) bus.InboundMessage { + if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 { + return msg + } + + // Transcribe each audio media ref in order. + var transcriptions []string + for _, ref := range msg.Media { + path, meta, err := al.mediaStore.ResolveWithMeta(ref) + if err != nil { + logger.WarnCF("voice", "Failed to resolve media ref", map[string]any{"ref": ref, "error": err}) + continue + } + if !utils.IsAudioFile(meta.Filename, meta.ContentType) { + continue + } + result, err := al.transcriber.Transcribe(ctx, path) + if err != nil { + logger.WarnCF("voice", "Transcription failed", map[string]any{"ref": ref, "error": err}) + transcriptions = append(transcriptions, "") + continue + } + transcriptions = append(transcriptions, result.Text) + } + + if len(transcriptions) == 0 { + return msg + } + + // Replace audio annotations sequentially with transcriptions. + idx := 0 + newContent := audioAnnotationRe.ReplaceAllStringFunc(msg.Content, func(match string) string { + if idx >= len(transcriptions) { + return match + } + text := transcriptions[idx] + idx++ + return "[voice: " + text + "]" + }) + + // Append any remaining transcriptions not matched by an annotation. + for ; idx < len(transcriptions); idx++ { + newContent += "\n[voice: " + transcriptions[idx] + "]" + } + + msg.Content = newContent + return msg +} + // inferMediaType determines the media type ("image", "audio", "video", "file") // from a filename and MIME content type. func inferMediaType(filename, contentType string) string { @@ -443,6 +542,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) }, ) + msg = al.transcribeAudioInMessage(ctx, msg) + // Route system messages to processSystemMessage if msg.Channel == "system" { return al.processSystemMessage(ctx, msg) @@ -473,8 +574,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) // Reset message-tool state for this round so we don't skip publishing due to a previous round. if tool, ok := agent.Tools.Get("message"); ok { - if mt, ok := tool.(tools.ContextualTool); ok { - mt.SetContext(msg.Channel, msg.ChatID) + if resetter, ok := tool.(interface{ ResetSentInRound() }); ok { + resetter.ResetSentInRound() } } @@ -496,6 +597,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, + Media: msg.Media, DefaultResponse: defaultResponse, EnableSummary: true, SendResponse: false, @@ -588,10 +690,7 @@ func (al *AgentLoop) runAgentLoop( } } - // 1. Update tool contexts - al.updateToolContexts(agent, opts.Channel, opts.ChatID) - - // 2. Build messages (skip history for heartbeat) + // 1. Build messages (skip history for heartbeat) var history []providers.Message var summary string if !opts.NoHistory { @@ -602,15 +701,19 @@ func (al *AgentLoop) runAgentLoop( history, summary, opts.UserMessage, - nil, + opts.Media, opts.Channel, opts.ChatID, ) - // 3. Save user message to session + // Resolve media:// refs to base64 data URLs (streaming) + maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize() + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) + + // 2. Save user message to session agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - // 4. Run LLM iteration loop + // 3. Run LLM iteration loop finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err @@ -619,21 +722,21 @@ func (al *AgentLoop) runAgentLoop( // If last tool had ForUser content and we already sent it, we might not need to send final response // This is controlled by the tool's Silent flag and ForUser content - // 5. Handle empty response + // 4. Handle empty response if finalContent == "" { finalContent = opts.DefaultResponse } - // 6. Save final assistant message to session + // 5. Save final assistant message to session agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) agent.Sessions.Save(opts.SessionKey) - // 7. Optional: summarization + // 6. Optional: summarization if opts.EnableSummary { al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) } - // 8. Optional: send response via bus + // 7. Optional: send response via bus if opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, @@ -642,7 +745,7 @@ func (al *AgentLoop) runAgentLoop( }) } - // 9. Log response + // 8. Log response responsePreview := utils.Truncate(finalContent, 120) logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), map[string]any{ @@ -765,23 +868,29 @@ func (al *AgentLoop) runLLMIteration( var response *providers.LLMResponse var err error + llmOpts := map[string]any{ + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + "prompt_cache_key": agent.ID, + } + // parseThinkingLevel guarantees ThinkingOff for empty/unknown values, + // so checking != ThinkingOff is sufficient. + if agent.ThinkingLevel != ThinkingOff { + if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() { + llmOpts["thinking_level"] = string(agent.ThinkingLevel) + } else { + logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring", + map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)}) + } + } + callLLM := func() (*providers.LLMResponse, error) { if len(activeCandidates) > 1 && al.fallback != nil { fbResult, fbErr := al.fallback.Execute( ctx, activeCandidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return agent.Provider.Chat( - ctx, - messages, - providerToolDefs, - model, - map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, - }, - ) + return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts) }, ) if fbErr != nil { @@ -797,11 +906,7 @@ func (al *AgentLoop) runLLMIteration( } return fbResult.Response, nil } - return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, - }) + return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts) } // Retry loop for context/token errors @@ -963,62 +1068,76 @@ func (al *AgentLoop) runLLMIteration( // Save assistant message with tool calls to session agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) - // Execute tool calls - for _, tc := range normalizedToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "agent_id": agent.ID, - "tool": tc.Name, - "iteration": iteration, - }) + // Execute tool calls in parallel + type indexedAgentResult struct { + result *tools.ToolResult + tc providers.ToolCall + } - // Create async callback for tools that implement AsyncTool - // NOTE: Following openclaw's design, async tools do NOT send results directly to users. - // Instead, they notify the agent via PublishInbound, and the agent decides - // whether to forward the result to the user (in processSystemMessage). - asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { - // Log the async completion but don't send directly to user - // The agent will handle user notification via processSystemMessage - if !result.Silent && result.ForUser != "" { - logger.InfoCF("agent", "Async tool completed, agent will handle notification", - map[string]any{ - "tool": tc.Name, - "content_len": len(result.ForUser), - }) + agentResults := make([]indexedAgentResult, len(normalizedToolCalls)) + var wg sync.WaitGroup + + for i, tc := range normalizedToolCalls { + agentResults[i].tc = tc + + wg.Add(1) + go func(idx int, tc providers.ToolCall) { + defer wg.Done() + + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "agent_id": agent.ID, + "tool": tc.Name, + "iteration": iteration, + }) + + // Create async callback for tools that implement AsyncExecutor + asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { + if !result.Silent && result.ForUser != "" { + logger.InfoCF("agent", "Async tool completed, agent will handle notification", + map[string]any{ + "tool": tc.Name, + "content_len": len(result.ForUser), + }) + } } - } - toolResult := agent.Tools.ExecuteWithContext( - ctx, - tc.Name, - tc.Arguments, - opts.Channel, - opts.ChatID, - asyncCallback, - ) + toolResult := agent.Tools.ExecuteWithContext( + ctx, + tc.Name, + tc.Arguments, + opts.Channel, + opts.ChatID, + asyncCallback, + ) + agentResults[idx].result = toolResult + }(i, tc) + } + wg.Wait() + // Process results in original order (send to user, save to session) + for _, r := range agentResults { // Send ForUser content to user immediately if not Silent - if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + if !r.result.Silent && r.result.ForUser != "" && opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: toolResult.ForUser, + Content: r.result.ForUser, }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": tc.Name, - "content_len": len(toolResult.ForUser), + "tool": r.tc.Name, + "content_len": len(r.result.ForUser), }) } // If tool returned media refs, publish them as outbound media - if len(toolResult.Media) > 0 && opts.SendResponse { - parts := make([]bus.MediaPart, 0, len(toolResult.Media)) - for _, ref := range toolResult.Media { + if len(r.result.Media) > 0 && opts.SendResponse { + parts := make([]bus.MediaPart, 0, len(r.result.Media)) + for _, ref := range r.result.Media { part := bus.MediaPart{Ref: ref} - // Populate metadata from MediaStore when available if al.mediaStore != nil { if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { part.Filename = meta.Filename @@ -1036,15 +1155,15 @@ func (al *AgentLoop) runLLMIteration( } // Determine content for LLM based on tool result - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() + contentForLLM := r.result.ForLLM + if contentForLLM == "" && r.result.Err != nil { + contentForLLM = r.result.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, + ToolCallID: r.tc.ID, } messages = append(messages, toolResultMsg) @@ -1087,33 +1206,13 @@ func (al *AgentLoop) selectCandidates( return agent.LightCandidates, agent.Router.LightModel() } -// updateToolContexts updates the context for tools that need channel/chatID info. -func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) { - // Use ContextualTool interface instead of type assertions - if tool, ok := agent.Tools.Get("message"); ok { - if mt, ok := tool.(tools.ContextualTool); ok { - mt.SetContext(channel, chatID) - } - } - if tool, ok := agent.Tools.Get("spawn"); ok { - if st, ok := tool.(tools.ContextualTool); ok { - st.SetContext(channel, chatID) - } - } - if tool, ok := agent.Tools.Get("subagent"); ok { - if st, ok := tool.(tools.ContextualTool); ok { - st.SetContext(channel, chatID) - } - } -} - // maybeSummarize triggers summarization if the session history exceeds thresholds. func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) - threshold := agent.ContextWindow * 75 / 100 + threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100 - if len(newHistory) > 20 || tokenEstimate > threshold { + if len(newHistory) > agent.SummarizeMessageThreshold || tokenEstimate > threshold { summarizeKey := agent.ID + ":" + sessionKey if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { diff --git a/pkg/agent/loop_media.go b/pkg/agent/loop_media.go new file mode 100644 index 000000000..82547a008 --- /dev/null +++ b/pkg/agent/loop_media.go @@ -0,0 +1,122 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package agent + +import ( + "bytes" + "encoding/base64" + "io" + "os" + "strings" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs. +// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding +// both raw bytes and encoded string in memory simultaneously. +// Returns a new slice; original messages are not mutated. +func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message { + if store == nil { + return messages + } + + result := make([]providers.Message, len(messages)) + copy(result, messages) + + for i, m := range result { + if len(m.Media) == 0 { + continue + } + + resolved := make([]string, 0, len(m.Media)) + for _, ref := range m.Media { + if !strings.HasPrefix(ref, "media://") { + resolved = append(resolved, ref) + continue + } + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{ + "ref": ref, + "error": err.Error(), + }) + continue + } + + info, err := os.Stat(localPath) + if err != nil { + logger.WarnCF("agent", "Failed to stat media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + if info.Size() > int64(maxSize) { + logger.WarnCF("agent", "Media file too large, skipping", map[string]any{ + "path": localPath, + "size": info.Size(), + "max_size": maxSize, + }) + continue + } + + // Determine MIME type: prefer metadata, fallback to magic-bytes detection + mime := meta.ContentType + if mime == "" { + kind, ftErr := filetype.MatchFile(localPath) + if ftErr != nil || kind == filetype.Unknown { + logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{ + "path": localPath, + }) + continue + } + mime = kind.MIME.Value + } + + // Streaming base64: open file → base64 encoder → buffer + // Peak memory: ~1.33x file size (buffer only, no raw bytes copy) + f, err := os.Open(localPath) + if err != nil { + logger.WarnCF("agent", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + + prefix := "data:" + mime + ";base64," + encodedLen := base64.StdEncoding.EncodedLen(int(info.Size())) + var buf bytes.Buffer + buf.Grow(len(prefix) + encodedLen) + buf.WriteString(prefix) + + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + if _, err := io.Copy(encoder, f); err != nil { + f.Close() + logger.WarnCF("agent", "Failed to encode media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + encoder.Close() + f.Close() + + resolved = append(resolved, buf.String()) + } + + result[i].Media = resolved + } + + return result +} diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 3565314fe..aa7d59b5a 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -6,12 +6,14 @@ import ( "os" "path/filepath" "slices" + "strings" "testing" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -162,35 +164,21 @@ func TestToolRegistry_ToolRegistration(t *testing.T) { } } -// TestToolContext_Updates verifies tool context is updated with channel/chatID +// TestToolContext_Updates verifies tool context helpers work correctly func TestToolContext_Updates(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agent-test-*") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tmpDir) + ctx := tools.WithToolContext(context.Background(), "telegram", "chat-42") - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, + if got := tools.ToolChannel(ctx); got != "telegram" { + t.Errorf("expected channel 'telegram', got %q", got) + } + if got := tools.ToolChatID(ctx); got != "chat-42" { + t.Errorf("expected chatID 'chat-42', got %q", got) } - msgBus := bus.NewMessageBus() - provider := &simpleMockProvider{response: "OK"} - _ = NewAgentLoop(cfg, msgBus, provider) - - // Verify that ContextualTool interface is defined and can be implemented - // This test validates the interface contract exists - ctxTool := &mockContextualTool{} - - // Verify the tool implements the interface correctly - var _ tools.ContextualTool = ctxTool + // Empty context returns empty strings + if got := tools.ToolChannel(context.Background()); got != "" { + t.Errorf("expected empty channel from bare context, got %q", got) + } } // TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved @@ -239,16 +227,11 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { } defer os.RemoveAll(tmpDir) - cfg := &config.Config{ - Agents: config.AgentsConfig{ - Defaults: config.AgentDefaults{ - Workspace: tmpDir, - Model: "test-model", - MaxTokens: 4096, - MaxToolIterations: 10, - }, - }, - } + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 msgBus := bus.NewMessageBus() provider := &mockProvider{} @@ -357,36 +340,6 @@ func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tool return tools.SilentResult("Custom tool executed") } -// mockContextualTool tracks context updates -type mockContextualTool struct { - lastChannel string - lastChatID string -} - -func (m *mockContextualTool) Name() string { - return "mock_contextual" -} - -func (m *mockContextualTool) Description() string { - return "Mock contextual tool" -} - -func (m *mockContextualTool) Parameters() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{}, - } -} - -func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { - return tools.SilentResult("Contextual tool executed") -} - -func (m *mockContextualTool) SetContext(channel, chatID string) { - m.lastChannel = channel - m.lastChatID = chatID -} - // testHelper executes a message and returns the response type testHelper struct { al *AgentLoop @@ -808,3 +761,142 @@ func TestHandleReasoning(t *testing.T) { } }) } + +func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + // Create a minimal valid PNG (8-byte header is enough for filetype detection) + pngPath := filepath.Join(dir, "test.png") + // PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR + pngHeader := []byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // "IHDR" + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB + 0x00, 0x00, 0x00, // no interlace + 0x90, 0x77, 0x53, 0xDE, // CRC + } + if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil { + t.Fatal(err) + } + ref, err := store.Store(pngPath, media.MediaMeta{}, "test") + if err != nil { + t.Fatal(err) + } + + messages := []providers.Message{ + {Role: "user", Content: "describe this", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 { + t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media)) + } + if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") { + t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40]) + } +} + +func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + bigPath := filepath.Join(dir, "big.png") + // Write PNG header + padding to exceed limit + data := make([]byte, 1024+1) // 1KB + 1 byte + copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}) + if err := os.WriteFile(bigPath, data, 0o644); err != nil { + t.Fatal(err) + } + ref, _ := store.Store(bigPath, media.MediaMeta{}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + // Use a tiny limit (1KB) so the file is oversized + result := resolveMediaRefs(messages, store, 1024) + + if len(result[0].Media) != 0 { + t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media)) + } +} + +func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + txtPath := filepath.Join(dir, "readme.txt") + if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil { + t.Fatal(err) + } + ref, _ := store.Store(txtPath, media.MediaMeta{}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 0 { + t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media)) + } +} + +func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) { + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}}, + } + result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" { + t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media) + } +} + +func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + pngPath := filepath.Join(dir, "test.png") + pngHeader := []byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, + 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, + 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, + } + os.WriteFile(pngPath, pngHeader, 0o644) + ref, _ := store.Store(pngPath, media.MediaMeta{}, "test") + + original := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + originalRef := original[0].Media[0] + + resolveMediaRefs(original, store, config.DefaultMaxMediaSize) + + if original[0].Media[0] != originalRef { + t.Fatal("resolveMediaRefs mutated original message slice") + } +} + +func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + // File with JPEG content but stored with explicit content type + jpegPath := filepath.Join(dir, "photo") + jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes + os.WriteFile(jpegPath, jpegHeader, 0o644) + ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 { + t.Fatalf("expected 1 media, got %d", len(result[0].Media)) + } + if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") { + t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30]) + } +} diff --git a/pkg/agent/thinking.go b/pkg/agent/thinking.go new file mode 100644 index 000000000..015b69282 --- /dev/null +++ b/pkg/agent/thinking.go @@ -0,0 +1,39 @@ +package agent + +import "strings" + +// ThinkingLevel controls how the provider sends thinking parameters. +// +// - "adaptive": sends {thinking: {type: "adaptive"}} + output_config.effort (Claude 4.6+) +// - "low"/"medium"/"high"/"xhigh": sends {thinking: {type: "enabled", budget_tokens: N}} (all models) +// - "off": disables thinking +type ThinkingLevel string + +const ( + ThinkingOff ThinkingLevel = "off" + ThinkingLow ThinkingLevel = "low" + ThinkingMedium ThinkingLevel = "medium" + ThinkingHigh ThinkingLevel = "high" + ThinkingXHigh ThinkingLevel = "xhigh" + ThinkingAdaptive ThinkingLevel = "adaptive" +) + +// parseThinkingLevel normalizes a config string to a ThinkingLevel. +// Case-insensitive and whitespace-tolerant for user-facing config values. +// Returns ThinkingOff for unknown or empty values. +func parseThinkingLevel(level string) ThinkingLevel { + switch strings.ToLower(strings.TrimSpace(level)) { + case "adaptive": + return ThinkingAdaptive + case "low": + return ThinkingLow + case "medium": + return ThinkingMedium + case "high": + return ThinkingHigh + case "xhigh": + return ThinkingXHigh + default: + return ThinkingOff + } +} diff --git a/pkg/agent/thinking_test.go b/pkg/agent/thinking_test.go new file mode 100644 index 000000000..be3a68c33 --- /dev/null +++ b/pkg/agent/thinking_test.go @@ -0,0 +1,35 @@ +package agent + +import "testing" + +func TestParseThinkingLevel(t *testing.T) { + tests := []struct { + name string + input string + want ThinkingLevel + }{ + {"off", "off", ThinkingOff}, + {"empty", "", ThinkingOff}, + {"low", "low", ThinkingLow}, + {"medium", "medium", ThinkingMedium}, + {"high", "high", ThinkingHigh}, + {"xhigh", "xhigh", ThinkingXHigh}, + {"adaptive", "adaptive", ThinkingAdaptive}, + {"unknown", "unknown", ThinkingOff}, + // Case-insensitive and whitespace-tolerant + {"upper_Medium", "Medium", ThinkingMedium}, + {"upper_HIGH", "HIGH", ThinkingHigh}, + {"mixed_Adaptive", "Adaptive", ThinkingAdaptive}, + {"leading_space", " high", ThinkingHigh}, + {"trailing_space", "low ", ThinkingLow}, + {"both_spaces", " medium ", ThinkingMedium}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseThinkingLevel(tt.input); got != tt.want { + t.Errorf("parseThinkingLevel(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index 91c9e25c5..4667e3d81 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -212,7 +212,10 @@ func RequestDeviceCode(cfg OAuthProviderConfig) (*DeviceCodeInfo, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading device code response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("device code request failed: %s", string(body)) } @@ -300,7 +303,10 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading device code response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("device code request failed: %s", string(body)) } @@ -360,7 +366,10 @@ func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*Au return nil, fmt.Errorf("pending") } - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading device token response: %w", err) + } var tokenResp struct { AuthorizationCode string `json:"authorization_code"` @@ -401,7 +410,10 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading token refresh response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token refresh failed: %s", string(body)) } @@ -494,7 +506,10 @@ func ExchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirect } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading token exchange response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("token exchange failed: %s", string(body)) } diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 283dc6977..2e55d4877 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -39,6 +39,9 @@ func (c *AuthCredential) NeedsRefresh() bool { } func authFilePath() string { + if home := os.Getenv("PICOCLAW_HOME"); home != "" { + return filepath.Join(home, "auth.json") + } home, _ := os.UserHomeDir() return filepath.Join(home, ".picoclaw", "auth.json") } diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index cd6a2560f..1de910c83 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,12 +3,15 @@ package discord import ( "context" "fmt" + "net/http" + "net/url" "os" "strings" "sync" "time" "github.com/bwmarrin/discordgo" + "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -40,6 +43,9 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } + if err := applyDiscordProxy(session, cfg.Proxy); err != nil { + return nil, err + } base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(2000), channels.WithGroupTrigger(cfg.GroupTrigger), @@ -465,9 +471,43 @@ func (c *DiscordChannel) StartTyping(ctx context.Context, chatID string) (func() func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", + ProxyURL: c.config.Proxy, }) } +func applyDiscordProxy(session *discordgo.Session, proxyAddr string) error { + var proxyFunc func(*http.Request) (*url.URL, error) + if proxyAddr != "" { + proxyURL, err := url.Parse(proxyAddr) + if err != nil { + return fmt.Errorf("invalid discord proxy URL %q: %w", proxyAddr, err) + } + proxyFunc = http.ProxyURL(proxyURL) + } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { + proxyFunc = http.ProxyFromEnvironment + } + + if proxyFunc == nil { + return nil + } + + transport := &http.Transport{Proxy: proxyFunc} + session.Client = &http.Client{ + Timeout: sendTimeout, + Transport: transport, + } + + if session.Dialer != nil { + dialerCopy := *session.Dialer + dialerCopy.Proxy = proxyFunc + session.Dialer = &dialerCopy + } else { + session.Dialer = &websocket.Dialer{Proxy: proxyFunc} + } + + return nil +} + // stripBotMention removes the bot mention from the message content. // Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). func (c *DiscordChannel) stripBotMention(text string) string { diff --git a/pkg/channels/discord/discord_test.go b/pkg/channels/discord/discord_test.go new file mode 100644 index 000000000..0cd5328f4 --- /dev/null +++ b/pkg/channels/discord/discord_test.go @@ -0,0 +1,91 @@ +package discord + +import ( + "net/http" + "net/url" + "testing" + + "github.com/bwmarrin/discordgo" +) + +func TestApplyDiscordProxy_CustomProxy(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "http://127.0.0.1:7890"); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + restProxy := session.Client.Transport.(*http.Transport).Proxy + restProxyURL, err := restProxy(req) + if err != nil { + t.Fatalf("rest proxy func error: %v", err) + } + if got, want := restProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("REST proxy = %q, want %q", got, want) + } + + wsProxyURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + if got, want := wsProxyURL.String(), "http://127.0.0.1:7890"; got != want { + t.Fatalf("WS proxy = %q, want %q", got, want) + } +} + +func TestApplyDiscordProxy_FromEnvironment(t *testing.T) { + t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888") + t.Setenv("http_proxy", "http://127.0.0.1:8888") + t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + t.Setenv("https_proxy", "http://127.0.0.1:8888") + t.Setenv("ALL_PROXY", "") + t.Setenv("all_proxy", "") + t.Setenv("NO_PROXY", "") + t.Setenv("no_proxy", "") + + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, ""); err != nil { + t.Fatalf("applyDiscordProxy() error: %v", err) + } + + req, err := http.NewRequest("GET", "https://discord.com/api/v10/gateway", nil) + if err != nil { + t.Fatalf("http.NewRequest() error: %v", err) + } + + gotURL, err := session.Dialer.Proxy(req) + if err != nil { + t.Fatalf("ws proxy func error: %v", err) + } + + wantURL, err := url.Parse("http://127.0.0.1:8888") + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + if gotURL.String() != wantURL.String() { + t.Fatalf("WS proxy = %q, want %q", gotURL.String(), wantURL.String()) + } +} + +func TestApplyDiscordProxy_InvalidProxyURL(t *testing.T) { + session, err := discordgo.New("Bot test-token") + if err != nil { + t.Fatalf("discordgo.New() error: %v", err) + } + + if err = applyDiscordProxy(session, "://bad-proxy"); err == nil { + t.Fatal("applyDiscordProxy() expected error for invalid proxy URL, got nil") + } +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go index e8a057741..fbe085b73 100644 --- a/pkg/channels/feishu/common.go +++ b/pkg/channels/feishu/common.go @@ -1,5 +1,16 @@ package feishu +import ( + "encoding/json" + "regexp" + "strings" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions. +var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`) + // stringValue safely dereferences a *string pointer. func stringValue(v *string) string { if v == nil { @@ -7,3 +18,69 @@ func stringValue(v *string) string { } return *v } + +// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content. +// JSON 2.0 cards support full CommonMark standard markdown syntax. +func buildMarkdownCard(content string) (string, error) { + card := map[string]any{ + "schema": "2.0", + "body": map[string]any{ + "elements": []map[string]any{ + { + "tag": "markdown", + "content": content, + }, + }, + }, + } + data, err := json.Marshal(card) + if err != nil { + return "", err + } + return string(data), nil +} + +// extractJSONStringField unmarshals content as JSON and returns the value of the given string field. +// Returns "" if the content is invalid JSON or the field is missing/empty. +func extractJSONStringField(content, field string) string { + var m map[string]json.RawMessage + if err := json.Unmarshal([]byte(content), &m); err != nil { + return "" + } + raw, ok := m[field] + if !ok { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "" + } + return s +} + +// extractImageKey extracts the image_key from a Feishu image message content JSON. +// Format: {"image_key": "img_xxx"} +func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") } + +// extractFileKey extracts the file_key from a Feishu file/audio message content JSON. +// Format: {"file_key": "file_xxx", "file_name": "...", ...} +func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") } + +// extractFileName extracts the file_name from a Feishu file message content JSON. +func extractFileName(content string) string { return extractJSONStringField(content, "file_name") } + +// stripMentionPlaceholders removes @_user_N placeholders from the text content. +// These are inserted by Feishu when users @mention someone in a message. +func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string { + if len(mentions) == 0 { + return content + } + for _, m := range mentions { + if m.Key != nil && *m.Key != "" { + content = strings.ReplaceAll(content, *m.Key, "") + } + } + // Also clean up any remaining @_user_N patterns + content = mentionPlaceholderRegex.ReplaceAllString(content, "") + return strings.TrimSpace(content) +} diff --git a/pkg/channels/feishu/common_test.go b/pkg/channels/feishu/common_test.go new file mode 100644 index 000000000..fefc9f7c1 --- /dev/null +++ b/pkg/channels/feishu/common_test.go @@ -0,0 +1,292 @@ +package feishu + +import ( + "encoding/json" + "testing" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func TestExtractJSONStringField(t *testing.T) { + tests := []struct { + name string + content string + field string + want string + }{ + { + name: "valid field", + content: `{"image_key": "img_v2_xxx"}`, + field: "image_key", + want: "img_v2_xxx", + }, + { + name: "missing field", + content: `{"image_key": "img_v2_xxx"}`, + field: "file_key", + want: "", + }, + { + name: "invalid JSON", + content: `not json at all`, + field: "image_key", + want: "", + }, + { + name: "empty content", + content: "", + field: "image_key", + want: "", + }, + { + name: "non-string field value", + content: `{"count": 42}`, + field: "count", + want: "", + }, + { + name: "empty string value", + content: `{"image_key": ""}`, + field: "image_key", + want: "", + }, + { + name: "multiple fields", + content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`, + field: "file_name", + want: "test.pdf", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractJSONStringField(tt.content, tt.field) + if got != tt.want { + t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want) + } + }) + } +} + +func TestExtractImageKey(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"image_key": "img_v2_abc123"}`, + want: "img_v2_abc123", + }, + { + name: "missing key", + content: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `{broken`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractImageKey(tt.content) + if got != tt.want { + t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestExtractFileKey(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`, + want: "file_v2_abc123", + }, + { + name: "missing key", + content: `{"image_key": "img_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `not json`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFileKey(tt.content) + if got != tt.want { + t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestExtractFileName(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`, + want: "report.pdf", + }, + { + name: "missing name", + content: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `{bad`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFileName(tt.content) + if got != tt.want { + t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestBuildMarkdownCard(t *testing.T) { + tests := []struct { + name string + content string + }{ + { + name: "normal content", + content: "Hello **world**", + }, + { + name: "empty content", + content: "", + }, + { + name: "special characters", + content: `Code: "foo" & 'baz'`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := buildMarkdownCard(tt.content) + if err != nil { + t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err) + } + + // Verify valid JSON + var parsed map[string]any + if err := json.Unmarshal([]byte(result), &parsed); err != nil { + t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err) + } + + // Verify schema + if parsed["schema"] != "2.0" { + t.Errorf("schema = %v, want %q", parsed["schema"], "2.0") + } + + // Verify body.elements[0].content == input + body, ok := parsed["body"].(map[string]any) + if !ok { + t.Fatal("missing body in card JSON") + } + elements, ok := body["elements"].([]any) + if !ok || len(elements) == 0 { + t.Fatal("missing or empty elements in card JSON") + } + elem, ok := elements[0].(map[string]any) + if !ok { + t.Fatal("first element is not an object") + } + if elem["tag"] != "markdown" { + t.Errorf("tag = %v, want %q", elem["tag"], "markdown") + } + if elem["content"] != tt.content { + t.Errorf("content = %v, want %q", elem["content"], tt.content) + } + }) + } +} + +func TestStripMentionPlaceholders(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + content string + mentions []*larkim.MentionEvent + want string + }{ + { + name: "no mentions", + content: "Hello world", + mentions: nil, + want: "Hello world", + }, + { + name: "single mention", + content: "@_user_1 hello", + mentions: []*larkim.MentionEvent{ + {Key: strPtr("@_user_1")}, + }, + want: "hello", + }, + { + name: "multiple mentions", + content: "@_user_1 @_user_2 hey", + mentions: []*larkim.MentionEvent{ + {Key: strPtr("@_user_1")}, + {Key: strPtr("@_user_2")}, + }, + want: "hey", + }, + { + name: "empty content", + content: "", + mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}}, + want: "", + }, + { + name: "empty mentions slice", + content: "@_user_1 test", + mentions: []*larkim.MentionEvent{}, + want: "@_user_1 test", + }, + { + name: "mention with nil key", + content: "@_user_1 test", + mentions: []*larkim.MentionEvent{ + {Key: nil}, + }, + want: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripMentionPlaceholders(tt.content, tt.mentions) + if got != tt.want { + t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/feishu/feishu_32.go b/pkg/channels/feishu/feishu_32.go index d0ec758c6..f5e3aa224 100644 --- a/pkg/channels/feishu/feishu_32.go +++ b/pkg/channels/feishu/feishu_32.go @@ -16,6 +16,8 @@ type FeishuChannel struct { *channels.BaseChannel } +var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures") + // NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { return nil, errors.New( @@ -25,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan // Start is a stub method to satisfy the Channel interface func (c *FeishuChannel) Start(ctx context.Context) error { - return nil + return errUnsupported } // Stop is a stub method to satisfy the Channel interface func (c *FeishuChannel) Stop(ctx context.Context) error { - return nil + return errUnsupported } // Send is a stub method to satisfy the Channel interface func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - return errors.New("feishu channel is not supported on 32-bit architectures") + return errUnsupported +} + +// EditMessage is a stub method to satisfy MessageEditor +func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return errUnsupported +} + +// SendPlaceholder is a stub method to satisfy PlaceholderCapable +func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + return "", errUnsupported +} + +// ReactToMessage is a stub method to satisfy ReactionCapable +func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + return func() {}, errUnsupported +} + +// SendMedia is a stub method to satisfy MediaSender +func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + return errUnsupported } diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index 1db1bf669..00f73064d 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -6,10 +6,15 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "os" + "path/filepath" "sync" - "time" + "sync/atomic" lark "github.com/larksuite/oapi-sdk-go/v3" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" larkws "github.com/larksuite/oapi-sdk-go/v3/ws" @@ -19,6 +24,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -28,6 +34,8 @@ type FeishuChannel struct { client *lark.Client wsClient *larkws.Client + botOpenID atomic.Value // stores string; populated lazily for @mention detection + mu sync.Mutex cancel context.CancelFunc } @@ -38,11 +46,13 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) - return &FeishuChannel{ + ch := &FeishuChannel{ BaseChannel: base, config: cfg, client: lark.NewClient(cfg.AppID, cfg.AppSecret), - }, nil + } + ch.SetOwner(ch) + return ch, nil } func (c *FeishuChannel) Start(ctx context.Context) error { @@ -50,6 +60,13 @@ func (c *FeishuChannel) Start(ctx context.Context) error { return fmt.Errorf("feishu app_id or app_secret is empty") } + // Fetch bot open_id via API for reliable @mention detection. + if err := c.fetchBotOpenID(ctx); err != nil { + logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{ + "error": err.Error(), + }) + } + dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey). OnP2MessageReceiveV1(c.handleMessageReceive) @@ -93,46 +110,213 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { return nil } +// Send sends a message using Interactive Card format for markdown rendering. func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning } if msg.ChatID == "" { - return fmt.Errorf("chat ID is empty") + return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) } - payload, err := json.Marshal(map[string]string{"text": msg.Content}) + // Build interactive card with markdown content + cardContent, err := buildMarkdownCard(msg.Content) if err != nil { - return fmt.Errorf("failed to marshal feishu content: %w", err) + return fmt.Errorf("feishu send: card build failed: %w", err) + } + return c.sendCard(ctx, msg.ChatID, cardContent) +} + +// EditMessage implements channels.MessageEditor. +// Uses Message.Patch to update an interactive card message. +func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + cardContent, err := buildMarkdownCard(content) + if err != nil { + return fmt.Errorf("feishu edit: card build failed: %w", err) + } + + req := larkim.NewPatchMessageReqBuilder(). + MessageId(messageID). + Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()). + Build() + + resp, err := c.client.Im.V1.Message.Patch(ctx, req) + if err != nil { + return fmt.Errorf("feishu edit: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + return nil +} + +// SendPlaceholder implements channels.PlaceholderCapable. +// Sends an interactive card with placeholder text and returns its message ID. +func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + if !c.config.Placeholder.Enabled { + logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{ + "chat_id": chatID, + }) + return "", nil + } + + text := c.config.Placeholder.Text + if text == "" { + text = "Thinking..." + } + + cardContent, err := buildMarkdownCard(text) + if err != nil { + return "", fmt.Errorf("feishu placeholder: card build failed: %w", err) } req := larkim.NewCreateMessageReqBuilder(). ReceiveIdType(larkim.ReceiveIdTypeChatId). Body(larkim.NewCreateMessageReqBodyBuilder(). - ReceiveId(msg.ChatID). - MsgType(larkim.MsgTypeText). - Content(string(payload)). - Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())). + ReceiveId(chatID). + MsgType(larkim.MsgTypeInteractive). + Content(cardContent). Build()). Build() resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("feishu send: %w", channels.ErrTemporary) + return "", fmt.Errorf("feishu placeholder send: %w", err) } - if !resp.Success() { - return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg) } - logger.DebugCF("feishu", "Feishu message sent", map[string]any{ - "chat_id": msg.ChatID, - }) + if resp.Data != nil && resp.Data.MessageId != nil { + return *resp.Data.MessageId, nil + } + return "", nil +} + +// ReactToMessage implements channels.ReactionCapable. +// Adds an "Pin" reaction and returns an undo function to remove it. +func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + req := larkim.NewCreateMessageReactionReqBuilder(). + MessageId(messageID). + Body(larkim.NewCreateMessageReactionReqBodyBuilder(). + ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()). + Build()). + Build() + + resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req) + if err != nil { + logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{ + "message_id": messageID, + "error": err.Error(), + }) + return func() {}, fmt.Errorf("feishu react: %w", err) + } + if !resp.Success() { + logger.ErrorCF("feishu", "Reaction API error", map[string]any{ + "message_id": messageID, + "code": resp.Code, + "msg": resp.Msg, + }) + return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + + var reactionID string + if resp.Data != nil && resp.Data.ReactionId != nil { + reactionID = *resp.Data.ReactionId + } + if reactionID == "" { + return func() {}, nil + } + + var undone atomic.Bool + undo := func() { + if !undone.CompareAndSwap(false, true) { + return + } + delReq := larkim.NewDeleteMessageReactionReqBuilder(). + MessageId(messageID). + ReactionId(reactionID). + Build() + _, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq) + } + return undo, nil +} + +// SendMedia implements channels.MediaSender. +// Uploads images/files via Feishu API then sends as messages. +func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + if msg.ChatID == "" { + return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil { + return err + } + } return nil } +// sendMediaPart resolves and sends a single media part. +func (c *FeishuChannel) sendMediaPart( + ctx context.Context, + chatID string, + part bus.MediaPart, + store media.MediaStore, +) error { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + return nil // skip this part + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("feishu", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + return nil // skip this part + } + defer file.Close() + + switch part.Type { + case "image": + err = c.sendImage(ctx, chatID, file) + default: + filename := part.Filename + if filename == "" { + filename = "file" + } + err = c.sendFile(ctx, chatID, file, filename, part.Type) + } + + if err != nil { + logger.ErrorCF("feishu", "Failed to send media", map[string]any{ + "type": part.Type, + "error": err.Error(), + }) + return fmt.Errorf("feishu send media: %w", channels.ErrTemporary) + } + return nil +} + +// --- Inbound message handling --- + func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { if event == nil || event.Event == nil || event.Event.Message == nil { return nil @@ -151,34 +335,68 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. senderID = "unknown" } - content := extractFeishuMessageContent(message) + messageType := stringValue(message.MessageType) + messageID := stringValue(message.MessageId) + rawContent := stringValue(message.Content) + + // Check allowlist early to avoid downloading media for rejected senders. + // BaseChannel.HandleMessage will check again, but this avoids wasted network I/O. + senderInfo := bus.SenderInfo{ + Platform: "feishu", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("feishu", senderID), + } + if !c.IsAllowedSender(senderInfo) { + return nil + } + + // Extract content based on message type + content := extractContent(messageType, rawContent) + + // Handle media messages (download and store) + var mediaRefs []string + if store := c.GetMediaStore(); store != nil && messageID != "" { + mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store) + } + + // Append media tags to content (like Telegram does) + content = appendMediaTags(content, messageType, mediaRefs) + if content == "" { content = "[empty message]" } metadata := map[string]string{} - messageID := "" - if mid := stringValue(message.MessageId); mid != "" { - messageID = mid + if messageID != "" { + metadata["message_id"] = messageID } - if messageType := stringValue(message.MessageType); messageType != "" { + if messageType != "" { metadata["message_type"] = messageType } - if chatType := stringValue(message.ChatType); chatType != "" { + chatType := stringValue(message.ChatType) + if chatType != "" { metadata["chat_type"] = chatType } if sender != nil && sender.TenantKey != nil { metadata["tenant_key"] = *sender.TenantKey } - chatType := stringValue(message.ChatType) var peer bus.Peer if chatType == "p2p" { peer = bus.Peer{Kind: "direct", ID: senderID} } else { peer = bus.Peer{Kind: "group", ID: chatID} + + // Check if bot was mentioned + isMentioned := c.isBotMentioned(message) + + // Strip mention placeholders from content before group trigger check + if len(message.Mentions) > 0 { + content = stripMentionPlaceholders(content, message.Mentions) + } + // In group chats, apply unified group trigger filtering - respond, cleaned := c.ShouldRespondInGroup(false, content) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) if !respond { return nil } @@ -186,22 +404,398 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. } logger.InfoCF("feishu", "Feishu message received", map[string]any{ - "sender_id": senderID, - "chat_id": chatID, - "preview": utils.Truncate(content, 80), + "sender_id": senderID, + "chat_id": chatID, + "message_id": messageID, + "preview": utils.Truncate(content, 80), }) - senderInfo := bus.SenderInfo{ - Platform: "feishu", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("feishu", senderID), + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo) + return nil +} + +// --- Internal helpers --- + +// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id. +func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error { + resp, err := c.client.Do(ctx, &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: "/open-apis/bot/v3/info", + SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant}, + }) + if err != nil { + return fmt.Errorf("bot info request: %w", err) } - if !c.IsAllowedSender(senderInfo) { - return nil + var result struct { + Code int `json:"code"` + Bot struct { + OpenID string `json:"open_id"` + } `json:"bot"` + } + if err := json.Unmarshal(resp.RawBody, &result); err != nil { + return fmt.Errorf("bot info parse: %w", err) + } + if result.Code != 0 { + return fmt.Errorf("bot info api error (code=%d)", result.Code) + } + if result.Bot.OpenID == "" { + return fmt.Errorf("bot info: empty open_id") } - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo) + c.botOpenID.Store(result.Bot.OpenID) + logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{ + "open_id": result.Bot.OpenID, + }) + return nil +} + +// isBotMentioned checks if the bot was @mentioned in the message. +func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool { + if message.Mentions == nil { + return false + } + + knownID, _ := c.botOpenID.Load().(string) + if knownID == "" { + logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil) + return false + } + + for _, m := range message.Mentions { + if m.Id == nil { + continue + } + if m.Id.OpenId != nil && *m.Id.OpenId == knownID { + return true + } + } + return false +} + +// extractContent extracts text content from different message types. +func extractContent(messageType, rawContent string) string { + if rawContent == "" { + return "" + } + + switch messageType { + case larkim.MsgTypeText: + var textPayload struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil { + return textPayload.Text + } + return rawContent + + case larkim.MsgTypePost: + // Pass raw JSON to LLM — structured rich text is more informative than flattened plain text + return rawContent + + case larkim.MsgTypeImage: + // Image messages don't have text content + return "" + + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + // File/audio/video messages may have a filename + name := extractFileName(rawContent) + if name != "" { + return name + } + return "" + + default: + return rawContent + } +} + +// downloadInboundMedia downloads media from inbound messages and stores in MediaStore. +func (c *FeishuChannel) downloadInboundMedia( + ctx context.Context, + chatID, messageID, messageType, rawContent string, + store media.MediaStore, +) []string { + var refs []string + scope := channels.BuildMediaScope("feishu", chatID, messageID) + + switch messageType { + case larkim.MsgTypeImage: + imageKey := extractImageKey(rawContent) + if imageKey == "" { + return nil + } + ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope) + if ref != "" { + refs = append(refs, ref) + } + + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + fileKey := extractFileKey(rawContent) + if fileKey == "" { + return nil + } + // Derive a fallback extension from the message type. + var ext string + switch messageType { + case larkim.MsgTypeAudio: + ext = ".ogg" + case larkim.MsgTypeMedia: + ext = ".mp4" + default: + ext = "" // generic file — rely on resp.FileName + } + ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope) + if ref != "" { + refs = append(refs, ref) + } + } + + return refs +} + +// downloadResource downloads a message resource (image/file) from Feishu, +// writes it to the project media directory, and stores the reference in MediaStore. +// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension. +func (c *FeishuChannel) downloadResource( + ctx context.Context, + messageID, fileKey, resourceType, fallbackExt string, + store media.MediaStore, + scope string, +) string { + req := larkim.NewGetMessageResourceReqBuilder(). + MessageId(messageID). + FileKey(fileKey). + Type(resourceType). + Build() + + resp, err := c.client.Im.V1.MessageResource.Get(ctx, req) + if err != nil { + logger.ErrorCF("feishu", "Failed to download resource", map[string]any{ + "message_id": messageID, + "file_key": fileKey, + "error": err.Error(), + }) + return "" + } + if !resp.Success() { + logger.ErrorCF("feishu", "Resource download api error", map[string]any{ + "code": resp.Code, + "msg": resp.Msg, + }) + return "" + } + + if resp.File == nil { + return "" + } + // Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body). + if closer, ok := resp.File.(io.Closer); ok { + defer closer.Close() + } + + filename := resp.FileName + if filename == "" { + filename = fileKey + } + // If filename still has no extension, append the fallback (like Telegram's ext parameter). + if filepath.Ext(filename) == "" && fallbackExt != "" { + filename += fallbackExt + } + + // Write to the shared picoclaw_media directory using a unique name to avoid collisions. + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil { + logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{ + "error": mkdirErr.Error(), + }) + return "" + } + ext := filepath.Ext(filename) + localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext)) + + out, err := os.Create(localPath) + if err != nil { + logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{ + "error": err.Error(), + }) + return "" + } + + if _, copyErr := io.Copy(out, resp.File); copyErr != nil { + out.Close() + os.Remove(localPath) + logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{ + "error": copyErr.Error(), + }) + return "" + } + out.Close() + + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "feishu", + }, scope) + if err != nil { + logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{ + "file_key": fileKey, + "error": err.Error(), + }) + os.Remove(localPath) + return "" + } + + return ref +} + +// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]"). +func appendMediaTags(content, messageType string, mediaRefs []string) string { + if len(mediaRefs) == 0 { + return content + } + + var tag string + switch messageType { + case larkim.MsgTypeImage: + tag = "[image: photo]" + case larkim.MsgTypeAudio: + tag = "[audio]" + case larkim.MsgTypeMedia: + tag = "[video]" + case larkim.MsgTypeFile: + tag = "[file]" + default: + tag = "[attachment]" + } + + if content == "" { + return tag + } + return content + " " + tag +} + +// sendCard sends an interactive card message to a chat. +func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error { + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeInteractive). + Content(cardContent). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu send card: %w", channels.ErrTemporary) + } + + if !resp.Success() { + return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + } + + logger.DebugCF("feishu", "Feishu card message sent", map[string]any{ + "chat_id": chatID, + }) + + return nil +} + +// sendImage uploads an image and sends it as a message. +func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error { + // Upload image to get image_key + uploadReq := larkim.NewCreateImageReqBuilder(). + Body(larkim.NewCreateImageReqBodyBuilder(). + ImageType("message"). + Image(file). + Build()). + Build() + + uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("feishu image upload: %w", err) + } + if !uploadResp.Success() { + return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg) + } + if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil { + return fmt.Errorf("feishu image upload: no image_key returned") + } + + imageKey := *uploadResp.Data.ImageKey + + // Send image message + content, _ := json.Marshal(map[string]string{"image_key": imageKey}) + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeImage). + Content(string(content)). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu image send: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + return nil +} + +// sendFile uploads a file and sends it as a message. +func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error { + // Map part type to Feishu file type + feishuFileType := "stream" + switch fileType { + case "audio": + feishuFileType = "opus" + case "video": + feishuFileType = "mp4" + } + + // Upload file to get file_key + uploadReq := larkim.NewCreateFileReqBuilder(). + Body(larkim.NewCreateFileReqBodyBuilder(). + FileType(feishuFileType). + FileName(filename). + File(file). + Build()). + Build() + + uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("feishu file upload: %w", err) + } + if !uploadResp.Success() { + return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg) + } + if uploadResp.Data == nil || uploadResp.Data.FileKey == nil { + return fmt.Errorf("feishu file upload: no file_key returned") + } + + fileKey := *uploadResp.Data.FileKey + + // Send file message + content, _ := json.Marshal(map[string]string{"file_key": fileKey}) + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeFile). + Content(string(content)). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu file send: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg) + } return nil } @@ -222,20 +816,3 @@ func extractFeishuSenderID(sender *larkim.EventSender) string { return "" } - -func extractFeishuMessageContent(message *larkim.EventMessage) string { - if message == nil || message.Content == nil || *message.Content == "" { - return "" - } - - if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText { - var textPayload struct { - Text string `json:"text"` - } - if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil { - return textPayload.Text - } - } - - return *message.Content -} diff --git a/pkg/channels/feishu/feishu_64_test.go b/pkg/channels/feishu/feishu_64_test.go new file mode 100644 index 000000000..dc3eab2e7 --- /dev/null +++ b/pkg/channels/feishu/feishu_64_test.go @@ -0,0 +1,256 @@ +//go:build amd64 || arm64 || riscv64 || mips64 || ppc64 + +package feishu + +import ( + "testing" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func TestExtractContent(t *testing.T) { + tests := []struct { + name string + messageType string + rawContent string + want string + }{ + { + name: "text message", + messageType: "text", + rawContent: `{"text": "hello world"}`, + want: "hello world", + }, + { + name: "text message invalid JSON", + messageType: "text", + rawContent: `not json`, + want: "not json", + }, + { + name: "post message returns raw JSON", + messageType: "post", + rawContent: `{"title": "test post"}`, + want: `{"title": "test post"}`, + }, + { + name: "image message returns empty", + messageType: "image", + rawContent: `{"image_key": "img_xxx"}`, + want: "", + }, + { + name: "file message with filename", + messageType: "file", + rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`, + want: "report.pdf", + }, + { + name: "file message without filename", + messageType: "file", + rawContent: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "audio message with filename", + messageType: "audio", + rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`, + want: "recording.ogg", + }, + { + name: "media message with filename", + messageType: "media", + rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`, + want: "video.mp4", + }, + { + name: "unknown message type returns raw", + messageType: "sticker", + rawContent: `{"sticker_id": "sticker_xxx"}`, + want: `{"sticker_id": "sticker_xxx"}`, + }, + { + name: "empty raw content", + messageType: "text", + rawContent: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractContent(tt.messageType, tt.rawContent) + if got != tt.want { + t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want) + } + }) + } +} + +func TestAppendMediaTags(t *testing.T) { + tests := []struct { + name string + content string + messageType string + mediaRefs []string + want string + }{ + { + name: "no refs returns content unchanged", + content: "hello", + messageType: "image", + mediaRefs: nil, + want: "hello", + }, + { + name: "empty refs returns content unchanged", + content: "hello", + messageType: "image", + mediaRefs: []string{}, + want: "hello", + }, + { + name: "image with content", + content: "check this", + messageType: "image", + mediaRefs: []string{"ref1"}, + want: "check this [image: photo]", + }, + { + name: "image empty content", + content: "", + messageType: "image", + mediaRefs: []string{"ref1"}, + want: "[image: photo]", + }, + { + name: "audio", + content: "listen", + messageType: "audio", + mediaRefs: []string{"ref1"}, + want: "listen [audio]", + }, + { + name: "media/video", + content: "watch", + messageType: "media", + mediaRefs: []string{"ref1"}, + want: "watch [video]", + }, + { + name: "file", + content: "report.pdf", + messageType: "file", + mediaRefs: []string{"ref1"}, + want: "report.pdf [file]", + }, + { + name: "unknown type", + content: "something", + messageType: "sticker", + mediaRefs: []string{"ref1"}, + want: "something [attachment]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs) + if got != tt.want { + t.Errorf( + "appendMediaTags(%q, %q, %v) = %q, want %q", + tt.content, + tt.messageType, + tt.mediaRefs, + got, + tt.want, + ) + } + }) + } +} + +func TestExtractFeishuSenderID(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + sender *larkim.EventSender + want string + }{ + { + name: "nil sender", + sender: nil, + want: "", + }, + { + name: "nil sender ID", + sender: &larkim.EventSender{SenderId: nil}, + want: "", + }, + { + name: "userId preferred", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr("u_abc123"), + OpenId: strPtr("ou_def456"), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "u_abc123", + }, + { + name: "openId fallback", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr("ou_def456"), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "ou_def456", + }, + { + name: "unionId fallback", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr(""), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "on_ghi789", + }, + { + name: "all empty strings", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr(""), + UnionId: strPtr(""), + }, + }, + want: "", + }, + { + name: "nil userId pointer falls through", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: nil, + OpenId: strPtr("ou_def456"), + UnionId: nil, + }, + }, + want: "ou_def456", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFeishuSenderID(tt.sender) + if got != tt.want { + t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 398f12e6b..b36350a06 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -654,7 +654,10 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("reading LINE API error response: %w", err)) + } return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody))) } diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 7feb706aa..f328f32b8 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -72,6 +72,10 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann })) } + if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" { + opts = append(opts, telego.WithAPIServer(baseURL)) + } + bot, err := telego.NewBot(telegramCfg.Token, opts...) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) diff --git a/pkg/channels/wecom/aibot.go b/pkg/channels/wecom/aibot.go index 6c5aca40b..93fe8c36d 100644 --- a/pkg/channels/wecom/aibot.go +++ b/pkg/channels/wecom/aibot.go @@ -793,7 +793,10 @@ func (c *WeComAIBotChannel) sendViaResponseURL(responseURL, content string) erro return nil } - respBody, _ := io.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("reading response_url body: %w: %w", channels.ErrTemporary, err) + } switch { case resp.StatusCode == http.StatusTooManyRequests: return fmt.Errorf("response_url rate limited (%d): %s: %w", diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 717815b9f..2098fcd4e 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -321,8 +321,17 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody))) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return "", channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("reading wecom upload error response: %w", readErr), + ) + } + return "", channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("wecom upload error: %s", string(respBody)), + ) } var result struct { @@ -371,8 +380,17 @@ func (c *WeComAppChannel) sendWeComMessage(ctx context.Context, accessToken stri defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody))) + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("reading wecom_app error response: %w", readErr), + ) + } + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("wecom_app API error: %s", string(respBody)), + ) } respBody, err := io.ReadAll(resp.Body) diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 9126a847d..96d5a961f 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -453,8 +453,17 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body))) + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("reading webhook error response: %w", readErr), + ) + } + return channels.ClassifySendError( + resp.StatusCode, + fmt.Errorf("webhook API error: %s", string(body)), + ) } body, err := io.ReadAll(resp.Body) diff --git a/pkg/config/config.go b/pkg/config/config.go index 75b7539cb..23dca8cb8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -192,9 +192,21 @@ type AgentDefaults struct { MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` Routing *RoutingConfig `json:"routing,omitempty"` } +const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB + +func (d *AgentDefaults) GetMaxMediaSize() int { + if d.MaxMediaSize > 0 { + return d.MaxMediaSize + } + return DefaultMaxMediaSize +} + // GetModelName returns the effective model name for the agent defaults. // It prefers the new "model_name" field but falls back to "model" for backward compatibility. func (d *AgentDefaults) GetModelName() string { @@ -250,6 +262,7 @@ type WhatsAppConfig struct { type TelegramConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"` Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` @@ -266,12 +279,14 @@ type FeishuConfig struct { VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` } type DiscordConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_DISCORD_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` @@ -429,6 +444,7 @@ type ProvidersConfig struct { Antigravity ProviderConfig `json:"antigravity"` Qwen ProviderConfig `json:"qwen"` Mistral ProviderConfig `json:"mistral"` + Avian ProviderConfig `json:"avian"` } // IsEmpty checks if all provider configs are empty (no API keys or API bases set) @@ -452,7 +468,8 @@ func (p ProvidersConfig) IsEmpty() bool { p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" && p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" && p.Qwen.APIKey == "" && p.Qwen.APIBase == "" && - p.Mistral.APIKey == "" && p.Mistral.APIBase == "" + p.Mistral.APIKey == "" && p.Mistral.APIBase == "" && + p.Avian.APIKey == "" && p.Avian.APIBase == "" } // MarshalJSON implements custom JSON marshaling for ProvidersConfig @@ -503,6 +520,7 @@ type ModelConfig struct { RPM int `json:"rpm,omitempty"` // Requests per minute limit MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens") RequestTimeout int `json:"request_timeout,omitempty"` + ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive } // Validate checks if the ModelConfig has all required fields. @@ -521,6 +539,10 @@ type GatewayConfig struct { Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` } +type ToolConfig struct { + Enabled bool `json:"enabled" env:"ENABLED"` +} + type BraveConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` @@ -545,11 +567,30 @@ type PerplexityConfig struct { MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"` } +type SearXNGConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_SEARXNG_ENABLED"` + BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_SEARXNG_BASE_URL"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARXNG_MAX_RESULTS"` +} + +type GLMSearchConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_GLM_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_GLM_API_KEY"` + BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_GLM_BASE_URL"` + // SearchEngine specifies the search backend: "search_std" (default), + // "search_pro", "search_pro_sogou", or "search_pro_quark". + SearchEngine string `json:"search_engine" env:"PICOCLAW_TOOLS_WEB_GLM_SEARCH_ENGINE"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_GLM_MAX_RESULTS"` +} + type WebToolsConfig struct { - Brave BraveConfig `json:"brave"` - Tavily TavilyConfig `json:"tavily"` - DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` - Perplexity PerplexityConfig `json:"perplexity"` + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_WEB_"` + Brave BraveConfig ` json:"brave"` + Tavily TavilyConfig ` json:"tavily"` + DuckDuckGo DuckDuckGoConfig ` json:"duckduckgo"` + Perplexity PerplexityConfig ` json:"perplexity"` + SearXNG SearXNGConfig ` json:"searxng"` + GLMSearch GLMSearchConfig ` json:"glm_search"` // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h). // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config. Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"` @@ -557,19 +598,28 @@ type WebToolsConfig struct { } type CronToolsConfig struct { - ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"` + ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout } type ExecConfig struct { - EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"` - CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` - CustomAllowPatterns []string `json:"custom_allow_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS"` + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"` + EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"` + CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"` + CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"` +} + +type SkillsToolsConfig struct { + ToolConfig ` envPrefix:"PICOCLAW_TOOLS_SKILLS_"` + Registries SkillsRegistriesConfig ` json:"registries"` + MaxConcurrentSearches int ` json:"max_concurrent_searches" env:"PICOCLAW_TOOLS_SKILLS_MAX_CONCURRENT_SEARCHES"` + SearchCache SearchCacheConfig ` json:"search_cache"` } type MediaCleanupConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_MEDIA_CLEANUP_ENABLED"` - MaxAge int `json:"max_age_minutes" env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE"` - Interval int `json:"interval_minutes" env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL"` + ToolConfig ` envPrefix:"PICOCLAW_MEDIA_CLEANUP_"` + MaxAge int ` env:"PICOCLAW_MEDIA_CLEANUP_MAX_AGE" json:"max_age_minutes"` + Interval int ` env:"PICOCLAW_MEDIA_CLEANUP_INTERVAL" json:"interval_minutes"` } type ToolsConfig struct { @@ -581,12 +631,19 @@ type ToolsConfig struct { Skills SkillsToolsConfig `json:"skills"` MediaCleanup MediaCleanupConfig `json:"media_cleanup"` MCP MCPConfig `json:"mcp"` -} - -type SkillsToolsConfig struct { - Registries SkillsRegistriesConfig `json:"registries"` - MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"` - SearchCache SearchCacheConfig `json:"search_cache"` + AppendFile ToolConfig `json:"append_file" envPrefix:"PICOCLAW_TOOLS_APPEND_FILE_"` + EditFile ToolConfig `json:"edit_file" envPrefix:"PICOCLAW_TOOLS_EDIT_FILE_"` + FindSkills ToolConfig `json:"find_skills" envPrefix:"PICOCLAW_TOOLS_FIND_SKILLS_"` + I2C ToolConfig `json:"i2c" envPrefix:"PICOCLAW_TOOLS_I2C_"` + InstallSkill ToolConfig `json:"install_skill" envPrefix:"PICOCLAW_TOOLS_INSTALL_SKILL_"` + ListDir ToolConfig `json:"list_dir" envPrefix:"PICOCLAW_TOOLS_LIST_DIR_"` + Message ToolConfig `json:"message" envPrefix:"PICOCLAW_TOOLS_MESSAGE_"` + ReadFile ToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"` + Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"` + SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"` + Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"` + WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"` + WriteFile ToolConfig `json:"write_file" envPrefix:"PICOCLAW_TOOLS_WRITE_FILE_"` } type SearchCacheConfig struct { @@ -632,8 +689,7 @@ type MCPServerConfig struct { // MCPConfig defines configuration for all MCP servers type MCPConfig struct { - // Enabled globally enables/disables MCP integration - Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"` + ToolConfig `envPrefix:"PICOCLAW_TOOLS_MCP_"` // Servers is a map of server name to server configuration Servers map[string]MCPServerConfig `json:"servers,omitempty"` } @@ -819,3 +875,48 @@ func (c *Config) ValidateModelList() error { } return nil } + +func (t *ToolsConfig) IsToolEnabled(name string) bool { + switch name { + case "web": + return t.Web.Enabled + case "cron": + return t.Cron.Enabled + case "exec": + return t.Exec.Enabled + case "skills": + return t.Skills.Enabled + case "media_cleanup": + return t.MediaCleanup.Enabled + case "append_file": + return t.AppendFile.Enabled + case "edit_file": + return t.EditFile.Enabled + case "find_skills": + return t.FindSkills.Enabled + case "i2c": + return t.I2C.Enabled + case "install_skill": + return t.InstallSkill.Enabled + case "list_dir": + return t.ListDir.Enabled + case "message": + return t.Message.Enabled + case "read_file": + return t.ReadFile.Enabled + case "spawn": + return t.Spawn.Enabled + case "spi": + return t.SPI.Enabled + case "subagent": + return t.Subagent.Enabled + case "web_fetch": + return t.WebFetch.Enabled + case "write_file": + return t.WriteFile.Enabled + case "mcp": + return t.MCP.Enabled + default: + return true + } +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 6af7c209e..10ebc7c90 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -435,6 +435,18 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) { } // TestDefaultConfig_DMScope verifies the default dm_scope value +// TestDefaultConfig_SummarizationThresholds verifies summarization defaults +func TestDefaultConfig_SummarizationThresholds(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.SummarizeMessageThreshold != 20 { + t.Errorf("SummarizeMessageThreshold = %d, want 20", cfg.Agents.Defaults.SummarizeMessageThreshold) + } + if cfg.Agents.Defaults.SummarizeTokenPercent != 75 { + t.Errorf("SummarizeTokenPercent = %d, want 75", cfg.Agents.Defaults.SummarizeTokenPercent) + } +} + func TestDefaultConfig_DMScope(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 9fc09c5f1..e87d7aa0a 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -26,13 +26,15 @@ func DefaultConfig() *Config { return &Config{ Agents: AgentsConfig{ Defaults: AgentDefaults{ - Workspace: workspacePath, - RestrictToWorkspace: true, - Provider: "", - Model: "", - MaxTokens: 32768, - Temperature: nil, // nil means use provider default - MaxToolIterations: 50, + Workspace: workspacePath, + RestrictToWorkspace: true, + Provider: "", + Model: "", + MaxTokens: 32768, + Temperature: nil, // nil means use provider default + MaxToolIterations: 50, + SummarizeMessageThreshold: 20, + SummarizeTokenPercent: 75, }, }, Bindings: []AgentBinding{}, @@ -306,6 +308,20 @@ func DefaultConfig() *Config { APIKey: "", }, + // Avian - https://avian.io + { + ModelName: "deepseek-v3.2", + Model: "avian/deepseek/deepseek-v3.2", + APIBase: "https://api.avian.io/v1", + APIKey: "", + }, + { + ModelName: "kimi-k2.5", + Model: "avian/moonshotai/kimi-k2.5", + APIBase: "https://api.avian.io/v1", + APIKey: "", + }, + // VLLM (local) - http://localhost:8000 { ModelName: "local-model", @@ -320,11 +336,16 @@ func DefaultConfig() *Config { }, Tools: ToolsConfig{ MediaCleanup: MediaCleanupConfig{ - Enabled: true, + ToolConfig: ToolConfig{ + Enabled: true, + }, MaxAge: 30, Interval: 5, }, Web: WebToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, Proxy: "", FetchLimitBytes: 10 * 1024 * 1024, // 10MB by default Brave: BraveConfig{ @@ -341,14 +362,35 @@ func DefaultConfig() *Config { APIKey: "", MaxResults: 5, }, + SearXNG: SearXNGConfig{ + Enabled: false, + BaseURL: "", + MaxResults: 5, + }, + GLMSearch: GLMSearchConfig{ + Enabled: false, + APIKey: "", + BaseURL: "https://open.bigmodel.cn/api/paas/v4/web_search", + SearchEngine: "search_std", + MaxResults: 5, + }, }, Cron: CronToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, ExecTimeoutMinutes: 5, }, Exec: ExecConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, EnableDenyPatterns: true, }, Skills: SkillsToolsConfig{ + ToolConfig: ToolConfig{ + Enabled: true, + }, Registries: SkillsRegistriesConfig{ ClawHub: ClawHubRegistryConfig{ Enabled: true, @@ -362,9 +404,50 @@ func DefaultConfig() *Config { }, }, MCP: MCPConfig{ - Enabled: false, + ToolConfig: ToolConfig{ + Enabled: false, + }, Servers: map[string]MCPServerConfig{}, }, + AppendFile: ToolConfig{ + Enabled: true, + }, + EditFile: ToolConfig{ + Enabled: true, + }, + FindSkills: ToolConfig{ + Enabled: true, + }, + I2C: ToolConfig{ + Enabled: false, // Hardware tool - Linux only + }, + InstallSkill: ToolConfig{ + Enabled: true, + }, + ListDir: ToolConfig{ + Enabled: true, + }, + Message: ToolConfig{ + Enabled: true, + }, + ReadFile: ToolConfig{ + Enabled: true, + }, + Spawn: ToolConfig{ + Enabled: true, + }, + SPI: ToolConfig{ + Enabled: false, // Hardware tool - Linux only + }, + Subagent: ToolConfig{ + Enabled: true, + }, + WebFetch: ToolConfig{ + Enabled: true, + }, + WriteFile: ToolConfig{ + Enabled: true, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, diff --git a/pkg/config/migration.go b/pkg/config/migration.go index 772f714fd..4a17dd6c9 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -373,6 +373,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, true }, }, + { + providerNames: []string{"avian"}, + protocol: "avian", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.Avian.APIKey == "" && p.Avian.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "avian", + Model: "avian/deepseek/deepseek-v3.2", + APIKey: p.Avian.APIKey, + APIBase: p.Avian.APIBase, + Proxy: p.Avian.Proxy, + RequestTimeout: p.Avian.RequestTimeout, + }, true + }, + }, } // Process each provider migration diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index e24e9fa1d..67ad73db9 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -160,14 +160,15 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { Antigravity: ProviderConfig{AuthMethod: "oauth"}, Qwen: ProviderConfig{APIKey: "key17"}, Mistral: ProviderConfig{APIKey: "key18"}, + Avian: ProviderConfig{APIKey: "key19"}, }, } result := ConvertProvidersToModelList(cfg) - // All 19 providers should be converted - if len(result) != 19 { - t.Errorf("len(result) = %d, want 19", len(result)) + // All 20 providers should be converted + if len(result) != 20 { + t.Errorf("len(result) = %d, want 20", len(result)) } } diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go index 8b6d6d9aa..7b63cc979 100644 --- a/pkg/mcp/manager.go +++ b/pkg/mcp/manager.go @@ -11,6 +11,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -108,7 +109,7 @@ type ServerConnection struct { type Manager struct { servers map[string]*ServerConnection mu sync.RWMutex - closed bool + closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race wg sync.WaitGroup // tracks in-flight CallTool calls } @@ -440,14 +441,20 @@ func (m *Manager) CallTool( serverName, toolName string, arguments map[string]any, ) (*mcp.CallToolResult, error) { + // Check if closed before acquiring lock (fast path) + if m.closed.Load() { + return nil, fmt.Errorf("manager is closed") + } + m.mu.RLock() - if m.closed { + // Double-check after acquiring lock to prevent TOCTOU race + if m.closed.Load() { m.mu.RUnlock() return nil, fmt.Errorf("manager is closed") } conn, ok := m.servers[serverName] if ok { - m.wg.Add(1) + m.wg.Add(1) // Add to WaitGroup while holding the lock } m.mu.RUnlock() @@ -471,15 +478,14 @@ func (m *Manager) CallTool( // Close closes all server connections func (m *Manager) Close() error { - m.mu.Lock() - if m.closed { - m.mu.Unlock() - return nil + // Use Swap to atomically set closed=true and get the previous value + // This prevents TOCTOU race with CallTool's closed check + if m.closed.Swap(true) { + return nil // already closed } - m.closed = true - m.mu.Unlock() // Wait for all in-flight CallTool calls to finish before closing sessions + // After closed=true is set, no new CallTool can start (they check closed first) m.wg.Wait() m.mu.Lock() diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go index 6dd71a3c2..f353942ab 100644 --- a/pkg/mcp/manager_test.go +++ b/pkg/mcp/manager_test.go @@ -194,7 +194,9 @@ func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) { mgr := NewManager() mcpCfg := config.MCPConfig{ - Enabled: true, + ToolConfig: config.ToolConfig{ + Enabled: true, + }, Servers: map[string]config.MCPServerConfig{ "test-server": { Enabled: true, @@ -228,12 +230,20 @@ func TestNewManager_InitialState(t *testing.T) { func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) { mgr := NewManager() - err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp") + err := mgr.LoadFromMCPConfig( + context.Background(), + config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: false}}, + "/tmp", + ) if err != nil { t.Fatalf("expected nil error when MCP disabled, got: %v", err) } - err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp") + err = mgr.LoadFromMCPConfig( + context.Background(), + config.MCPConfig{ToolConfig: config.ToolConfig{Enabled: true}}, + "/tmp", + ) if err != nil { t.Fatalf("expected nil error when no servers configured, got: %v", err) } @@ -268,7 +278,7 @@ func TestGetAllTools_FiltersEmptyTools(t *testing.T) { func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) { t.Run("manager closed", func(t *testing.T) { mgr := NewManager() - mgr.closed = true + mgr.closed.Store(true) _, err := mgr.CallTool(context.Background(), "s1", "tool", nil) if err == nil || !strings.Contains(err.Error(), "manager is closed") { diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go new file mode 100644 index 000000000..e12e2c5ab --- /dev/null +++ b/pkg/memory/jsonl.go @@ -0,0 +1,460 @@ +package memory + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "hash/fnv" + "log" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/fileutil" + "github.com/sipeed/picoclaw/pkg/providers" +) + +const ( + // numLockShards is the fixed number of mutexes used to serialize + // per-session access. Using a sharded array instead of a map keeps + // memory bounded regardless of how many sessions are created over + // the lifetime of the process — important for a long-running daemon. + numLockShards = 64 + + // maxLineSize is the maximum size of a single JSON line in a .jsonl + // file. Tool results (read_file, web search, etc.) can be large, so + // we set a generous limit. The scanner starts at 64 KB and grows + // only as needed up to this cap. + maxLineSize = 10 * 1024 * 1024 // 10 MB +) + +// sessionMeta holds per-session metadata stored in a .meta.json file. +type sessionMeta struct { + Key string `json:"key"` + Summary string `json:"summary"` + Skip int `json:"skip"` + Count int `json:"count"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// JSONLStore implements Store using append-only JSONL files. +// +// Each session is stored as two files: +// +// {sanitized_key}.jsonl — one JSON-encoded message per line, append-only +// {sanitized_key}.meta.json — session metadata (summary, logical truncation offset) +// +// Messages are never physically deleted from the JSONL file. Instead, +// TruncateHistory records a "skip" offset in the metadata file and +// GetHistory ignores lines before that offset. This keeps all writes +// append-only, which is both fast and crash-safe. +type JSONLStore struct { + dir string + locks [numLockShards]sync.Mutex +} + +// NewJSONLStore creates a new JSONL-backed store rooted at dir. +func NewJSONLStore(dir string) (*JSONLStore, error) { + err := os.MkdirAll(dir, 0o755) + if err != nil { + return nil, fmt.Errorf("memory: create directory: %w", err) + } + return &JSONLStore{dir: dir}, nil +} + +// sessionLock returns a mutex for the given session key. +// Keys are mapped to a fixed pool of shards via FNV hash, so +// memory usage is O(1) regardless of total session count. +func (s *JSONLStore) sessionLock(key string) *sync.Mutex { + h := fnv.New32a() + h.Write([]byte(key)) + return &s.locks[h.Sum32()%numLockShards] +} + +func (s *JSONLStore) jsonlPath(key string) string { + return filepath.Join(s.dir, sanitizeKey(key)+".jsonl") +} + +func (s *JSONLStore) metaPath(key string) string { + return filepath.Join(s.dir, sanitizeKey(key)+".meta.json") +} + +// sanitizeKey converts a session key to a safe filename component. +// Mirrors pkg/session.sanitizeFilename so that migration paths match. +// +// Note: this is a lossy mapping — "telegram:123" and "telegram_123" +// both produce the same filename. This is an intentional tradeoff: +// keys with colons (e.g. from channels) are by far the common case, +// and a bidirectional encoding (like URL-encoding) would complicate +// file listings and debugging. +func sanitizeKey(key string) string { + return strings.ReplaceAll(key, ":", "_") +} + +// readMeta loads the metadata file for a session. +// Returns a zero-value sessionMeta if the file does not exist. +func (s *JSONLStore) readMeta(key string) (sessionMeta, error) { + data, err := os.ReadFile(s.metaPath(key)) + if os.IsNotExist(err) { + return sessionMeta{Key: key}, nil + } + if err != nil { + return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err) + } + var meta sessionMeta + err = json.Unmarshal(data, &meta) + if err != nil { + return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err) + } + return meta, nil +} + +// writeMeta atomically writes the metadata file using the project's +// standard WriteFileAtomic (temp + fsync + rename). +func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error { + data, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return fmt.Errorf("memory: encode meta: %w", err) + } + return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644) +} + +// readMessages reads valid JSON lines from a .jsonl file, skipping +// the first `skip` lines without unmarshaling them. This avoids the +// cost of json.Unmarshal on logically truncated messages. +// Malformed trailing lines (e.g. from a crash) are silently skipped. +func readMessages(path string, skip int) ([]providers.Message, error) { + f, err := os.Open(path) + if os.IsNotExist(err) { + return []providers.Message{}, nil + } + if err != nil { + return nil, fmt.Errorf("memory: open jsonl: %w", err) + } + defer f.Close() + + var msgs []providers.Message + scanner := bufio.NewScanner(f) + // Allow large lines for tool results (read_file, web search, etc.). + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + lineNum := 0 + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + lineNum++ + if lineNum <= skip { + continue + } + var msg providers.Message + if err := json.Unmarshal(line, &msg); err != nil { + // Corrupt line — likely a partial write from a crash. + // Log so operators know data was skipped, but don't + // fail the entire read; this is the standard JSONL + // recovery pattern. + log.Printf("memory: skipping corrupt line %d in %s: %v", + lineNum, filepath.Base(path), err) + continue + } + msgs = append(msgs, msg) + } + if scanner.Err() != nil { + return nil, fmt.Errorf("memory: scan jsonl: %w", scanner.Err()) + } + + if msgs == nil { + msgs = []providers.Message{} + } + return msgs, nil +} + +// countLines counts the total number of non-empty lines in a .jsonl file. +// Used by TruncateHistory to reconcile a stale meta.Count without +// the overhead of unmarshaling every message. +func countLines(path string) (int, error) { + f, err := os.Open(path) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("memory: open jsonl: %w", err) + } + defer f.Close() + + n := 0 + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + for scanner.Scan() { + if len(scanner.Bytes()) > 0 { + n++ + } + } + return n, scanner.Err() +} + +func (s *JSONLStore) AddMessage( + _ context.Context, sessionKey, role, content string, +) error { + return s.addMsg(sessionKey, providers.Message{ + Role: role, + Content: content, + }) +} + +func (s *JSONLStore) AddFullMessage( + _ context.Context, sessionKey string, msg providers.Message, +) error { + return s.addMsg(sessionKey, msg) +} + +// addMsg is the shared implementation for AddMessage and AddFullMessage. +func (s *JSONLStore) addMsg(sessionKey string, msg providers.Message) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + // Append the message as a single JSON line. + line, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("memory: marshal message: %w", err) + } + line = append(line, '\n') + + f, err := os.OpenFile( + s.jsonlPath(sessionKey), + os.O_CREATE|os.O_WRONLY|os.O_APPEND, + 0o644, + ) + if err != nil { + return fmt.Errorf("memory: open jsonl for append: %w", err) + } + _, writeErr := f.Write(line) + if writeErr != nil { + f.Close() + return fmt.Errorf("memory: append message: %w", writeErr) + } + // Flush to physical storage before closing. This matches the + // durability guarantee of writeMeta and rewriteJSONL (which use + // WriteFileAtomic with fsync). Without Sync, a power loss could + // leave the append in the kernel page cache only — lost on reboot. + if syncErr := f.Sync(); syncErr != nil { + f.Close() + return fmt.Errorf("memory: sync jsonl: %w", syncErr) + } + if closeErr := f.Close(); closeErr != nil { + return fmt.Errorf("memory: close jsonl: %w", closeErr) + } + + // Update metadata. + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.Count == 0 && meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Count++ + meta.UpdatedAt = now + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) GetHistory( + _ context.Context, sessionKey string, +) ([]providers.Message, error) { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return nil, err + } + + // Pass meta.Skip so readMessages skips those lines without + // unmarshaling them — avoids wasted CPU on truncated messages. + msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip) + if err != nil { + return nil, err + } + + return msgs, nil +} + +func (s *JSONLStore) GetSummary( + _ context.Context, sessionKey string, +) (string, error) { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return "", err + } + return meta.Summary, nil +} + +func (s *JSONLStore) SetSummary( + _ context.Context, sessionKey, summary string, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Summary = summary + meta.UpdatedAt = now + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) TruncateHistory( + _ context.Context, sessionKey string, keepLast int, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + + // Always reconcile meta.Count with the actual line count on disk. + // A crash between the JSONL append and the meta update in addMsg + // leaves meta.Count stale (e.g. file has 101 lines but meta says + // 100). Counting lines is cheap — no unmarshal, just a scan — and + // TruncateHistory is not a hot path, so always re-count. + n, countErr := countLines(s.jsonlPath(sessionKey)) + if countErr != nil { + return countErr + } + meta.Count = n + + if keepLast <= 0 { + meta.Skip = meta.Count + } else { + effective := meta.Count - meta.Skip + if keepLast < effective { + meta.Skip = meta.Count - keepLast + } + } + meta.UpdatedAt = time.Now() + + return s.writeMeta(sessionKey, meta) +} + +func (s *JSONLStore) SetHistory( + _ context.Context, + sessionKey string, + history []providers.Message, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + now := time.Now() + if meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.Skip = 0 + meta.Count = len(history) + meta.UpdatedAt = now + + // Write meta BEFORE rewriting the JSONL file. If we crash between + // the two writes, meta has Skip=0 and the old file is still intact, + // so GetHistory reads from line 1 — returning "too many" messages + // rather than losing data. The next SetHistory call corrects this. + err = s.writeMeta(sessionKey, meta) + if err != nil { + return err + } + + return s.rewriteJSONL(sessionKey, history) +} + +// Compact physically rewrites the JSONL file, dropping all logically +// skipped lines. This reclaims disk space that accumulates after +// repeated TruncateHistory calls. +// +// It is safe to call at any time; if there is nothing to compact +// (skip == 0) the method returns immediately. +func (s *JSONLStore) Compact( + _ context.Context, sessionKey string, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + if meta.Skip == 0 { + return nil + } + + // Read only the active messages, skipping truncated lines + // without unmarshaling them. + active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip) + if err != nil { + return err + } + + // Write meta BEFORE rewriting the JSONL file. If the process + // crashes between the two writes, meta has Skip=0 and the old + // (uncompacted) file is still intact, so GetHistory reads from + // line 1 — returning previously-truncated messages rather than + // losing data. The next Compact or TruncateHistory corrects this. + meta.Skip = 0 + meta.Count = len(active) + meta.UpdatedAt = time.Now() + + err = s.writeMeta(sessionKey, meta) + if err != nil { + return err + } + + return s.rewriteJSONL(sessionKey, active) +} + +// rewriteJSONL atomically replaces the JSONL file with the given messages +// using the project's standard WriteFileAtomic (temp + fsync + rename). +func (s *JSONLStore) rewriteJSONL( + sessionKey string, msgs []providers.Message, +) error { + var buf bytes.Buffer + for i, msg := range msgs { + line, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("memory: marshal message %d: %w", i, err) + } + buf.Write(line) + buf.WriteByte('\n') + } + return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644) +} + +func (s *JSONLStore) Close() error { + return nil +} diff --git a/pkg/memory/jsonl_test.go b/pkg/memory/jsonl_test.go new file mode 100644 index 000000000..356ff14ff --- /dev/null +++ b/pkg/memory/jsonl_test.go @@ -0,0 +1,835 @@ +package memory + +import ( + "context" + "os" + "path/filepath" + "sync" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func newTestStore(t *testing.T) *JSONLStore { + t.Helper() + store, err := NewJSONLStore(t.TempDir()) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + return store +} + +func TestNewJSONLStore_CreatesDirectory(t *testing.T) { + dir := filepath.Join(t.TempDir(), "nested", "sessions") + store, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + + info, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat: %v", err) + } + if !info.IsDir() { + t.Errorf("expected directory, got file") + } +} + +func TestAddMessage_BasicRoundtrip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "s1", "user", "hello") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store.AddMessage(ctx, "s1", "assistant", "hi there") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Role != "user" || history[0].Content != "hello" { + t.Errorf("msg[0] = %+v", history[0]) + } + if history[1].Role != "assistant" || history[1].Content != "hi there" { + t.Errorf("msg[1] = %+v", history[1]) + } +} + +func TestAddMessage_AutoCreatesSession(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Adding a message to a non-existent session should work. + err := store.AddMessage(ctx, "new-session", "user", "first message") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "new-session") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } +} + +func TestAddFullMessage_WithToolCalls(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "assistant", + Content: "Let me search that.", + ToolCalls: []providers.ToolCall{ + { + ID: "call_abc", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"q":"golang jsonl"}`, + }, + }, + }, + } + + err := store.AddFullMessage(ctx, "tc", msg) + if err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "tc") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if len(history[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls)) + } + tc := history[0].ToolCalls[0] + if tc.ID != "call_abc" { + t.Errorf("tool call ID = %q", tc.ID) + } + if tc.Function == nil || tc.Function.Name != "web_search" { + t.Errorf("tool call function = %+v", tc.Function) + } +} + +func TestAddFullMessage_ToolCallID(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + msg := providers.Message{ + Role: "tool", + Content: "search results here", + ToolCallID: "call_abc", + } + + err := store.AddFullMessage(ctx, "tr", msg) + if err != nil { + t.Fatalf("AddFullMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "tr") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if history[0].ToolCallID != "call_abc" { + t.Errorf("ToolCallID = %q", history[0].ToolCallID) + } +} + +func TestGetHistory_EmptySession(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + history, err := store.GetHistory(ctx, "nonexistent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if history == nil { + t.Fatal("expected non-nil empty slice") + } + if len(history) != 0 { + t.Errorf("expected 0 messages, got %d", len(history)) + } +} + +func TestGetHistory_Ordering(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage( + ctx, "order", + "user", + string(rune('a'+i)), + ) + if err != nil { + t.Fatalf("AddMessage(%d): %v", i, err) + } + } + + history, err := store.GetHistory(ctx, "order") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 5 { + t.Fatalf("expected 5, got %d", len(history)) + } + for i := 0; i < 5; i++ { + expected := string(rune('a' + i)) + if history[i].Content != expected { + t.Errorf("msg[%d].Content = %q, want %q", i, history[i].Content, expected) + } + } +} + +func TestSetSummary_GetSummary(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // No summary yet. + summary, err := store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "" { + t.Errorf("expected empty, got %q", summary) + } + + // Set a summary. + err = store.SetSummary(ctx, "s1", "talked about Go") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "talked about Go" { + t.Errorf("summary = %q", summary) + } + + // Update summary. + err = store.SetSummary(ctx, "s1", "updated summary") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + + summary, err = store.GetSummary(ctx, "s1") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "updated summary" { + t.Errorf("summary = %q", summary) + } +} + +func TestTruncateHistory_KeepLast(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 10; i++ { + err := store.AddMessage( + ctx, "trunc", + "user", + string(rune('a'+i)), + ) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "trunc", 4) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "trunc") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 4 { + t.Fatalf("expected 4, got %d", len(history)) + } + // Should be the last 4: g, h, i, j + if history[0].Content != "g" { + t.Errorf("first kept = %q, want 'g'", history[0].Content) + } + if history[3].Content != "j" { + t.Errorf("last kept = %q, want 'j'", history[3].Content) + } +} + +func TestTruncateHistory_KeepZero(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "empty", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "empty", 0) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "empty") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 0 { + t.Errorf("expected 0, got %d", len(history)) + } +} + +func TestTruncateHistory_KeepMoreThanExists(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + err := store.AddMessage(ctx, "few", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Keep 100, but only 3 exist — should keep all. + err := store.TruncateHistory(ctx, "few", 100) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "few") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Errorf("expected 3, got %d", len(history)) + } +} + +func TestSetHistory_ReplacesAll(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Add some initial messages. + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "replace", "user", "old") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Replace with new history. + newHistory := []providers.Message{ + {Role: "user", Content: "new1"}, + {Role: "assistant", Content: "new2"}, + } + err := store.SetHistory(ctx, "replace", newHistory) + if err != nil { + t.Fatalf("SetHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "replace") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2, got %d", len(history)) + } + if history[0].Content != "new1" || history[1].Content != "new2" { + t.Errorf("history = %+v", history) + } +} + +func TestSetHistory_ResetsSkip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Add messages and truncate. + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "skip-reset", "user", "old") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + err := store.TruncateHistory(ctx, "skip-reset", 3) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + // SetHistory should reset skip to 0. + newHistory := []providers.Message{ + {Role: "user", Content: "fresh"}, + } + err = store.SetHistory(ctx, "skip-reset", newHistory) + if err != nil { + t.Fatalf("SetHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "skip-reset") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + if history[0].Content != "fresh" { + t.Errorf("content = %q", history[0].Content) + } +} + +func TestColonInKey(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "telegram:123", "user", "hi") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + history, err := store.GetHistory(ctx, "telegram:123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1, got %d", len(history)) + } + + // Verify the file is named with underscore. + jsonlFile := filepath.Join(store.dir, "telegram_123.jsonl") + if _, statErr := os.Stat(jsonlFile); statErr != nil { + t.Errorf("expected file %s to exist: %v", jsonlFile, statErr) + } +} + +func TestCompact_RemovesSkippedMessages(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Write 10 messages, then truncate to keep last 3. + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "compact", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + err := store.TruncateHistory(ctx, "compact", 3) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + // Before compact: file still has 10 lines. + allOnDisk, err := readMessages(store.jsonlPath("compact"), 0) + if err != nil { + t.Fatalf("readMessages: %v", err) + } + if len(allOnDisk) != 10 { + t.Fatalf("before compact: expected 10 on disk, got %d", len(allOnDisk)) + } + + // Compact. + err = store.Compact(ctx, "compact") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // After compact: file should have only 3 lines. + allOnDisk, err = readMessages(store.jsonlPath("compact"), 0) + if err != nil { + t.Fatalf("readMessages: %v", err) + } + if len(allOnDisk) != 3 { + t.Fatalf("after compact: expected 3 on disk, got %d", len(allOnDisk)) + } + + // GetHistory should still return the same 3 messages. + history, err := store.GetHistory(ctx, "compact") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("expected 3, got %d", len(history)) + } + if history[0].Content != "h" || history[2].Content != "j" { + t.Errorf("wrong content: %+v", history) + } +} + +func TestCompact_NoOpWhenNoSkip(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 5; i++ { + err := store.AddMessage(ctx, "noop", "user", "msg") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Compact without prior truncation — should be a no-op. + err := store.Compact(ctx, "noop") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + history, err := store.GetHistory(ctx, "noop") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 5 { + t.Errorf("expected 5, got %d", len(history)) + } +} + +func TestCompact_ThenAppend(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 8; i++ { + err := store.AddMessage(ctx, "cap", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + err := store.TruncateHistory(ctx, "cap", 2) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + err = store.Compact(ctx, "cap") + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // Append after compaction should work correctly. + err = store.AddMessage(ctx, "cap", "user", "new") + if err != nil { + t.Fatalf("AddMessage after compact: %v", err) + } + + history, err := store.GetHistory(ctx, "cap") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("expected 3, got %d", len(history)) + } + // g, h (kept from truncation), new (appended after compaction). + if history[0].Content != "g" { + t.Errorf("first = %q, want 'g'", history[0].Content) + } + if history[2].Content != "new" { + t.Errorf("last = %q, want 'new'", history[2].Content) + } +} + +func TestTruncateHistory_StaleMetaCount(t *testing.T) { + // Simulates a crash between JSONL append and meta update in addMsg: + // file has N+1 lines but meta.Count is still N. TruncateHistory must + // reconcile with the real line count so that keepLast is accurate. + store := newTestStore(t) + ctx := context.Background() + + // Write 10 messages normally (meta.Count = 10). + for i := 0; i < 10; i++ { + err := store.AddMessage(ctx, "stale", "user", string(rune('a'+i))) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + // Simulate crash: append a line to JSONL but do NOT update meta. + // This leaves meta.Count = 10 while the file has 11 lines. + jsonlPath := store.jsonlPath("stale") + f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + t.Fatalf("open for append: %v", err) + } + _, err = f.WriteString(`{"role":"user","content":"orphan"}` + "\n") + if err != nil { + t.Fatalf("write orphan: %v", err) + } + f.Close() + + // TruncateHistory(keepLast=4) should keep the last 4 of 11 lines, + // not the last 4 of 10. + err = store.TruncateHistory(ctx, "stale", 4) + if err != nil { + t.Fatalf("TruncateHistory: %v", err) + } + + history, err := store.GetHistory(ctx, "stale") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 4 { + t.Fatalf("expected 4, got %d", len(history)) + } + // Last 4 of [a,b,c,d,e,f,g,h,i,j,orphan] = [h,i,j,orphan] + if history[0].Content != "h" { + t.Errorf("first kept = %q, want 'h'", history[0].Content) + } + if history[3].Content != "orphan" { + t.Errorf("last kept = %q, want 'orphan'", history[3].Content) + } +} + +func TestCrashRecovery_PartialLine(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + // Write a valid message first. + err := store.AddMessage(ctx, "crash", "user", "valid") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + // Simulate a crash by appending a partial JSON line directly. + jsonlPath := store.jsonlPath("crash") + f, err := os.OpenFile(jsonlPath, os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + t.Fatalf("open for append: %v", err) + } + _, err = f.WriteString(`{"role":"user","content":"incomple`) + if err != nil { + t.Fatalf("write partial: %v", err) + } + f.Close() + + // GetHistory should return only the valid message. + history, err := store.GetHistory(ctx, "crash") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 valid message, got %d", len(history)) + } + if history[0].Content != "valid" { + t.Errorf("content = %q", history[0].Content) + } +} + +func TestPersistence_AcrossInstances(t *testing.T) { + dir := t.TempDir() + ctx := context.Background() + + // Write with first instance. + store1, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + err = store1.AddMessage(ctx, "persist", "user", "remember me") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store1.SetSummary(ctx, "persist", "a test session") + if err != nil { + t.Fatalf("SetSummary: %v", err) + } + store1.Close() + + // Read with second instance. + store2, err := NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + defer store2.Close() + + history, err := store2.GetHistory(ctx, "persist") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 || history[0].Content != "remember me" { + t.Errorf("history = %+v", history) + } + + summary, err := store2.GetSummary(ctx, "persist") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "a test session" { + t.Errorf("summary = %q", summary) + } +} + +func TestConcurrent_AddAndRead(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + var wg sync.WaitGroup + const goroutines = 10 + const msgsPerGoroutine = 20 + + // Concurrent writes. + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < msgsPerGoroutine; i++ { + _ = store.AddMessage(ctx, "concurrent", "user", "msg") + } + }() + } + wg.Wait() + + history, err := store.GetHistory(ctx, "concurrent") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + expected := goroutines * msgsPerGoroutine + if len(history) != expected { + t.Errorf("expected %d messages, got %d", expected, len(history)) + } +} + +func TestConcurrent_SummarizeRace(t *testing.T) { + // Simulates the #704 race: one goroutine adds messages while + // another truncates + sets summary — like summarizeSession(). + store := newTestStore(t) + ctx := context.Background() + + // Seed with some messages. + for i := 0; i < 20; i++ { + err := store.AddMessage(ctx, "race", "user", "seed") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + } + + var wg sync.WaitGroup + + // Writer goroutine (main agent loop). + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 50; i++ { + _ = store.AddMessage(ctx, "race", "user", "new") + } + }() + + // Summarizer goroutine (background task). + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + _ = store.SetSummary(ctx, "race", "summary") + _ = store.TruncateHistory(ctx, "race", 5) + } + }() + + wg.Wait() + + // Verify the store is still in a consistent state. + _, err := store.GetHistory(ctx, "race") + if err != nil { + t.Fatalf("GetHistory after race: %v", err) + } + _, err = store.GetSummary(ctx, "race") + if err != nil { + t.Fatalf("GetSummary after race: %v", err) + } +} + +func TestMultipleSessions_Isolation(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + err := store.AddMessage(ctx, "s1", "user", "msg for s1") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + err = store.AddMessage(ctx, "s2", "user", "msg for s2") + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + h1, err := store.GetHistory(ctx, "s1") + if err != nil { + t.Fatalf("GetHistory s1: %v", err) + } + h2, err := store.GetHistory(ctx, "s2") + if err != nil { + t.Fatalf("GetHistory s2: %v", err) + } + + if len(h1) != 1 || h1[0].Content != "msg for s1" { + t.Errorf("s1 history = %+v", h1) + } + if len(h2) != 1 || h2[0].Content != "msg for s2" { + t.Errorf("s2 history = %+v", h2) + } +} + +func BenchmarkAddMessage(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = store.AddMessage(ctx, "bench", "user", "benchmark message content") + } +} + +func BenchmarkGetHistory_100(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + for i := 0; i < 100; i++ { + _ = store.AddMessage(ctx, "bench", "user", "message content") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = store.GetHistory(ctx, "bench") + } +} + +func BenchmarkGetHistory_1000(b *testing.B) { + dir := b.TempDir() + store, err := NewJSONLStore(dir) + if err != nil { + b.Fatalf("NewJSONLStore: %v", err) + } + defer store.Close() + ctx := context.Background() + + for i := 0; i < 1000; i++ { + _ = store.AddMessage(ctx, "bench", "user", "message content") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = store.GetHistory(ctx, "bench") + } +} diff --git a/pkg/memory/migration.go b/pkg/memory/migration.go new file mode 100644 index 000000000..c9d5176ab --- /dev/null +++ b/pkg/memory/migration.go @@ -0,0 +1,108 @@ +package memory + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// jsonSession mirrors pkg/session.Session for migration purposes. +type jsonSession struct { + Key string `json:"key"` + Messages []providers.Message `json:"messages"` + Summary string `json:"summary,omitempty"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` +} + +// MigrateFromJSON reads legacy sessions/*.json files from sessionsDir, +// writes them into the Store, and renames each migrated file to +// .json.migrated as a backup. Returns the number of sessions migrated. +// +// Files that fail to parse are logged and skipped. Already-migrated +// files (.json.migrated) are ignored, making the function idempotent. +func MigrateFromJSON( + ctx context.Context, sessionsDir string, store Store, +) (int, error) { + entries, err := os.ReadDir(sessionsDir) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("memory: read sessions dir: %w", err) + } + + migrated := 0 + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".json") { + continue + } + // Skip already-migrated files. + if strings.HasSuffix(name, ".migrated") { + continue + } + + srcPath := filepath.Join(sessionsDir, name) + + data, readErr := os.ReadFile(srcPath) + if readErr != nil { + log.Printf("memory: migrate: skip %s: %v", name, readErr) + continue + } + + var sess jsonSession + if parseErr := json.Unmarshal(data, &sess); parseErr != nil { + log.Printf("memory: migrate: skip %s: %v", name, parseErr) + continue + } + + // Use the key from the JSON content, not the filename. + // Filenames are sanitized (":" → "_") but keys are not. + key := sess.Key + if key == "" { + key = strings.TrimSuffix(name, ".json") + } + + // Use SetHistory (atomic replace) instead of per-message + // AddFullMessage. This makes migration idempotent: if the + // process crashes after writing messages but before the + // rename below, a retry replaces the partial data cleanly + // instead of duplicating messages. + if setErr := store.SetHistory(ctx, key, sess.Messages); setErr != nil { + return migrated, fmt.Errorf( + "memory: migrate %s: set history: %w", + name, setErr, + ) + } + + if sess.Summary != "" { + if sumErr := store.SetSummary(ctx, key, sess.Summary); sumErr != nil { + return migrated, fmt.Errorf( + "memory: migrate %s: set summary: %w", + name, sumErr, + ) + } + } + + // Rename to .migrated as backup (not delete). + renameErr := os.Rename(srcPath, srcPath+".migrated") + if renameErr != nil { + log.Printf("memory: migrate: rename %s: %v", name, renameErr) + } + + migrated++ + } + + return migrated, nil +} diff --git a/pkg/memory/migration_test.go b/pkg/memory/migration_test.go new file mode 100644 index 000000000..3170758b7 --- /dev/null +++ b/pkg/memory/migration_test.go @@ -0,0 +1,384 @@ +package memory + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func writeJSONSession( + t *testing.T, dir string, filename string, sess jsonSession, +) { + t.Helper() + data, err := json.MarshalIndent(sess, "", " ") + if err != nil { + t.Fatalf("marshal session: %v", err) + } + err = os.WriteFile(filepath.Join(dir, filename), data, 0o644) + if err != nil { + t.Fatalf("write session file: %v", err) + } +} + +func TestMigrateFromJSON_Basic(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "test.json", jsonSession{ + Key: "test", + Messages: []providers.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi"}, + }, + Summary: "A greeting.", + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 migrated, got %d", count) + } + + history, err := store.GetHistory(ctx, "test") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if history[0].Content != "hello" || history[1].Content != "hi" { + t.Errorf("unexpected messages: %+v", history) + } + + summary, err := store.GetSummary(ctx, "test") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "A greeting." { + t.Errorf("summary = %q", summary) + } +} + +func TestMigrateFromJSON_WithToolCalls(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "tools.json", jsonSession{ + Key: "tools", + Messages: []providers.Message{ + { + Role: "assistant", + Content: "Searching...", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "web_search", + Arguments: `{"q":"test"}`, + }, + }, + }, + }, + { + Role: "tool", + Content: "result", + ToolCallID: "call_1", + }, + }, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + history, err := store.GetHistory(ctx, "tools") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 2 { + t.Fatalf("expected 2 messages, got %d", len(history)) + } + if len(history[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(history[0].ToolCalls)) + } + if history[0].ToolCalls[0].Function.Name != "web_search" { + t.Errorf("function = %q", history[0].ToolCalls[0].Function.Name) + } + if history[1].ToolCallID != "call_1" { + t.Errorf("ToolCallID = %q", history[1].ToolCallID) + } +} + +func TestMigrateFromJSON_MultipleFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + writeJSONSession(t, sessionsDir, key+".json", jsonSession{ + Key: key, + Messages: []providers.Message{{Role: "user", Content: "msg " + key}}, + Created: time.Now(), + Updated: time.Now(), + }) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 3 { + t.Errorf("expected 3, got %d", count) + } + + for i := 0; i < 3; i++ { + key := string(rune('a' + i)) + history, histErr := store.GetHistory(ctx, key) + if histErr != nil { + t.Fatalf("GetHistory(%q): %v", key, histErr) + } + if len(history) != 1 { + t.Errorf("session %q: expected 1 msg, got %d", key, len(history)) + } + } +} + +func TestMigrateFromJSON_InvalidJSON(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + // One valid, one invalid. + writeJSONSession(t, sessionsDir, "good.json", jsonSession{ + Key: "good", + Messages: []providers.Message{{Role: "user", Content: "ok"}}, + Created: time.Now(), + Updated: time.Now(), + }) + err := os.WriteFile( + filepath.Join(sessionsDir, "bad.json"), + []byte("{invalid json"), + 0o644, + ) + if err != nil { + t.Fatalf("write bad file: %v", err) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1 (bad file skipped), got %d", count) + } + + history, err := store.GetHistory(ctx, "good") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_RenamesFiles(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "rename.json", jsonSession{ + Key: "rename", + Messages: []providers.Message{{Role: "user", Content: "hi"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + _, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + + // Original .json should not exist. + _, statErr := os.Stat(filepath.Join(sessionsDir, "rename.json")) + if !os.IsNotExist(statErr) { + t.Error("rename.json should have been renamed") + } + // .json.migrated should exist. + _, statErr = os.Stat( + filepath.Join(sessionsDir, "rename.json.migrated"), + ) + if statErr != nil { + t.Errorf("rename.json.migrated should exist: %v", statErr) + } +} + +func TestMigrateFromJSON_Idempotent(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "idem.json", jsonSession{ + Key: "idem", + Messages: []providers.Message{{Role: "user", Content: "once"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count1, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("first migration: %v", err) + } + if count1 != 1 { + t.Errorf("first run: expected 1, got %d", count1) + } + + // Second run should find only .migrated files, skip them. + count2, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("second migration: %v", err) + } + if count2 != 0 { + t.Errorf("second run: expected 0, got %d", count2) + } + + history, err := store.GetHistory(ctx, "idem") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Errorf("expected 1 message, got %d", len(history)) + } +} + +func TestMigrateFromJSON_ColonInKey(t *testing.T) { + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + // File is named telegram_123 (sanitized), but the key inside is telegram:123. + writeJSONSession(t, sessionsDir, "telegram_123.json", jsonSession{ + Key: "telegram:123", + Messages: []providers.Message{{Role: "user", Content: "from telegram"}}, + Created: time.Now(), + Updated: time.Now(), + }) + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 1 { + t.Errorf("expected 1, got %d", count) + } + + // Accessible via the original key "telegram:123". + history, err := store.GetHistory(ctx, "telegram:123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 { + t.Fatalf("expected 1 message, got %d", len(history)) + } + if history[0].Content != "from telegram" { + t.Errorf("content = %q", history[0].Content) + } + + // In the file-based store, "telegram:123" and "telegram_123" both + // sanitize to the same filename, so they share storage. This is + // expected — the colon-to-underscore mapping is a one-way function. + history2, err := store.GetHistory(ctx, "telegram_123") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history2) != 1 { + t.Errorf("expected 1 (same file), got %d", len(history2)) + } +} + +func TestMigrateFromJSON_RetryAfterCrash(t *testing.T) { + // Simulates a crash during migration: first run writes messages + // but doesn't rename the .json file. Second run must replace + // (not duplicate) the messages thanks to SetHistory semantics. + sessionsDir := t.TempDir() + store := newTestStore(t) + ctx := context.Background() + + writeJSONSession(t, sessionsDir, "retry.json", jsonSession{ + Key: "retry", + Messages: []providers.Message{ + {Role: "user", Content: "one"}, + {Role: "assistant", Content: "two"}, + }, + Created: time.Now(), + Updated: time.Now(), + }) + + // First migration succeeds — writes messages and renames file. + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("first migration: %v", err) + } + if count != 1 { + t.Fatalf("expected 1, got %d", count) + } + + // Simulate "crash before rename": restore the .json file. + src := filepath.Join(sessionsDir, "retry.json.migrated") + dst := filepath.Join(sessionsDir, "retry.json") + if renameErr := os.Rename(src, dst); renameErr != nil { + t.Fatalf("restore .json: %v", renameErr) + } + + // Second migration should re-import without duplicating messages. + count, err = MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("second migration: %v", err) + } + if count != 1 { + t.Fatalf("expected 1, got %d", count) + } + + history, err := store.GetHistory(ctx, "retry") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + // Must be exactly 2 messages (not 4 from duplication). + if len(history) != 2 { + t.Fatalf("expected 2 messages (no duplicates), got %d", len(history)) + } + if history[0].Content != "one" || history[1].Content != "two" { + t.Errorf("unexpected messages: %+v", history) + } +} + +func TestMigrateFromJSON_NonexistentDir(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + count, err := MigrateFromJSON(ctx, "/nonexistent/path", store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 0 { + t.Errorf("expected 0, got %d", count) + } +} diff --git a/pkg/memory/store.go b/pkg/memory/store.go new file mode 100644 index 000000000..b6e11707d --- /dev/null +++ b/pkg/memory/store.go @@ -0,0 +1,42 @@ +package memory + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// Store defines an interface for persistent session storage. +// Each method is an atomic operation — there is no separate Save() call. +type Store interface { + // AddMessage appends a simple text message to a session. + AddMessage(ctx context.Context, sessionKey, role, content string) error + + // AddFullMessage appends a complete message (with tool calls, etc.) to a session. + AddFullMessage(ctx context.Context, sessionKey string, msg providers.Message) error + + // GetHistory returns all messages for a session in insertion order. + // Returns an empty slice (not nil) if the session does not exist. + GetHistory(ctx context.Context, sessionKey string) ([]providers.Message, error) + + // GetSummary returns the conversation summary for a session. + // Returns an empty string if no summary exists. + GetSummary(ctx context.Context, sessionKey string) (string, error) + + // SetSummary updates the conversation summary for a session. + SetSummary(ctx context.Context, sessionKey, summary string) error + + // TruncateHistory removes all but the last keepLast messages from a session. + // If keepLast <= 0, all messages are removed. + TruncateHistory(ctx context.Context, sessionKey string, keepLast int) error + + // SetHistory replaces all messages in a session with the provided history. + SetHistory(ctx context.Context, sessionKey string, history []providers.Message) error + + // Compact reclaims storage by physically removing logically truncated + // data. Backends that do not accumulate dead data may return nil. + Compact(ctx context.Context, sessionKey string) error + + // Close releases any resources held by the store. + Close() error +} diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go index 1bb15f771..1b250b9b4 100644 --- a/pkg/providers/anthropic/provider.go +++ b/pkg/providers/anthropic/provider.go @@ -31,6 +31,9 @@ type Provider struct { baseURL string } +// SupportsThinking implements providers.ThinkingCapable. +func (p *Provider) SupportsThinking() bool { return true } + func NewProvider(token string) *Provider { return NewProviderWithBaseURL(token, "") } @@ -182,9 +185,80 @@ func buildParams( params.Tools = translateTools(tools) } + // Extended Thinking / Adaptive Thinking + // The thinking_level value directly determines the API parameter format: + // "adaptive" → {thinking: {type: "adaptive"}} + output_config.effort + // "low/medium/high/xhigh" → {thinking: {type: "enabled", budget_tokens: N}} + if level, ok := options["thinking_level"].(string); ok && level != "" && level != "off" { + applyThinkingConfig(¶ms, level) + } + return params, nil } +// applyThinkingConfig sets thinking parameters based on the level value. +// "adaptive" uses the adaptive thinking API (Claude 4.6+). +// All other levels use budget_tokens which is universally supported. +// +// Anthropic API constraint: temperature must not be set when thinking is enabled. +// budget_tokens must be strictly less than max_tokens. +func applyThinkingConfig(params *anthropic.MessageNewParams, level string) { + // Anthropic API rejects requests with temperature set alongside thinking. + // Reset to zero value (omitted from JSON serialization). + if params.Temperature.Valid() { + log.Printf("anthropic: temperature cleared because thinking is enabled (level=%s)", level) + } + params.Temperature = anthropic.MessageNewParams{}.Temperature + + if level == "adaptive" { + adaptive := anthropic.NewThinkingConfigAdaptiveParam() + params.Thinking = anthropic.ThinkingConfigParamUnion{OfAdaptive: &adaptive} + params.OutputConfig = anthropic.OutputConfigParam{ + Effort: anthropic.OutputConfigEffortHigh, + } + return + } + + budget := int64(levelToBudget(level)) + if budget <= 0 { + return + } + + // budget_tokens must be < max_tokens; clamp to respect user's max_tokens setting. + if budget >= params.MaxTokens { + log.Printf("anthropic: budget_tokens (%d) clamped to %d (max_tokens-1)", budget, params.MaxTokens-1) + budget = params.MaxTokens - 1 + } else if budget > params.MaxTokens*80/100 { + log.Printf("anthropic: thinking budget (%d) exceeds 80%% of max_tokens (%d), output may be truncated", + budget, params.MaxTokens) + } + params.Thinking = anthropic.ThinkingConfigParamOfEnabled(budget) +} + +// levelToBudget maps a thinking level to budget_tokens. +// Values are based on Anthropic's recommendations and community best practices: +// +// low = 4,096 — simple reasoning, quick debugging (Claude Code "think") +// medium = 16,384 — Anthropic recommended sweet spot for most tasks +// high = 32,000 — complex architecture, deep analysis (diminishing returns above this) +// xhigh = 64,000 — extreme reasoning, research problems, benchmarks +// +// Note: For Claude 4.6+, prefer adaptive thinking over manual budget_tokens. +func levelToBudget(level string) int { + switch level { + case "low": + return 4096 + case "medium": + return 16384 + case "high": + return 32000 + case "xhigh": + return 64000 + default: + return 0 + } +} + func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { result := make([]anthropic.ToolUnionParam, 0, len(tools)) for _, t := range tools { @@ -213,10 +287,14 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { func parseResponse(resp *anthropic.Message) *LLMResponse { var content strings.Builder + var reasoning strings.Builder var toolCalls []ToolCall for _, block := range resp.Content { switch block.Type { + case "thinking": + tb := block.AsThinking() + reasoning.WriteString(tb.Thinking) case "text": tb := block.AsText() content.WriteString(tb.Text) @@ -247,6 +325,7 @@ func parseResponse(resp *anthropic.Message) *LLMResponse { return &LLMResponse{ Content: content.String(), + Reasoning: reasoning.String(), ToolCalls: toolCalls, FinishReason: finishReason, Usage: &UsageInfo{ diff --git a/pkg/providers/anthropic/thinking_test.go b/pkg/providers/anthropic/thinking_test.go new file mode 100644 index 000000000..e69a3869e --- /dev/null +++ b/pkg/providers/anthropic/thinking_test.go @@ -0,0 +1,212 @@ +package anthropicprovider + +import ( + "encoding/json" + "testing" + + "github.com/anthropics/anthropic-sdk-go" +) + +func TestApplyThinkingConfig_Adaptive(t *testing.T) { + params := anthropic.MessageNewParams{ + MaxTokens: 16000, + Temperature: anthropic.Float(0.7), + } + applyThinkingConfig(¶ms, "adaptive") + + if params.Thinking.OfAdaptive == nil { + t.Fatal("expected adaptive thinking") + } + if params.Thinking.OfEnabled != nil { + t.Error("should not set enabled thinking in adaptive mode") + } + if params.OutputConfig.Effort != anthropic.OutputConfigEffortHigh { + t.Errorf("effort = %q, want %q", params.OutputConfig.Effort, anthropic.OutputConfigEffortHigh) + } + if params.Temperature.Valid() { + t.Error("temperature should be cleared when thinking is enabled") + } +} + +func TestApplyThinkingConfig_BudgetLevels(t *testing.T) { + tests := []struct { + level string + wantBudget int64 + }{ + {"low", 4096}, + {"medium", 16384}, + {"high", 32000}, + {"xhigh", 64000}, + } + + for _, tt := range tests { + t.Run(tt.level, func(t *testing.T) { + params := anthropic.MessageNewParams{ + MaxTokens: 200000, + Temperature: anthropic.Float(0.5), + } + applyThinkingConfig(¶ms, tt.level) + + if params.Thinking.OfEnabled == nil { + t.Fatal("expected enabled thinking") + } + if params.Thinking.OfAdaptive != nil { + t.Error("should not set adaptive thinking") + } + if params.Thinking.OfEnabled.BudgetTokens != tt.wantBudget { + t.Errorf("budget_tokens = %d, want %d", params.Thinking.OfEnabled.BudgetTokens, tt.wantBudget) + } + if params.OutputConfig.Effort != "" { + t.Errorf("effort = %q, want empty", params.OutputConfig.Effort) + } + if params.Temperature.Valid() { + t.Error("temperature should be cleared when thinking is enabled") + } + }) + } +} + +func TestApplyThinkingConfig_BudgetClamp(t *testing.T) { + // budget_tokens must be < max_tokens; clamp budget down to respect user's max_tokens. + params := anthropic.MessageNewParams{MaxTokens: 4096} + applyThinkingConfig(¶ms, "high") // budget=32000 > maxTokens=4096 + + if params.Thinking.OfEnabled == nil { + t.Fatal("expected enabled thinking") + } + if params.Thinking.OfEnabled.BudgetTokens != 4095 { + t.Errorf("budget_tokens = %d, want 4095 (maxTokens-1)", params.Thinking.OfEnabled.BudgetTokens) + } + if params.MaxTokens != 4096 { + t.Errorf("max_tokens should not be modified, got %d", params.MaxTokens) + } +} + +func TestApplyThinkingConfig_UnknownLevel(t *testing.T) { + params := anthropic.MessageNewParams{MaxTokens: 16000} + applyThinkingConfig(¶ms, "unknown") + + if params.Thinking.OfEnabled != nil { + t.Error("should not set enabled thinking for unknown level") + } + if params.Thinking.OfAdaptive != nil { + t.Error("should not set adaptive thinking for unknown level") + } +} + +func TestLevelToBudget(t *testing.T) { + tests := []struct { + name string + level string + want int + }{ + {"low", "low", 4096}, + {"medium", "medium", 16384}, + {"high", "high", 32000}, + {"xhigh", "xhigh", 64000}, + {"off", "off", 0}, + {"empty", "", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := levelToBudget(tt.level); got != tt.want { + t.Errorf("levelToBudget(%q) = %d, want %d", tt.level, got, tt.want) + } + }) + } +} + +func TestBuildParams_ThinkingClearsTemperature(t *testing.T) { + msgs := []Message{{Role: "user", Content: "hello"}} + opts := map[string]any{ + "max_tokens": 200000, + "temperature": 0.8, + "thinking_level": "medium", + } + + params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts) + if err != nil { + t.Fatal(err) + } + + if params.Temperature.Valid() { + t.Error("temperature should be cleared when thinking_level is set") + } + if params.Thinking.OfEnabled == nil { + t.Fatal("expected enabled thinking") + } + if params.Thinking.OfEnabled.BudgetTokens != 16384 { + t.Errorf("budget_tokens = %d, want 16384", params.Thinking.OfEnabled.BudgetTokens) + } +} + +// unmarshalBlocks constructs []ContentBlockUnion via JSON round-trip so that +// the internal JSON.raw field is populated (required by AsText/AsThinking). +func unmarshalBlocks(t *testing.T, jsonStr string) []anthropic.ContentBlockUnion { + t.Helper() + var blocks []anthropic.ContentBlockUnion + if err := json.Unmarshal([]byte(jsonStr), &blocks); err != nil { + t.Fatalf("unmarshalBlocks: %v", err) + } + return blocks +} + +func TestParseResponse_ThinkingBlock(t *testing.T) { + resp := &anthropic.Message{ + Content: unmarshalBlocks(t, `[ + {"type":"thinking","thinking":"Let me reason step by step...","signature":"sig"}, + {"type":"text","text":"The answer is 42."} + ]`), + StopReason: anthropic.StopReasonEndTurn, + } + + result := parseResponse(resp) + + if result.Reasoning != "Let me reason step by step..." { + t.Errorf("Reasoning = %q, want thinking content", result.Reasoning) + } + if result.Content != "The answer is 42." { + t.Errorf("Content = %q, want text content", result.Content) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", result.FinishReason) + } +} + +func TestParseResponse_NoThinkingBlock(t *testing.T) { + resp := &anthropic.Message{ + Content: unmarshalBlocks(t, `[ + {"type":"text","text":"Just a normal response."} + ]`), + StopReason: anthropic.StopReasonEndTurn, + } + + result := parseResponse(resp) + + if result.Reasoning != "" { + t.Errorf("Reasoning = %q, want empty", result.Reasoning) + } + if result.Content != "Just a normal response." { + t.Errorf("Content = %q, want text content", result.Content) + } +} + +func TestBuildParams_NoThinkingKeepsTemperature(t *testing.T) { + msgs := []Message{{Role: "user", Content: "hello"}} + opts := map[string]any{ + "temperature": 0.8, + } + + params, err := buildParams(msgs, nil, "claude-sonnet-4-6", opts) + if err != nil { + t.Fatal(err) + } + + if !params.Temperature.Valid() { + t.Error("temperature should be preserved when thinking is not set") + } + if params.Temperature.Value != 0.8 { + t.Errorf("temperature = %f, want 0.8", params.Temperature.Value) + } +} diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go index d4ee528b7..8a1890212 100644 --- a/pkg/providers/antigravity_provider.go +++ b/pkg/providers/antigravity_provider.go @@ -640,7 +640,10 @@ func FetchAntigravityProjectID(accessToken string) (string, error) { } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading loadCodeAssist response: %w", err) + } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("loadCodeAssist failed: %s", string(body)) } @@ -681,7 +684,10 @@ func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelIn } defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading fetchAvailableModels response: %w", err) + } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf( "fetchAvailableModels failed (HTTP %d): %s", diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index 5b3e42b9e..a0d09a835 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -181,6 +181,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.model = "deepseek-chat" } } + case "avian": + if cfg.Providers.Avian.APIKey != "" { + sel.apiKey = cfg.Providers.Avian.APIKey + sel.apiBase = cfg.Providers.Avian.APIBase + sel.proxy = cfg.Providers.Avian.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.avian.io/v1" + } + } case "mistral": if cfg.Providers.Mistral.APIKey != "" { sel.apiKey = cfg.Providers.Mistral.APIKey @@ -300,6 +309,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { if sel.apiBase == "" { sel.apiBase = "https://api.mistral.ai/v1" } + case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "": + sel.apiKey = cfg.Providers.Avian.APIKey + sel.apiBase = cfg.Providers.Avian.APIBase + sel.proxy = cfg.Providers.Avian.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.avian.io/v1" + } case cfg.Providers.VLLM.APIBase != "": sel.apiKey = cfg.Providers.VLLM.APIKey sel.apiBase = cfg.Providers.VLLM.APIBase diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 155317a3b..c05fb0ad4 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -94,7 +94,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", - "volcengine", "vllm", "qwen", "mistral": + "volcengine", "vllm", "qwen", "mistral", "avian": // All other OpenAI-compatible HTTP providers if cfg.APIKey == "" && cfg.APIBase == "" { return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) @@ -208,6 +208,8 @@ func getDefaultAPIBase(protocol string) string { return "http://localhost:8000/v1" case "mistral": return "https://api.mistral.ai/v1" + case "avian": + return "https://api.avian.io/v1" default: return "" } diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 3a18b8b16..1904ee153 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -116,7 +116,7 @@ func (p *Provider) Chat( requestBody := map[string]any{ "model": model, - "messages": stripSystemParts(messages), + "messages": serializeMessages(messages), } if len(tools) > 0 { @@ -296,19 +296,57 @@ type openaiMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` } -// stripSystemParts converts []Message to []openaiMessage, dropping the -// SystemParts field so it doesn't leak into the JSON payload sent to -// OpenAI-compatible APIs (some strict endpoints reject unknown fields). -func stripSystemParts(messages []Message) []openaiMessage { - out := make([]openaiMessage, len(messages)) - for i, m := range messages { - out[i] = openaiMessage{ - Role: m.Role, - Content: m.Content, - ReasoningContent: m.ReasoningContent, - ToolCalls: m.ToolCalls, - ToolCallID: m.ToolCallID, +// serializeMessages converts internal Message structs to the OpenAI wire format. +// - Strips SystemParts (unknown to third-party endpoints) +// - Converts messages with Media to multipart content format (text + image_url parts) +// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages +func serializeMessages(messages []Message) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + if len(m.Media) == 0 { + out = append(out, openaiMessage{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolCallID: m.ToolCallID, + }) + continue } + + // Multipart content format for messages with media + parts := make([]map[string]any, 0, 1+len(m.Media)) + if m.Content != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": m.Content, + }) + } + for _, mediaURL := range m.Media { + if strings.HasPrefix(mediaURL, "data:image/") { + parts = append(parts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": mediaURL, + }, + }) + } + } + + msg := map[string]any{ + "role": m.Role, + "content": parts, + } + if m.ToolCallID != "" { + msg["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + msg["tool_calls"] = m.ToolCalls + } + if m.ReasoningContent != "" { + msg["reasoning_content"] = m.ReasoningContent + } + out = append(out, msg) } return out } diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 53b9e75ee..174bcf00d 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -5,8 +5,11 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { @@ -416,3 +419,97 @@ func TestProvider_FunctionalOptionRequestTimeoutNonPositive(t *testing.T) { t.Fatalf("http timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout) } } + +func TestSerializeMessages_PlainText(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, + } + result := serializeMessages(messages) + + data, err := json.Marshal(result) + if err != nil { + t.Fatal(err) + } + + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["content"] != "hello" { + t.Fatalf("expected plain string content, got %v", msgs[0]["content"]) + } + if msgs[1]["reasoning_content"] != "thinking..." { + t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"]) + } +} + +func TestSerializeMessages_WithMedia(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + content, ok := msgs[0]["content"].([]any) + if !ok { + t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) + } + if len(content) != 2 { + t.Fatalf("expected 2 content parts, got %d", len(content)) + } + + textPart := content[0].(map[string]any) + if textPart["type"] != "text" || textPart["text"] != "describe this" { + t.Fatalf("text part mismatch: %v", textPart) + } + + imgPart := content[1].(map[string]any) + if imgPart["type"] != "image_url" { + t.Fatalf("expected image_url type, got %v", imgPart["type"]) + } + imgURL := imgPart["image_url"].(map[string]any) + if imgURL["url"] != "data:image/png;base64,abc123" { + t.Fatalf("image url mismatch: %v", imgURL["url"]) + } +} + +func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["tool_call_id"] != "call_1" { + t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"]) + } + // Content should be multipart array + if _, ok := msgs[0]["content"].([]any); !ok { + t.Fatalf("expected array content, got %T", msgs[0]["content"]) + } +} + +func TestSerializeMessages_StripsSystemParts(t *testing.T) { + messages := []protocoltypes.Message{ + { + Role: "system", + Content: "you are helpful", + SystemParts: []protocoltypes.ContentBlock{ + {Type: "text", Text: "you are helpful"}, + }, + }, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + raw := string(data) + if strings.Contains(raw, "system_parts") { + t.Fatal("system_parts should not appear in serialized output") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 99f13334e..194c1aa6f 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -65,6 +65,7 @@ type ContentBlock struct { type Message struct { Role string `json:"role"` Content string `json:"content"` + Media []string `json:"media,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters ToolCalls []ToolCall `json:"tool_calls,omitempty"` diff --git a/pkg/providers/types.go b/pkg/providers/types.go index f0c168bc6..68bbd1e65 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -37,6 +37,13 @@ type StatefulProvider interface { Close() } +// ThinkingCapable is an optional interface for providers that support +// extended thinking (e.g. Anthropic). Used by the agent loop to warn +// when thinking_level is configured but the active provider cannot use it. +type ThinkingCapable interface { + SupportsThinking() bool +} + // FailoverReason classifies why an LLM request failed for fallback decisions. type FailoverReason string diff --git a/pkg/skills/clawhub_registry.go b/pkg/skills/clawhub_registry.go index f78197bbe..bd4bed8fb 100644 --- a/pkg/skills/clawhub_registry.go +++ b/pkg/skills/clawhub_registry.go @@ -259,15 +259,7 @@ func (c *ClawHubRegistry) DownloadAndInstall( } u.RawQuery = q.Encode() - req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - if c.authToken != "" { - req.Header.Set("Authorization", "Bearer "+c.authToken) - } - - tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize)) + tmpPath, err := c.downloadToTempFileWithRetry(ctx, u.String()) if err != nil { return nil, fmt.Errorf("download failed: %w", err) } @@ -284,17 +276,12 @@ func (c *ClawHubRegistry) DownloadAndInstall( // --- HTTP helper --- func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) { - req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) + req, err := c.newGetRequest(ctx, urlStr, "application/json") if err != nil { return nil, err } - req.Header.Set("Accept", "application/json") - if c.authToken != "" { - req.Header.Set("Authorization", "Bearer "+c.authToken) - } - - resp, err := c.client.Do(req) + resp, err := utils.DoRequestWithRetry(c.client, req) if err != nil { return nil, err } @@ -312,3 +299,64 @@ func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, err return body, nil } + +func (c *ClawHubRegistry) newGetRequest(ctx context.Context, urlStr, accept string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlStr, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", accept) + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) + } + return req, nil +} + +func (c *ClawHubRegistry) downloadToTempFileWithRetry(ctx context.Context, urlStr string) (string, error) { + req, err := c.newGetRequest(ctx, urlStr, "application/zip") + if err != nil { + return "", err + } + + resp, err := utils.DoRequestWithRetry(c.client, req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + errBody := make([]byte, 512) + n, _ := io.ReadFull(resp.Body, errBody) + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n])) + } + + tmpFile, err := os.CreateTemp("", "picoclaw-dl-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + cleanup := func() { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + } + + src := io.LimitReader(resp.Body, int64(c.maxZipSize)+1) + written, err := io.Copy(tmpFile, src) + if err != nil { + cleanup() + return "", fmt.Errorf("download write failed: %w", err) + } + + if written > int64(c.maxZipSize) { + cleanup() + return "", fmt.Errorf("download too large: %d bytes (max %d)", written, c.maxZipSize) + } + + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpPath) + return "", fmt.Errorf("failed to close temp file: %w", err) + } + + return tmpPath, nil +} diff --git a/pkg/skills/clawhub_registry_test.go b/pkg/skills/clawhub_registry_test.go index 65ee638da..055da22dc 100644 --- a/pkg/skills/clawhub_registry_test.go +++ b/pkg/skills/clawhub_registry_test.go @@ -54,6 +54,39 @@ func TestClawHubRegistrySearch(t *testing.T) { assert.Equal(t, "clawhub", results[0].RegistryName) } +func TestClawHubRegistrySearchRetries429(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + + slug := "github" + name := "GitHub Integration" + summary := "Interact with GitHub repos" + version := "1.0.0" + + json.NewEncoder(w).Encode(clawhubSearchResponse{ + Results: []clawhubSearchResult{ + {Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version}, + }, + }) + })) + defer srv.Close() + + reg := newTestRegistry(srv.URL, "") + results, err := reg.Search(context.Background(), "github", 5) + + require.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, 2, attempts) + assert.Equal(t, "github", results[0].Slug) +} + func TestClawHubRegistryGetSkillMeta(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/api/v1/skills/github", r.URL.Path) @@ -137,6 +170,54 @@ func TestClawHubRegistryDownloadAndInstall(t *testing.T) { assert.Contains(t, string(readmeContent), "# Test Skill") } +func TestClawHubRegistryDownloadAndInstallRetries429(t *testing.T) { + zipBuf := createTestZip(t, map[string]string{ + "SKILL.md": "---\nname: retry-skill\ndescription: A test\n---\nHello skill", + }) + + downloadAttempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/skills/retry-skill": + json.NewEncoder(w).Encode(clawhubSkillResponse{ + Slug: "retry-skill", + DisplayName: "Retry Skill", + Summary: "A retry test skill", + LatestVersion: &clawhubVersionInfo{Version: "1.0.0"}, + }) + case "/api/v1/download": + downloadAttempts++ + if downloadAttempts == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + assert.Equal(t, "retry-skill", r.URL.Query().Get("slug")) + w.Header().Set("Content-Type", "application/zip") + w.Write(zipBuf) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + tmpDir := t.TempDir() + targetDir := filepath.Join(tmpDir, "retry-skill") + + reg := newTestRegistry(srv.URL, "") + result, err := reg.DownloadAndInstall(context.Background(), "retry-skill", "", targetDir) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "1.0.0", result.Version) + assert.Equal(t, 2, downloadAttempts) + + skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md")) + require.NoError(t, err) + assert.Contains(t, string(skillContent), "Hello skill") +} + func TestClawHubRegistryAuthToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index fcbcf934b..30d84635a 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -64,6 +64,29 @@ type SkillsLoader struct { builtinSkills string // builtin skills } +// SkillRoots returns all unique skill root directories used by this loader. +// The order follows resolution priority: workspace > global > builtin. +func (sl *SkillsLoader) SkillRoots() []string { + roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills} + seen := make(map[string]struct{}, len(roots)) + out := make([]string, 0, len(roots)) + + for _, root := range roots { + trimmed := strings.TrimSpace(root) + if trimmed == "" { + continue + } + clean := filepath.Clean(trimmed) + if _, ok := seen[clean]; ok { + continue + } + seen[clean] = struct{}{} + out = append(out, clean) + } + + return out +} + func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader { return &SkillsLoader{ workspace: workspace, diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go index 9428bea62..31619f9c2 100644 --- a/pkg/skills/loader_test.go +++ b/pkg/skills/loader_test.go @@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) { }) } } + +func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) { + tmp := t.TempDir() + workspace := filepath.Join(tmp, "workspace") + global := filepath.Join(tmp, "global") + builtin := filepath.Join(tmp, "builtin") + + sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n") + roots := sl.SkillRoots() + + assert.Equal(t, []string{ + filepath.Join(workspace, "skills"), + global, + builtin, + }, roots) +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 770d8cb04..ec743e164 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -10,11 +10,38 @@ type Tool interface { Execute(ctx context.Context, args map[string]any) *ToolResult } -// ContextualTool is an optional interface that tools can implement -// to receive the current message context (channel, chatID) -type ContextualTool interface { - Tool - SetContext(channel, chatID string) +// --- Request-scoped tool context (channel / chatID) --- +// +// Carried via context.Value so that concurrent tool calls each receive +// their own immutable copy — no mutable state on singleton tool instances. +// +// Keys are unexported pointer-typed vars — guaranteed collision-free, +// and only accessible through the helper functions below. + +type toolCtxKey struct{ name string } + +var ( + ctxKeyChannel = &toolCtxKey{"channel"} + ctxKeyChatID = &toolCtxKey{"chatID"} +) + +// WithToolContext returns a child context carrying channel and chatID. +func WithToolContext(ctx context.Context, channel, chatID string) context.Context { + ctx = context.WithValue(ctx, ctxKeyChannel, channel) + ctx = context.WithValue(ctx, ctxKeyChatID, chatID) + return ctx +} + +// ToolChannel extracts the channel from ctx, or "" if unset. +func ToolChannel(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyChannel).(string) + return v +} + +// ToolChatID extracts the chatID from ctx, or "" if unset. +func ToolChatID(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyChatID).(string) + return v } // AsyncCallback is a function type that async tools use to notify completion. @@ -22,51 +49,36 @@ type ContextualTool interface { // // The ctx parameter allows the callback to be canceled if the agent is shutting down. // The result parameter contains the tool's execution result. -// -// Example usage in an async tool: -// -// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { -// // Start async work in background -// go func() { -// result := doAsyncWork() -// if t.callback != nil { -// t.callback(ctx, result) -// } -// }() -// return AsyncResult("Async task started") -// } type AsyncCallback func(ctx context.Context, result *ToolResult) -// AsyncTool is an optional interface that tools can implement to support +// AsyncExecutor is an optional interface that tools can implement to support // asynchronous execution with completion callbacks. // -// Async tools return immediately with an AsyncResult, then notify completion -// via the callback set by SetCallback. +// Unlike the old AsyncTool pattern (SetCallback + Execute), AsyncExecutor +// receives the callback as a parameter of ExecuteAsync. This eliminates the +// data race where concurrent calls could overwrite each other's callbacks +// on a shared tool instance. // // This is useful for: -// - Long-running operations that shouldn't block the agent loop -// - Subagent spawns that complete independently -// - Background tasks that need to report results later +// - Long-running operations that shouldn't block the agent loop +// - Subagent spawns that complete independently +// - Background tasks that need to report results later // // Example: // -// type SpawnTool struct { -// callback AsyncCallback -// } -// -// func (t *SpawnTool) SetCallback(cb AsyncCallback) { -// t.callback = cb -// } -// -// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { -// go t.runSubagent(ctx, args) +// func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult { +// go func() { +// result := t.runSubagent(ctx, args) +// if cb != nil { cb(ctx, result) } +// }() // return AsyncResult("Subagent spawned, will report back") // } -type AsyncTool interface { +type AsyncExecutor interface { Tool - // SetCallback registers a callback function to be invoked when the async operation completes. - // The callback will be called from a goroutine and should handle thread-safety if needed. - SetCallback(cb AsyncCallback) + // ExecuteAsync runs the tool asynchronously. The callback cb will be + // invoked (possibly from another goroutine) when the async operation + // completes. cb is guaranteed to be non-nil by the caller (registry). + ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult } func ToolToSchema(tool Tool) map[string]any { diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 6888d1326..31ac9ab88 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strings" - "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" @@ -24,9 +23,6 @@ type CronTool struct { executor JobExecutor msgBus *bus.MessageBus execTool *ExecTool - channel string - chatID string - mu sync.RWMutex } // NewCronTool creates a new CronTool @@ -102,14 +98,6 @@ func (t *CronTool) Parameters() map[string]any { } } -// SetContext sets the current session context for job creation -func (t *CronTool) SetContext(channel, chatID string) { - t.mu.Lock() - defer t.mu.Unlock() - t.channel = channel - t.chatID = chatID -} - // Execute runs the tool with the given arguments func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult { action, ok := args["action"].(string) @@ -119,7 +107,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult switch action { case "add": - return t.addJob(args) + return t.addJob(ctx, args) case "list": return t.listJobs() case "remove": @@ -133,11 +121,9 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult } } -func (t *CronTool) addJob(args map[string]any) *ToolResult { - t.mu.RLock() - channel := t.channel - chatID := t.chatID - t.mu.RUnlock() +func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult { + channel := ToolChannel(ctx) + chatID := ToolChatID(ctx) if channel == "" || chatID == "" { return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 15ef4ff73..438ceeddd 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -3,15 +3,14 @@ package tools import ( "context" "fmt" + "sync/atomic" ) type SendCallback func(channel, chatID, content string) error type MessageTool struct { - sendCallback SendCallback - defaultChannel string - defaultChatID string - sentInRound bool // Tracks whether a message was sent in the current processing round + sendCallback SendCallback + sentInRound atomic.Bool // Tracks whether a message was sent in the current processing round } func NewMessageTool() *MessageTool { @@ -47,15 +46,15 @@ func (t *MessageTool) Parameters() map[string]any { } } -func (t *MessageTool) SetContext(channel, chatID string) { - t.defaultChannel = channel - t.defaultChatID = chatID - t.sentInRound = false // Reset send tracking for new processing round +// ResetSentInRound resets the per-round send tracker. +// Called by the agent loop at the start of each inbound message processing round. +func (t *MessageTool) ResetSentInRound() { + t.sentInRound.Store(false) } // HasSentInRound returns true if the message tool sent a message during the current round. func (t *MessageTool) HasSentInRound() bool { - return t.sentInRound + return t.sentInRound.Load() } func (t *MessageTool) SetSendCallback(callback SendCallback) { @@ -72,10 +71,10 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes chatID, _ := args["chat_id"].(string) if channel == "" { - channel = t.defaultChannel + channel = ToolChannel(ctx) } if chatID == "" { - chatID = t.defaultChatID + chatID = ToolChatID(ctx) } if channel == "" || chatID == "" { @@ -94,7 +93,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes } } - t.sentInRound = true + t.sentInRound.Store(true) // Silent: user already received the message directly return &ToolResult{ ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index 717c1117b..05630972e 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -8,7 +8,6 @@ import ( func TestMessageTool_Execute_Success(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") var sentChannel, sentChatID, sentContent string tool.SetSendCallback(func(channel, chatID, content string) error { @@ -18,7 +17,7 @@ func TestMessageTool_Execute_Success(t *testing.T) { return nil }) - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{ "content": "Hello, world!", } @@ -60,7 +59,6 @@ func TestMessageTool_Execute_Success(t *testing.T) { func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool := NewMessageTool() - tool.SetContext("default-channel", "default-chat-id") var sentChannel, sentChatID string tool.SetSendCallback(func(channel, chatID, content string) error { @@ -69,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { return nil }) - ctx := context.Background() + ctx := WithToolContext(context.Background(), "default-channel", "default-chat-id") args := map[string]any{ "content": "Test message", "channel": "custom-channel", @@ -96,14 +94,13 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") sendErr := errors.New("network error") tool.SetSendCallback(func(channel, chatID, content string) error { return sendErr }) - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{ "content": "Test message", } @@ -133,9 +130,8 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { func TestMessageTool_Execute_MissingContent(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{} // content missing result := tool.Execute(ctx, args) @@ -151,7 +147,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) { func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() - // No SetContext called, so defaultChannel and defaultChatID are empty + // No WithToolContext — channel/chatID are empty tool.SetSendCallback(func(channel, chatID, content string) error { return nil @@ -175,10 +171,9 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { func TestMessageTool_Execute_NotConfigured(t *testing.T) { tool := NewMessageTool() - tool.SetContext("test-channel", "test-chat-id") // No SetSendCallback called - ctx := context.Background() + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") args := map[string]any{ "content": "Test message", } diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 0ba983e02..ca8436c67 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -45,8 +45,9 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string } // ExecuteWithContext executes a tool with channel/chatID context and optional async callback. -// If the tool implements AsyncTool and a non-nil callback is provided, -// the callback will be set on the tool before execution. +// If the tool implements AsyncExecutor and a non-nil callback is provided, +// ExecuteAsync is called instead of Execute — the callback is a parameter, +// never stored as mutable state on the tool. func (r *ToolRegistry) ExecuteWithContext( ctx context.Context, name string, @@ -69,22 +70,23 @@ func (r *ToolRegistry) ExecuteWithContext( return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } - // If tool implements ContextualTool, set context - if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" { - contextualTool.SetContext(channel, chatID) - } + // Inject channel/chatID into ctx so tools read them via ToolChannel(ctx)/ToolChatID(ctx). + // Always inject — tools validate what they require. + ctx = WithToolContext(ctx, channel, chatID) - // If tool implements AsyncTool and callback is provided, set callback - if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil { - asyncTool.SetCallback(asyncCallback) - logger.DebugCF("tool", "Async callback injected", + // If tool implements AsyncExecutor and callback is provided, use ExecuteAsync. + // The callback is a call parameter, not mutable state on the tool instance. + var result *ToolResult + start := time.Now() + if asyncExec, ok := tool.(AsyncExecutor); ok && asyncCallback != nil { + logger.DebugCF("tool", "Executing async tool via ExecuteAsync", map[string]any{ "tool": name, }) + result = asyncExec.ExecuteAsync(ctx, args, asyncCallback) + } else { + result = tool.Execute(ctx, args) } - - start := time.Now() - result := tool.Execute(ctx, args) duration := time.Since(start) // Log based on result type diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go index 8fe88ca78..92d7d5abd 100644 --- a/pkg/tools/registry_test.go +++ b/pkg/tools/registry_test.go @@ -25,24 +25,24 @@ func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolRes return m.result } -type mockCtxTool struct { +type mockContextAwareTool struct { mockRegistryTool - channel string - chatID string + lastCtx context.Context } -func (m *mockCtxTool) SetContext(channel, chatID string) { - m.channel = channel - m.chatID = chatID +func (m *mockContextAwareTool) Execute(ctx context.Context, _ map[string]any) *ToolResult { + m.lastCtx = ctx + return m.result } type mockAsyncRegistryTool struct { mockRegistryTool - cb AsyncCallback + lastCB AsyncCallback } -func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) { - m.cb = cb +func (m *mockAsyncRegistryTool) ExecuteAsync(_ context.Context, args map[string]any, cb AsyncCallback) *ToolResult { + m.lastCB = cb + return m.result } // --- helpers --- @@ -136,34 +136,44 @@ func TestToolRegistry_Execute_NotFound(t *testing.T) { } } -func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) { +func TestToolRegistry_ExecuteWithContext_InjectsToolContext(t *testing.T) { r := NewToolRegistry() - ct := &mockCtxTool{ + ct := &mockContextAwareTool{ mockRegistryTool: *newMockTool("ctx_tool", "needs context"), } r.Register(ct) r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil) - if ct.channel != "telegram" { - t.Errorf("expected channel 'telegram', got %q", ct.channel) + if ct.lastCtx == nil { + t.Fatal("expected Execute to be called") } - if ct.chatID != "chat-42" { - t.Errorf("expected chatID 'chat-42', got %q", ct.chatID) + if got := ToolChannel(ct.lastCtx); got != "telegram" { + t.Errorf("expected channel 'telegram', got %q", got) + } + if got := ToolChatID(ct.lastCtx); got != "chat-42" { + t.Errorf("expected chatID 'chat-42', got %q", got) } } -func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) { +func TestToolRegistry_ExecuteWithContext_EmptyContext(t *testing.T) { r := NewToolRegistry() - ct := &mockCtxTool{ + ct := &mockContextAwareTool{ mockRegistryTool: *newMockTool("ctx_tool", "needs context"), } r.Register(ct) r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil) - if ct.channel != "" || ct.chatID != "" { - t.Error("SetContext should not be called with empty channel/chatID") + if ct.lastCtx == nil { + t.Fatal("expected Execute to be called") + } + // Empty values are still injected; tools decide what to do with them. + if got := ToolChannel(ct.lastCtx); got != "" { + t.Errorf("expected empty channel, got %q", got) + } + if got := ToolChatID(ct.lastCtx); got != "" { + t.Errorf("expected empty chatID, got %q", got) } } @@ -179,14 +189,14 @@ func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) { cb := func(_ context.Context, _ *ToolResult) { called = true } result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb) - if at.cb == nil { - t.Error("expected SetCallback to have been called") + if at.lastCB == nil { + t.Error("expected ExecuteAsync to have received a callback") } if !result.Async { t.Error("expected async result") } - at.cb(context.Background(), SilentResult("done")) + at.lastCB(context.Background(), SilentResult("done")) if !called { t.Error("expected callback to be invoked") } diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 8b166b41f..be40ffda2 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -8,25 +8,18 @@ import ( type SpawnTool struct { manager *SubagentManager - originChannel string - originChatID string allowlistCheck func(targetAgentID string) bool - callback AsyncCallback // For async completion notification } +// Compile-time check: SpawnTool implements AsyncExecutor. +var _ AsyncExecutor = (*SpawnTool)(nil) + func NewSpawnTool(manager *SubagentManager) *SpawnTool { return &SpawnTool{ - manager: manager, - originChannel: "cli", - originChatID: "direct", + manager: manager, } } -// SetCallback implements AsyncTool interface for async completion notification -func (t *SpawnTool) SetCallback(cb AsyncCallback) { - t.callback = cb -} - func (t *SpawnTool) Name() string { return "spawn" } @@ -56,16 +49,21 @@ func (t *SpawnTool) Parameters() map[string]any { } } -func (t *SpawnTool) SetContext(channel, chatID string) { - t.originChannel = channel - t.originChatID = chatID -} - func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { t.allowlistCheck = check } func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + return t.execute(ctx, args, nil) +} + +// ExecuteAsync implements AsyncExecutor. The callback is passed through to the +// subagent manager as a call parameter — never stored on the SpawnTool instance. +func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult { + return t.execute(ctx, args, cb) +} + +func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult { task, ok := args["task"].(string) if !ok || strings.TrimSpace(task) == "" { return ErrorResult("task is required and must be a non-empty string") @@ -85,8 +83,20 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul return ErrorResult("Subagent manager not configured") } + // Read channel/chatID from context (injected by registry). + // Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests) + // to preserve the same defaults as the original NewSpawnTool constructor. + channel := ToolChannel(ctx) + if channel == "" { + channel = "cli" + } + chatID := ToolChatID(ctx) + if chatID == "" { + chatID = "direct" + } + // Pass callback to manager for async completion notification - result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback) + result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb) if err != nil { return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 69f1a49a2..429340047 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -252,16 +252,12 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask { // Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion // and returns the result directly in the ToolResult. type SubagentTool struct { - manager *SubagentManager - originChannel string - originChatID string + manager *SubagentManager } func NewSubagentTool(manager *SubagentManager) *SubagentTool { return &SubagentTool{ - manager: manager, - originChannel: "cli", - originChatID: "direct", + manager: manager, } } @@ -290,11 +286,6 @@ func (t *SubagentTool) Parameters() map[string]any { } } -func (t *SubagentTool) SetContext(channel, chatID string) { - t.originChannel = channel - t.originChatID = chatID -} - func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult { task, ok := args["task"].(string) if !ok { @@ -341,13 +332,24 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe } } + // Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests) + // to preserve the same defaults as the original NewSubagentTool constructor. + channel := ToolChannel(ctx) + if channel == "" { + channel = "cli" + } + chatID := ToolChatID(ctx) + if chatID == "" { + chatID = "direct" + } + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ Provider: sm.provider, Model: sm.defaultModel, Tools: tools, MaxIterations: maxIter, LLMOptions: llmOptions, - }, messages, t.originChannel, t.originChatID) + }, messages, channel, chatID) if err != nil { return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) } diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go index 59bfdffae..a1450410a 100644 --- a/pkg/tools/subagent_tool_test.go +++ b/pkg/tools/subagent_tool_test.go @@ -50,9 +50,8 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) { manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) manager.SetLLMOptions(2048, 0.6) tool := NewSubagentTool(manager) - tool.SetContext("cli", "direct") - ctx := context.Background() + ctx := WithToolContext(context.Background(), "cli", "direct") args := map[string]any{"task": "Do something"} result := tool.Execute(ctx, args) @@ -147,28 +146,14 @@ func TestSubagentTool_Parameters(t *testing.T) { } } -// TestSubagentTool_SetContext verifies context setting -func TestSubagentTool_SetContext(t *testing.T) { - provider := &MockLLMProvider{} - manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) - tool := NewSubagentTool(manager) - - tool.SetContext("test-channel", "test-chat") - - // Verify context is set (we can't directly access private fields, - // but we can verify it doesn't crash) - // The actual context usage is tested in Execute tests -} - // TestSubagentTool_Execute_Success tests successful execution func TestSubagentTool_Execute_Success(t *testing.T) { provider := &MockLLMProvider{} msgBus := bus.NewMessageBus() manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) tool := NewSubagentTool(manager) - tool.SetContext("telegram", "chat-123") - ctx := context.Background() + ctx := WithToolContext(context.Background(), "telegram", "chat-123") args := map[string]any{ "task": "Write a haiku about coding", "label": "haiku-task", @@ -297,12 +282,9 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) { manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) tool := NewSubagentTool(manager) - // Set context channel := "test-channel" chatID := "test-chat" - tool.SetContext(channel, chatID) - - ctx := context.Background() + ctx := WithToolContext(context.Background(), channel, chatID) args := map[string]any{ "task": "Test context passing", } diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index cdfe0d6ce..244f0d4a2 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -121,37 +122,53 @@ func RunToolLoop( } messages = append(messages, assistantMsg) - // 7. Execute tool calls - for _, tc := range normalizedToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]any{ - "tool": tc.Name, - "iteration": iteration, - }) + // 7. Execute tool calls in parallel + type indexedResult struct { + result *ToolResult + tc providers.ToolCall + } - // Execute tool (no async callback for subagents - they run independently) - var toolResult *ToolResult - if config.Tools != nil { - toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) - } else { - toolResult = ErrorResult("No tools available") + results := make([]indexedResult, len(normalizedToolCalls)) + var wg sync.WaitGroup + + for i, tc := range normalizedToolCalls { + results[i].tc = tc + + wg.Add(1) + go func(idx int, tc providers.ToolCall) { + defer wg.Done() + + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "tool": tc.Name, + "iteration": iteration, + }) + + var toolResult *ToolResult + if config.Tools != nil { + toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) + } else { + toolResult = ErrorResult("No tools available") + } + results[idx].result = toolResult + }(i, tc) + } + wg.Wait() + + // Append results in original order + for _, r := range results { + contentForLLM := r.result.ForLLM + if contentForLLM == "" && r.result.Err != nil { + contentForLLM = r.result.Err.Error() } - // Determine content for LLM - contentForLLM := toolResult.ForLLM - if contentForLLM == "" && toolResult.Err != nil { - contentForLLM = toolResult.Err.Error() - } - - // Add tool result message - toolResultMsg := providers.Message{ + messages = append(messages, providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) + ToolCallID: r.tc.ID, + }) } } diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 10498126b..eeceabd98 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -109,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in return "", fmt.Errorf("failed to read response: %w", err) } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body)) + } + var searchResp struct { Web struct { Results []struct { @@ -391,6 +395,150 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil } +type SearXNGSearchProvider struct { + baseURL string +} + +func (p *SearXNGSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := fmt.Sprintf("%s/search?q=%s&format=json&categories=general", + strings.TrimSuffix(p.baseURL, "/"), + url.QueryEscape(query)) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("SearXNG returned status %d", resp.StatusCode) + } + + var result struct { + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` + Engine string `json:"engine"` + Score float64 `json:"score"` + } `json:"results"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if len(result.Results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + // Limit results to requested count + if len(result.Results) > count { + result.Results = result.Results[:count] + } + + // Format results in standard PicoClaw format + var b strings.Builder + b.WriteString(fmt.Sprintf("Results for: %s (via SearXNG)\n", query)) + for i, r := range result.Results { + b.WriteString(fmt.Sprintf("%d. %s\n", i+1, r.Title)) + b.WriteString(fmt.Sprintf(" %s\n", r.URL)) + if r.Content != "" { + b.WriteString(fmt.Sprintf(" %s\n", r.Content)) + } + } + + return b.String(), nil +} + +type GLMSearchProvider struct { + apiKey string + baseURL string + searchEngine string + proxy string + client *http.Client +} + +func (p *GLMSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := p.baseURL + if searchURL == "" { + searchURL = "https://open.bigmodel.cn/api/paas/v4/web_search" + } + + payload := map[string]any{ + "search_query": query, + "search_engine": p.searchEngine, + "search_intent": false, + "count": count, + "content_size": "medium", + } + + bodyBytes, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewReader(bodyBytes)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.apiKey) + + resp, err := p.client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("GLM Search API error (status %d): %s", resp.StatusCode, string(body)) + } + + var searchResp struct { + SearchResult []struct { + Title string `json:"title"` + Content string `json:"content"` + Link string `json:"link"` + } `json:"search_result"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + results := searchResp.SearchResult + if len(results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via GLM Search)", query)) + for i, item := range results { + if i >= count { + break + } + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.Link)) + if item.Content != "" { + lines = append(lines, fmt.Sprintf(" %s", item.Content)) + } + } + + return strings.Join(lines, "\n"), nil +} + type WebSearchTool struct { provider SearchProvider maxResults int @@ -409,6 +557,14 @@ type WebSearchToolOptions struct { PerplexityAPIKey string PerplexityMaxResults int PerplexityEnabled bool + SearXNGBaseURL string + SearXNGMaxResults int + SearXNGEnabled bool + GLMSearchAPIKey string + GLMSearchBaseURL string + GLMSearchEngine string + GLMSearchMaxResults int + GLMSearchEnabled bool Proxy string } @@ -416,7 +572,7 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { var provider SearchProvider maxResults := 5 - // Priority: Perplexity > Brave > Tavily > DuckDuckGo + // Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo > GLM Search if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { client, err := createHTTPClient(opts.Proxy, perplexityTimeout) if err != nil { @@ -435,6 +591,11 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { if opts.BraveMaxResults > 0 { maxResults = opts.BraveMaxResults } + } else if opts.SearXNGEnabled && opts.SearXNGBaseURL != "" { + provider = &SearXNGSearchProvider{baseURL: opts.SearXNGBaseURL} + if opts.SearXNGMaxResults > 0 { + maxResults = opts.SearXNGMaxResults + } } else if opts.TavilyEnabled && opts.TavilyAPIKey != "" { client, err := createHTTPClient(opts.Proxy, searchTimeout) if err != nil { @@ -458,6 +619,25 @@ func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { if opts.DuckDuckGoMaxResults > 0 { maxResults = opts.DuckDuckGoMaxResults } + } else if opts.GLMSearchEnabled && opts.GLMSearchAPIKey != "" { + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for GLM Search: %w", err) + } + searchEngine := opts.GLMSearchEngine + if searchEngine == "" { + searchEngine = "search_std" + } + provider = &GLMSearchProvider{ + apiKey: opts.GLMSearchAPIKey, + baseURL: opts.GLMSearchBaseURL, + searchEngine: searchEngine, + proxy: opts.Proxy, + client: client, + } + if opts.GLMSearchMaxResults > 0 { + maxResults = opts.GLMSearchMaxResults + } } else { return nil, nil } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 8a8b88131..bdd30d385 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -681,3 +681,135 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser) } } + +func TestWebTool_GLMSearch_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST request, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + } + if r.Header.Get("Authorization") != "Bearer test-glm-key" { + t.Errorf("Expected Authorization Bearer test-glm-key, got %s", r.Header.Get("Authorization")) + } + + var payload map[string]any + json.NewDecoder(r.Body).Decode(&payload) + if payload["search_query"] != "test query" { + t.Errorf("Expected search_query 'test query', got %v", payload["search_query"]) + } + if payload["search_engine"] != "search_std" { + t.Errorf("Expected search_engine 'search_std', got %v", payload["search_engine"]) + } + + response := map[string]any{ + "id": "web-search-test", + "created": 1709568000, + "search_result": []map[string]any{ + { + "title": "Test GLM Result", + "content": "GLM search snippet", + "link": "https://example.com/glm", + "media": "Example", + "publish_date": "2026-03-04", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + tool, err := NewWebSearchTool(WebSearchToolOptions{ + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-glm-key", + GLMSearchBaseURL: server.URL, + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "query": "test query", + }) + + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + if !strings.Contains(result.ForUser, "Test GLM Result") { + t.Errorf("Expected 'Test GLM Result' in output, got: %s", result.ForUser) + } + if !strings.Contains(result.ForUser, "https://example.com/glm") { + t.Errorf("Expected URL in output, got: %s", result.ForUser) + } + if !strings.Contains(result.ForUser, "via GLM Search") { + t.Errorf("Expected 'via GLM Search' in output, got: %s", result.ForUser) + } +} + +func TestWebTool_GLMSearch_APIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":"invalid api key"}`)) + })) + defer server.Close() + + tool, err := NewWebSearchTool(WebSearchToolOptions{ + GLMSearchEnabled: true, + GLMSearchAPIKey: "bad-key", + GLMSearchBaseURL: server.URL, + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + result := tool.Execute(context.Background(), map[string]any{ + "query": "test query", + }) + + if !result.IsError { + t.Errorf("Expected IsError=true for 401 response") + } + if !strings.Contains(result.ForLLM, "status 401") { + t.Errorf("Expected status 401 in error, got: %s", result.ForLLM) + } +} + +func TestWebTool_GLMSearch_Priority(t *testing.T) { + // GLM Search should only be selected when all other providers are disabled + tool, err := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: true, + DuckDuckGoMaxResults: 5, + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-key", + GLMSearchBaseURL: "https://example.com", + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + + // DuckDuckGo should win over GLM Search + if _, ok := tool.provider.(*DuckDuckGoSearchProvider); !ok { + t.Errorf("Expected DuckDuckGoSearchProvider when both enabled, got %T", tool.provider) + } + + // With DuckDuckGo disabled, GLM Search should be selected + tool2, err := NewWebSearchTool(WebSearchToolOptions{ + DuckDuckGoEnabled: false, + GLMSearchEnabled: true, + GLMSearchAPIKey: "test-key", + GLMSearchBaseURL: "https://example.com", + GLMSearchEngine: "search_std", + }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } + if _, ok := tool2.provider.(*GLMSearchProvider); !ok { + t.Errorf("Expected GLMSearchProvider when only GLM enabled, got %T", tool2.provider) + } +} diff --git a/pkg/utils/media.go b/pkg/utils/media.go index a34889fb8..3e1c5d88e 100644 --- a/pkg/utils/media.go +++ b/pkg/utils/media.go @@ -3,6 +3,7 @@ package utils import ( "io" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -52,11 +53,12 @@ type DownloadOptions struct { Timeout time.Duration ExtraHeaders map[string]string LoggerPrefix string + ProxyURL string } // DownloadFile downloads a file from URL to a local temp directory. // Returns the local file path or empty string on error. -func DownloadFile(url, filename string, opts DownloadOptions) string { +func DownloadFile(urlStr, filename string, opts DownloadOptions) string { // Set defaults if opts.Timeout == 0 { opts.Timeout = 60 * time.Second @@ -78,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName) // Create HTTP request - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest("GET", urlStr, nil) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{ "error": err.Error(), @@ -92,11 +94,24 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { } client := &http.Client{Timeout: opts.Timeout} + if opts.ProxyURL != "" { + proxyURL, parseErr := url.Parse(opts.ProxyURL) + if parseErr != nil { + logger.ErrorCF(opts.LoggerPrefix, "Invalid proxy URL for download", map[string]any{ + "error": parseErr.Error(), + "proxy": opts.ProxyURL, + }) + return "" + } + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } resp, err := client.Do(req) if err != nil { logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{ "error": err.Error(), - "url": url, + "url": urlStr, }) return "" } @@ -105,7 +120,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { if resp.StatusCode != http.StatusOK { logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{ "status": resp.StatusCode, - "url": url, + "url": urlStr, }) return "" } diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index f973e77fe..e949d7a22 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -10,12 +10,19 @@ import ( "net/http" "os" "path/filepath" + "strings" "time" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) +type Transcriber interface { + Name() string + Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) +} + type GroqTranscriber struct { apiKey string apiBase string @@ -152,8 +159,22 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) return &result, nil } -func (t *GroqTranscriber) IsAvailable() bool { - available := t.apiKey != "" - logger.DebugCF("voice", "Checking transcriber availability", map[string]any{"available": available}) - return available +func (t *GroqTranscriber) Name() string { + return "groq" +} + +// DetectTranscriber inspects cfg and returns the appropriate Transcriber, or +// nil if no supported transcription provider is configured. +func DetectTranscriber(cfg *config.Config) Transcriber { + // Direct Groq provider config takes priority. + if key := cfg.Providers.Groq.APIKey; key != "" { + return NewGroqTranscriber(key) + } + // Fall back to any model-list entry that uses the groq/ protocol. + for _, mc := range cfg.ModelList { + if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" { + return NewGroqTranscriber(mc.APIKey) + } + } + return nil } diff --git a/pkg/voice/transcriber_test.go b/pkg/voice/transcriber_test.go new file mode 100644 index 000000000..9b6add333 --- /dev/null +++ b/pkg/voice/transcriber_test.go @@ -0,0 +1,160 @@ +package voice + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// Ensure GroqTranscriber satisfies the Transcriber interface at compile time. +var _ Transcriber = (*GroqTranscriber)(nil) + +func TestGroqTranscriberName(t *testing.T) { + tr := NewGroqTranscriber("sk-test") + if got := tr.Name(); got != "groq" { + t.Errorf("Name() = %q, want %q", got, "groq") + } +} + +func TestDetectTranscriber(t *testing.T) { + tests := []struct { + name string + cfg *config.Config + wantNil bool + wantName string + }{ + { + name: "no config", + cfg: &config.Config{}, + wantNil: true, + }, + { + name: "groq provider key", + cfg: &config.Config{ + Providers: config.ProvidersConfig{ + Groq: config.ProviderConfig{APIKey: "sk-groq-direct"}, + }, + }, + wantName: "groq", + }, + { + name: "groq via model list", + cfg: &config.Config{ + ModelList: []config.ModelConfig{ + {Model: "openai/gpt-4o", APIKey: "sk-openai"}, + {Model: "groq/llama-3.3-70b", APIKey: "sk-groq-model"}, + }, + }, + wantName: "groq", + }, + { + name: "groq model list entry without key is skipped", + cfg: &config.Config{ + ModelList: []config.ModelConfig{ + {Model: "groq/llama-3.3-70b", APIKey: ""}, + }, + }, + wantNil: true, + }, + { + name: "provider key takes priority over model list", + cfg: &config.Config{ + Providers: config.ProvidersConfig{ + Groq: config.ProviderConfig{APIKey: "sk-groq-direct"}, + }, + ModelList: []config.ModelConfig{ + {Model: "groq/llama-3.3-70b", APIKey: "sk-groq-model"}, + }, + }, + wantName: "groq", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tr := DetectTranscriber(tc.cfg) + if tc.wantNil { + if tr != nil { + t.Errorf("DetectTranscriber() = %v, want nil", tr) + } + return + } + if tr == nil { + t.Fatal("DetectTranscriber() = nil, want non-nil") + } + if got := tr.Name(); got != tc.wantName { + t.Errorf("Name() = %q, want %q", got, tc.wantName) + } + }) + } +} + +func TestTranscribe(t *testing.T) { + // Write a minimal fake audio file so the transcriber can open and send it. + tmpDir := t.TempDir() + audioPath := filepath.Join(tmpDir, "clip.ogg") + if err := os.WriteFile(audioPath, []byte("fake-audio-data"), 0o644); err != nil { + t.Fatalf("failed to write fake audio file: %v", err) + } + + t.Run("success", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/audio/transcriptions" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer sk-test" { + t.Errorf("unexpected Authorization header: %s", r.Header.Get("Authorization")) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(TranscriptionResponse{ + Text: "hello world", + Language: "en", + Duration: 1.5, + }) + })) + defer srv.Close() + + tr := NewGroqTranscriber("sk-test") + tr.apiBase = srv.URL + + resp, err := tr.Transcribe(context.Background(), audioPath) + if err != nil { + t.Fatalf("Transcribe() error: %v", err) + } + if resp.Text != "hello world" { + t.Errorf("Text = %q, want %q", resp.Text, "hello world") + } + if resp.Language != "en" { + t.Errorf("Language = %q, want %q", resp.Language, "en") + } + }) + + t.Run("api error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"invalid_api_key"}`, http.StatusUnauthorized) + })) + defer srv.Close() + + tr := NewGroqTranscriber("sk-bad") + tr.apiBase = srv.URL + + _, err := tr.Transcribe(context.Background(), audioPath) + if err == nil { + t.Fatal("expected error for non-200 response, got nil") + } + }) + + t.Run("missing file", func(t *testing.T) { + tr := NewGroqTranscriber("sk-test") + _, err := tr.Transcribe(context.Background(), filepath.Join(tmpDir, "nonexistent.ogg")) + if err == nil { + t.Fatal("expected error for missing file, got nil") + } + }) +}