diff --git a/README.fr.md b/README.fr.md index 08a1926b6..574402a3e 100644 --- a/README.fr.md +++ b/README.fr.md @@ -649,7 +649,6 @@ PicoClaw stocke les données dans votre workspace configuré (par défaut : `~/. ├── HEARTBEAT.md # Invites de tâches périodiques (vérifiées toutes les 30 min) ├── IDENTITY.md # Identité de l'Agent ├── SOUL.md # Âme de l'Agent -├── TOOLS.md # Description des outils └── USER.md # Préférences utilisateur ``` @@ -980,6 +979,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Obtenir Clé](https://cerebras.ai) | | **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Obtenir Clé](https://console.volcengine.com) | | **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.ja.md b/README.ja.md index c4c5b27a0..1eb47cfdc 100644 --- a/README.ja.md +++ b/README.ja.md @@ -610,7 +610,6 @@ PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw ├── HEARTBEAT.md # 定期タスクプロンプト(30分ごとに確認) ├── IDENTITY.md # エージェントのアイデンティティ ├── SOUL.md # エージェントのソウル -├── TOOLS.md # ツールの説明 └── USER.md # ユーザー設定 ``` @@ -921,6 +920,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [キーを取得](https://cerebras.ai) | | **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [キーを取得](https://console.volcengine.com) | | **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) | | **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.md b/README.md index bae3fa681..55e9fb187 100644 --- a/README.md +++ b/README.md @@ -787,7 +787,6 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa ├── HEARTBEAT.md # Periodic task prompts (checked every 30 min) ├── IDENTITY.md # Agent identity ├── SOUL.md # Agent soul -├── TOOLS.md # Tool descriptions └── USER.md # User preferences ``` @@ -1034,6 +1033,7 @@ This design also enables **multi-agent support** with flexible provider selectio | **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) | | **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | | **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) | +| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | @@ -1504,3 +1504,4 @@ This happens when another instance of the bot is running. Make sure only one `pi | **SearXNG** | Unlimited (self-hosted) | Privacy-focused metasearch (70+ engines) | | **Groq** | Free tier available | Fast inference (Llama, Mixtral) | | **Cerebras** | Free tier available | Fast inference (Llama, Qwen, etc.) | +| **LongCat** | Up to 5M tokens/day | Fast inference (free tier) | diff --git a/README.pt-br.md b/README.pt-br.md index 5f37ba457..066d71d6a 100644 --- a/README.pt-br.md +++ b/README.pt-br.md @@ -645,7 +645,6 @@ O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/worksp ├── HEARTBEAT.md # Prompts de tarefas periodicas (verificado a cada 30 min) ├── IDENTITY.md # Identidade do Agente ├── SOUL.md # Alma do Agente -├── TOOLS.md # Descrição das ferramentas └── USER.md # Preferencias do usuario ``` @@ -976,6 +975,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Obter Chave](https://cerebras.ai) | | **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Obter Chave](https://console.volcengine.com) | | **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) | | **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.vi.md b/README.vi.md index 92c6ecbae..66573a1c5 100644 --- a/README.vi.md +++ b/README.vi.md @@ -617,7 +617,6 @@ PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định: ├── HEARTBEAT.md # Prompt tác vụ định kỳ (kiểm tra mỗi 30 phút) ├── IDENTITY.md # Danh tính Agent ├── SOUL.md # Tâm hồn/Tính cách Agent -├── TOOLS.md # Mô tả công cụ └── USER.md # Tùy chọn người dùng ``` @@ -945,6 +944,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Lấy Khóa](https://cerebras.ai) | | **Volcengine** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Lấy Khóa](https://console.volcengine.com) | | **ShengsuanYun** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) | | **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | diff --git a/README.zh.md b/README.zh.md index c744e0d20..a3a4c7f5f 100644 --- a/README.zh.md +++ b/README.zh.md @@ -365,7 +365,6 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work ├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次) ├── IDENTITY.md # Agent 身份设定 ├── SOUL.md # Agent 灵魂/性格 -├── TOOLS.md # 工具描述 └── USER.md # 用户偏好 ``` @@ -517,6 +516,7 @@ Agent 读取 HEARTBEAT.md | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [获取密钥](https://cerebras.ai) | | **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [获取密钥](https://console.volcengine.com) | | **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - | +| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) | | **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth | | **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - | @@ -879,3 +879,4 @@ Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN) | **Brave Search** | 2000 次查询/月 | 网络搜索功能 | | **Tavily** | 1000 次查询/月 | AI Agent 搜索优化 | | **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) | +| **LongCat** | 最多 5M tokens/天 | 推理速度快 (免费额度) | diff --git a/assets/wechat.png b/assets/wechat.png index 4442ef2c7..4cfcbbb1a 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw-launcher-tui/internal/ui/app.go b/cmd/picoclaw-launcher-tui/internal/ui/app.go index a2ccddf70..f26b6125c 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/app.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/app.go @@ -325,7 +325,7 @@ func (s *appState) viewGatewayLog() { } func (s *appState) selectedModelName() string { - modelName := strings.TrimSpace(s.config.Agents.Defaults.Model) + modelName := strings.TrimSpace(s.config.Agents.Defaults.GetModelName()) if modelName == "" { return "" } @@ -413,7 +413,7 @@ func (s *appState) isGatewayRunning() bool { } func (s *appState) validateAgentModel() error { - modelName := strings.TrimSpace(s.config.Agents.Defaults.Model) + modelName := strings.TrimSpace(s.config.Agents.Defaults.GetModelName()) if modelName == "" { return nil } @@ -422,7 +422,7 @@ func (s *appState) validateAgentModel() error { } func (s *appState) isActiveModelValid() bool { - modelName := strings.TrimSpace(s.config.Agents.Defaults.Model) + modelName := strings.TrimSpace(s.config.Agents.Defaults.GetModelName()) if modelName == "" { return false } diff --git a/cmd/picoclaw-launcher-tui/internal/ui/model.go b/cmd/picoclaw-launcher-tui/internal/ui/model.go index 47ca5a355..93069ac7b 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/model.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/model.go @@ -15,7 +15,7 @@ import ( func (s *appState) modelMenu() tview.Primitive { items := make([]MenuItem, 0, 1+len(s.config.ModelList)) - currentModel := strings.TrimSpace(s.config.Agents.Defaults.Model) + currentModel := strings.TrimSpace(s.config.Agents.Defaults.ModelName) for i := range s.config.ModelList { index := i model := s.config.ModelList[i] @@ -77,9 +77,9 @@ func (s *appState) modelMenu() tview.Primitive { ) return nil } - s.config.Agents.Defaults.Model = model.ModelName + s.config.Agents.Defaults.ModelName = model.ModelName s.dirty = true - refreshModelMenu(menu, s.config.Agents.Defaults.Model, s.config.ModelList) + refreshModelMenu(menu, s.config.Agents.Defaults.GetModelName(), s.config.ModelList) refreshMainMenuIfPresent(s) } return nil @@ -105,8 +105,8 @@ func (s *appState) modelForm(index int) tview.Primitive { } oldName := model.ModelName model.ModelName = value - if s.config.Agents.Defaults.Model == oldName { - s.config.Agents.Defaults.Model = value + if s.config.Agents.Defaults.ModelName == oldName { + s.config.Agents.Defaults.ModelName = value } s.dirty = true form.SetTitle(fmt.Sprintf("Model: %s", model.ModelName)) @@ -258,7 +258,7 @@ func refreshModelMenu(menu *Menu, currentModel string, models []picoclawconfig.M func refreshModelMenuFromState(menu *Menu, s *appState) { items := make([]MenuItem, 0, 1+len(s.config.ModelList)) - currentModel := strings.TrimSpace(s.config.Agents.Defaults.Model) + currentModel := strings.TrimSpace(s.config.Agents.Defaults.ModelName) for i := range s.config.ModelList { index := i model := s.config.ModelList[i] diff --git a/cmd/picoclaw/internal/auth/helpers.go b/cmd/picoclaw/internal/auth/helpers.go index a0a229167..02c78cf4e 100644 --- a/cmd/picoclaw/internal/auth/helpers.go +++ b/cmd/picoclaw/internal/auth/helpers.go @@ -56,9 +56,6 @@ func authLoginOpenAI(useDeviceCode bool) error { appCfg, err := internal.LoadConfig() if err == nil { - // Update Providers (legacy format) - appCfg.Providers.OpenAI.AuthMethod = "oauth" - // Update or add openai in ModelList foundOpenAI := false for i := range appCfg.ModelList { @@ -130,9 +127,6 @@ func authLoginGoogleAntigravity() error { appCfg, err := internal.LoadConfig() if err == nil { - // Update Providers (legacy format, for backward compatibility) - appCfg.Providers.Antigravity.AuthMethod = "oauth" - // Update or add antigravity in ModelList foundAntigravity := false for i := range appCfg.ModelList { @@ -210,8 +204,6 @@ func authLoginAnthropicSetupToken() error { appCfg, err := internal.LoadConfig() if err == nil { - appCfg.Providers.Anthropic.AuthMethod = "oauth" - found := false for i := range appCfg.ModelList { if isAnthropicModel(appCfg.ModelList[i].Model) { @@ -287,7 +279,6 @@ func authLoginPasteToken(provider string) error { if err == nil { switch provider { case "anthropic": - appCfg.Providers.Anthropic.AuthMethod = "token" // Update ModelList found := false for i := range appCfg.ModelList { @@ -306,7 +297,6 @@ func authLoginPasteToken(provider string) error { appCfg.Agents.Defaults.ModelName = defaultAnthropicModel } case "openai": - appCfg.Providers.OpenAI.AuthMethod = "token" // Update ModelList found := false for i := range appCfg.ModelList { @@ -365,15 +355,6 @@ func authLogoutCmd(provider string) error { } } } - // Clear AuthMethod in Providers (legacy) - switch provider { - case "openai": - appCfg.Providers.OpenAI.AuthMethod = "" - case "anthropic": - appCfg.Providers.Anthropic.AuthMethod = "" - case "google-antigravity", "antigravity": - appCfg.Providers.Antigravity.AuthMethod = "" - } config.SaveConfig(internal.GetConfigPath(), appCfg) } @@ -392,10 +373,6 @@ func authLogoutCmd(provider string) error { for i := range appCfg.ModelList { appCfg.ModelList[i].AuthMethod = "" } - // Clear all AuthMethods in Providers (legacy) - appCfg.Providers.OpenAI.AuthMethod = "" - appCfg.Providers.Anthropic.AuthMethod = "" - appCfg.Providers.Antigravity.AuthMethod = "" config.SaveConfig(internal.GetConfigPath(), appCfg) } diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go index e04bccffb..120b740d8 100644 --- a/cmd/picoclaw/internal/helpers.go +++ b/cmd/picoclaw/internal/helpers.go @@ -4,19 +4,20 @@ import ( "os" "path/filepath" + "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/config" ) -const Logo = "🦞" +const Logo = pkg.Logo // GetPicoclawHome returns the picoclaw home directory. // Priority: $PICOCLAW_HOME > ~/.picoclaw func GetPicoclawHome() string { - if home := os.Getenv("PICOCLAW_HOME"); home != "" { + if home := os.Getenv(pkg.PicoClawHome); home != "" { return home } home, _ := os.UserHomeDir() - return filepath.Join(home, ".picoclaw") + return filepath.Join(home, pkg.DefaultPicoClawHome) } func GetConfigPath() string { diff --git a/cmd/picoclaw/internal/helpers_test.go b/cmd/picoclaw/internal/helpers_test.go index 583751781..6e5123152 100644 --- a/cmd/picoclaw/internal/helpers_test.go +++ b/cmd/picoclaw/internal/helpers_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/sipeed/picoclaw/pkg" ) func TestGetConfigPath(t *testing.T) { @@ -20,7 +22,7 @@ func TestGetConfigPath(t *testing.T) { } func TestGetConfigPath_WithPICOCLAW_HOME(t *testing.T) { - t.Setenv("PICOCLAW_HOME", "/custom/picoclaw") + t.Setenv(pkg.PicoClawHome, "/custom/picoclaw") t.Setenv("HOME", "/tmp/home") got := GetConfigPath() @@ -31,7 +33,7 @@ func TestGetConfigPath_WithPICOCLAW_HOME(t *testing.T) { func TestGetConfigPath_WithPICOCLAW_CONFIG(t *testing.T) { t.Setenv("PICOCLAW_CONFIG", "/custom/config.json") - t.Setenv("PICOCLAW_HOME", "/custom/picoclaw") + t.Setenv(pkg.PicoClawHome, "/custom/picoclaw") t.Setenv("HOME", "/tmp/home") got := GetConfigPath() diff --git a/cmd/picoclaw/internal/status/helpers.go b/cmd/picoclaw/internal/status/helpers.go index dd7063fe6..43c5786a8 100644 --- a/cmd/picoclaw/internal/status/helpers.go +++ b/cmd/picoclaw/internal/status/helpers.go @@ -42,48 +42,6 @@ func statusCmd() { if _, err := os.Stat(configPath); err == nil { fmt.Printf("Model: %s\n", cfg.Agents.Defaults.GetModelName()) - hasOpenRouter := cfg.Providers.OpenRouter.APIKey != "" - hasAnthropic := cfg.Providers.Anthropic.APIKey != "" - hasOpenAI := cfg.Providers.OpenAI.APIKey != "" - hasGemini := cfg.Providers.Gemini.APIKey != "" - hasZhipu := cfg.Providers.Zhipu.APIKey != "" - hasQwen := cfg.Providers.Qwen.APIKey != "" - hasGroq := cfg.Providers.Groq.APIKey != "" - hasVLLM := cfg.Providers.VLLM.APIBase != "" - hasMoonshot := cfg.Providers.Moonshot.APIKey != "" - hasDeepSeek := cfg.Providers.DeepSeek.APIKey != "" - hasVolcEngine := cfg.Providers.VolcEngine.APIKey != "" - hasNvidia := cfg.Providers.Nvidia.APIKey != "" - hasOllama := cfg.Providers.Ollama.APIBase != "" - - status := func(enabled bool) string { - if enabled { - return "✓" - } - return "not set" - } - fmt.Println("OpenRouter API:", status(hasOpenRouter)) - fmt.Println("Anthropic API:", status(hasAnthropic)) - fmt.Println("OpenAI API:", status(hasOpenAI)) - fmt.Println("Gemini API:", status(hasGemini)) - fmt.Println("Zhipu API:", status(hasZhipu)) - fmt.Println("Qwen API:", status(hasQwen)) - fmt.Println("Groq API:", status(hasGroq)) - fmt.Println("Moonshot API:", status(hasMoonshot)) - fmt.Println("DeepSeek API:", status(hasDeepSeek)) - fmt.Println("VolcEngine API:", status(hasVolcEngine)) - fmt.Println("Nvidia API:", status(hasNvidia)) - if hasVLLM { - fmt.Printf("vLLM/Local: ✓ %s\n", cfg.Providers.VLLM.APIBase) - } else { - fmt.Println("vLLM/Local: not set") - } - if hasOllama { - fmt.Printf("Ollama: ✓ %s\n", cfg.Providers.Ollama.APIBase) - } else { - fmt.Println("Ollama: not set") - } - store, _ := auth.LoadStore() if store != nil && len(store.Credentials) > 0 { fmt.Println("\nOAuth/Token Auth:") diff --git a/config/config.example.json b/config/config.example.json index 49658b9f2..1eea37683 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -35,6 +35,11 @@ "model": "deepseek/deepseek-chat", "api_key": "sk-your-deepseek-key" }, + { + "model_name": "longcat", + "model": "longcat/LongCat-Flash-Thinking", + "api_key": "your-longcat-api-key" + }, { "model_name": "loadbalanced-gpt4", "model": "openai/gpt-5.2", @@ -274,6 +279,10 @@ "avian": { "api_key": "", "api_base": "https://api.avian.io/v1" + }, + "longcat": { + "api_key": "", + "api_base": "https://api.longcat.chat/openai" } }, "tools": { @@ -477,6 +486,9 @@ "enabled": false, "monitor_usb": true }, + "voice": { + "echo_transcription": false + }, "gateway": { "host": "127.0.0.1", "port": 18790 diff --git a/docs/channels/matrix/README.md b/docs/channels/matrix/README.md index c213aa80b..233f5c0a3 100644 --- a/docs/channels/matrix/README.md +++ b/docs/channels/matrix/README.md @@ -22,7 +22,8 @@ Add this to `config.json`: "enabled": true, "text": "Thinking..." }, - "reasoning_channel_id": "" + "reasoning_channel_id": "", + "message_format": "richtext" } } } @@ -42,10 +43,12 @@ Add this to `config.json`: | group_trigger | object | No | Group trigger strategy (`mention_only` / `prefixes`) | | placeholder | object | No | Placeholder message config | | reasoning_channel_id | string | No | Target channel for reasoning output | +| message_format | string | No | Output format: `"richtext"` (default) renders markdown as HTML; `"plain"` sends plain text only | ## 3. Currently Supported -- Text message send/receive +- Text message send/receive with markdown rendering (bold, italic, headers, code blocks, etc.) +- Configurable message format (`richtext` / `plain`) - Incoming image/audio/video/file download (MediaStore first, local path fallback) - Incoming audio normalization into existing transcription flow (`[audio: ...]`) - Outgoing image/audio/video/file upload and send diff --git a/docs/config-versioning.md b/docs/config-versioning.md new file mode 100644 index 000000000..36d7fdd25 --- /dev/null +++ b/docs/config-versioning.md @@ -0,0 +1,230 @@ +# Config Schema Versioning Guide + +## Overview + +PicoClaw uses a schema versioning system for `config.json` to ensure smooth upgrades as the configuration format evolves. + +## Version History + +### Version 1 +- **Introduction**: Initial version with version field support +- **Changes**: Added `version` field to Config struct +- **Migration**: No structural changes needed for existing configs + +## How It Works + +### Automatic Migration +When you load a config file: +1. The system first reads the `version` field from the JSON +2. Based on the detected version, it loads the appropriate config struct (`ConfigV0`, `ConfigV1`, etc.) +3. If the loaded version is less than the latest, migrations are applied incrementally +4. The version number is updated automatically +5. The migrated config is automatically saved back to disk + +### Version Field +The `version` field in `config.json` indicates the schema version: +- `0` or missing: Legacy config (no version field) +- `1`: Current version with versioning support + +```json +{ + "version": 1, + "agents": {...}, + ... +} +``` + +## Adding a New Migration + +When making breaking changes to the config schema: + +### Step 1: Define the New Version Struct + +Create a new struct for the new version if the structure changes significantly: + +```go +// ConfigV2 represents version 2 config structure +type ConfigV2 struct { + Version int `json:"version"` + Agents AgentsConfig `json:"agents"` + // ... other fields with new structure +} +``` + +### Step 2: Update Current Config Version + +```go +const CurrentConfigVersion = 2 // Increment this +``` + +### Step 3: Add a Loader Function + +```go +// loadConfigV2 loads a version 2 config +func loadConfigV2(data []byte) (*Config, error) { + cfg := DefaultConfig() + + // Parse to ConfigV2 struct + var v2 ConfigV2 + if err := json.Unmarshal(data, &v2); err != nil { + return nil, err + } + + // Convert to current Config + cfg.Version = v2.Version + cfg.Agents = v2.Agents + // ... map other fields + + return cfg, nil +} +``` + +### Step 4: Add Migration Logic + +```go +// applyMigration applies a single migration step from fromVersion to toVersion +func applyMigration(cfg *Config, fromVersion, toVersion int) (*Config, error) { + switch toVersion { + case 1: + // Migration from version 0 to 1 + return &Config{ + Version: 1, + Agents: cfg.Agents, + // ... copy all fields + }, nil + case 2: + // Migration from version 1 to 2 + // Example: Move or rename fields + migrated := *cfg + migrated.Version = 2 + // Apply structural changes + if cfg.SomeOldField != "" { + migrated.SomeNewField = cfg.SomeOldField + } + return &migrated, nil + default: + return nil, fmt.Errorf("unsupported migration target version: %d", toVersion) + } +} +``` + +### Step 5: Update LoadConfig Switch + +```go +func LoadConfig(path string) (*Config, error) { + // ... read file ... + + switch versionInfo.Version { + case 0: + cfg, err = loadConfigV0(data) + case 1: + cfg, err = loadConfigV1(data) + case 2: + cfg, err = loadConfigV2(data) + default: + return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version) + } + + // ... migrate and validate ... +} +``` + +### Step 6: Test Your Migration + +Create a test in `config_migration_test.go`: + +```go +func TestMigrateV1ToV2(t *testing.T) { + // Create a version 1 config + v1Config := Config{ + Version: 1, + // ... set up test data + } + + // Apply migration + migrated, err := applyMigration(&v1Config, 1, 2) + if err != nil { + t.Fatalf("Migration failed: %v", err) + } + + // Verify version is updated + if migrated.Version != 2 { + t.Errorf("Expected version 2, got %d", migrated.Version) + } + + // Verify data is preserved/transformed correctly + // ... +} +``` + +## Migration Best Practices + +1. **Version-Specific Structs**: Define a separate struct for each version that has structural changes +2. **Backward Compatibility**: Ensure old configs can still be loaded with their specific structs +3. **No Data Loss**: Migrations should preserve all user settings +4. **Idempotent**: Running the same migration multiple times should be safe +5. **Auto-Save**: Migrated configs are automatically saved to update the user's file +6. **Test Thoroughly**: Test with real user config files +7. **Update Defaults**: Keep `defaults.go` in sync with the latest schema + +## Example Migration + +### Scenario: Adding a new field with default value + +Old config (version 1): +```json +{ + "version": 1, + "agents": { + "defaults": { + "max_tokens": 32768 + } + } +} +``` + +Migration to version 2: +```go +case 2: + migrated := *cfg + migrated.Version = 2 + + // Add new field with default value if not set + if migrated.Agents.Defaults.NewFeatureEnabled == false { + // Use default value + } + + return &migrated, nil +``` + +New config (version 2): +```json +{ + "version": 2, + "agents": { + "defaults": { + "max_tokens": 32768, + "new_feature_enabled": false + } + } +} +``` + +## Troubleshooting + +### Config Not Upgrading +- Check that `CurrentConfigVersion` is incremented +- Verify migration logic in `applyMigration()` handles the target version +- Ensure `migrateConfig()` is called in `LoadConfig()` + +### Migration Errors +- Check error messages for specific migration failures +- Review migration logic for edge cases +- Ensure all required fields are properly initialized +- Verify the loader function for the source version + +### Data Loss After Migration +- Ensure all fields are copied during migration +- Check that the migration doesn't overwrite values with defaults unnecessarily +- Review the conversion logic in the loader functions + diff --git a/go.mod b/go.mod index f60be046f..3762015e9 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/ergochat/irc-go v0.5.0 github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 + github.com/gomarkdown/markdown v0.0.0-20260217112301-37c66b85d6ab github.com/gorilla/websocket v1.5.3 github.com/h2non/filetype v1.1.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 @@ -20,6 +21,7 @@ require ( github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 github.com/rivo/tview v0.42.0 + github.com/rs/zerolog v1.34.0 github.com/slack-go/slack v0.17.3 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 @@ -49,7 +51,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/rs/zerolog v1.34.0 // indirect github.com/segmentio/asm v1.1.3 // indirect github.com/segmentio/encoding v0.5.3 // indirect github.com/spf13/pflag v1.0.10 // indirect diff --git a/go.sum b/go.sum index 4060997f8..2e2b1a1ec 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,8 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/gomarkdown/markdown v0.0.0-20260217112301-37c66b85d6ab h1:VYNivV7P8IRHUam2swVUNkhIdp0LRRFKe4hXNnoZKTc= +github.com/gomarkdown/markdown v0.0.0-20260217112301-37c66b85d6ab/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 5a84c45e2..dd030d1b1 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -52,14 +53,14 @@ func (cb *ContextBuilder) WithToolDiscovery(useBM25, useRegex bool) *ContextBuil } func getGlobalConfigDir() string { - if home := os.Getenv("PICOCLAW_HOME"); home != "" { + if home := os.Getenv(pkg.PicoClawHome); home != "" { return home } home, err := os.UserHomeDir() if err != nil { return "" } - return filepath.Join(home, ".picoclaw") + return filepath.Join(home, pkg.DefaultPicoClawHome) } func NewContextBuilder(workspace string) *ContextBuilder { diff --git a/pkg/agent/instance_test.go b/pkg/agent/instance_test.go index 4f41ecd1c..335e236a0 100644 --- a/pkg/agent/instance_test.go +++ b/pkg/agent/instance_test.go @@ -18,7 +18,7 @@ func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 1234, MaxToolIterations: 5, }, @@ -50,7 +50,7 @@ func TestNewAgentInstance_DefaultsTemperatureWhenZero(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 1234, MaxToolIterations: 5, }, @@ -79,7 +79,7 @@ func TestNewAgentInstance_DefaultsTemperatureWhenUnset(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 1234, MaxToolIterations: 5, }, @@ -133,7 +133,7 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: tt.aliasName, + ModelName: tt.aliasName, }, }, ModelList: []config.ModelConfig{ diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 235d42fcc..28e549ce0 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -25,7 +25,6 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/mcp" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" @@ -48,6 +47,7 @@ type AgentLoop struct { mediaStore media.MediaStore transcriber voice.Transcriber cmdRegistry *commands.Registry + mcp mcpRuntime } // processOptions configures how a message is processed @@ -239,119 +239,8 @@ func registerSharedTools( func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) - - // Initialize MCP servers for all agents - if al.cfg.Tools.IsToolEnabled("mcp") { - mcpManager := mcp.NewManager() - // Ensure MCP connections are cleaned up on exit, regardless of initialization success - // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails - defer func() { - if err := mcpManager.Close(); err != nil { - logger.ErrorCF("agent", "Failed to close MCP manager", - map[string]any{ - "error": err.Error(), - }) - } - }() - - defaultAgent := al.registry.GetDefaultAgent() - var workspacePath string - if defaultAgent != nil && defaultAgent.Workspace != "" { - workspacePath = defaultAgent.Workspace - } else { - workspacePath = al.cfg.WorkspacePath() - } - - if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil { - logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available", - map[string]any{ - "error": err.Error(), - }) - } else { - // Register MCP tools for all agents - servers := mcpManager.GetServers() - uniqueTools := 0 - totalRegistrations := 0 - agentIDs := al.registry.ListAgentIDs() - agentCount := len(agentIDs) - - for serverName, conn := range servers { - uniqueTools += len(conn.Tools) - for _, tool := range conn.Tools { - for _, agentID := range agentIDs { - agent, ok := al.registry.GetAgent(agentID) - if !ok { - continue - } - - mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) - - if al.cfg.Tools.MCP.Discovery.Enabled { - agent.Tools.RegisterHidden(mcpTool) - } else { - agent.Tools.Register(mcpTool) - } - - totalRegistrations++ - logger.DebugCF("agent", "Registered MCP tool", - map[string]any{ - "agent_id": agentID, - "server": serverName, - "tool": tool.Name, - "name": mcpTool.Name(), - }) - } - } - } - logger.InfoCF("agent", "MCP tools registered successfully", - map[string]any{ - "server_count": len(servers), - "unique_tools": uniqueTools, - "total_registrations": totalRegistrations, - "agent_count": agentCount, - }) - - // Initializes Discovery Tools only if enabled by configuration - if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled { - useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25 - useRegex := al.cfg.Tools.MCP.Discovery.UseRegex - - // Fail fast: If discovery is enabled but no search method is turned on - if !useBM25 && !useRegex { - return fmt.Errorf( - "tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration", - ) - } - - ttl := al.cfg.Tools.MCP.Discovery.TTL - if ttl <= 0 { - ttl = 5 // Default value - } - - maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults - if maxSearchResults <= 0 { - maxSearchResults = 5 // Default value - } - - logger.InfoCF("agent", "Initializing tool discovery", map[string]any{ - "bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults, - }) - - for _, agentID := range agentIDs { - agent, ok := al.registry.GetAgent(agentID) - if !ok { - continue - } - - if useRegex { - agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults)) - } - if useBM25 { - agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults)) - } - } - } - } + if err := al.ensureMCPInitialized(ctx); err != nil { + return err } for al.running.Load() { @@ -431,6 +320,17 @@ func (al *AgentLoop) Stop() { // Close releases resources held by agent session stores. Call after Stop. func (al *AgentLoop) Close() { + mcpManager := al.mcp.takeManager() + + if mcpManager != nil { + if err := mcpManager.Close(); err != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]any{ + "error": err.Error(), + }) + } + } + al.registry.Close() } @@ -467,9 +367,10 @@ var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`) // transcribeAudioInMessage resolves audio media refs, transcribes them, and // replaces audio annotations in msg.Content with the transcribed text. -func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) bus.InboundMessage { +// Returns the (possibly modified) message and true if audio was transcribed. +func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.InboundMessage) (bus.InboundMessage, bool) { if al.transcriber == nil || al.mediaStore == nil || len(msg.Media) == 0 { - return msg + return msg, false } // Transcribe each audio media ref in order. @@ -493,9 +394,11 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou } if len(transcriptions) == 0 { - return msg + return msg, false } + al.sendTranscriptionFeedback(ctx, msg.Channel, msg.ChatID, msg.MessageID, transcriptions) + // Replace audio annotations sequentially with transcriptions. idx := 0 newContent := audioAnnotationRe.ReplaceAllStringFunc(msg.Content, func(match string) string { @@ -513,7 +416,48 @@ func (al *AgentLoop) transcribeAudioInMessage(ctx context.Context, msg bus.Inbou } msg.Content = newContent - return msg + return msg, true +} + +// sendTranscriptionFeedback sends feedback to the user with the result of +// audio transcription if the option is enabled. It uses Manager.SendMessage +// which executes synchronously (rate limiting, splitting, retry) so that +// ordering with the subsequent placeholder is guaranteed. +func (al *AgentLoop) sendTranscriptionFeedback( + ctx context.Context, + channel, chatID, messageID string, + validTexts []string, +) { + if !al.cfg.Voice.EchoTranscription { + return + } + if al.channelManager == nil { + return + } + + var nonEmpty []string + for _, t := range validTexts { + if t != "" { + nonEmpty = append(nonEmpty, t) + } + } + + var feedbackMsg string + if len(nonEmpty) > 0 { + feedbackMsg = "Transcript: " + strings.Join(nonEmpty, "\n") + } else { + feedbackMsg = "No voice detected in the audio" + } + + err := al.channelManager.SendMessage(ctx, bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: feedbackMsg, + ReplyToMessageID: messageID, + }) + if err != nil { + logger.WarnCF("voice", "Failed to send transcription feedback", map[string]any{"error": err.Error()}) + } } // inferMediaType determines the media type ("image", "audio", "video", "file") @@ -575,6 +519,10 @@ func (al *AgentLoop) ProcessDirectWithChannel( ctx context.Context, content, sessionKey, channel, chatID string, ) (string, error) { + if err := al.ensureMCPInitialized(ctx); err != nil { + return "", err + } + msg := bus.InboundMessage{ Channel: channel, SenderID: "cron", @@ -627,7 +575,14 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) }, ) - msg = al.transcribeAudioInMessage(ctx, msg) + var hadAudio bool + msg, hadAudio = al.transcribeAudioInMessage(ctx, msg) + + // For audio messages the placeholder was deferred by the channel. + // Now that transcription (and optional feedback) is done, send it. + if hadAudio && al.channelManager != nil { + al.channelManager.SendPlaceholder(ctx, msg.Channel, msg.ChatID) + } // Route system messages to processSystemMessage if msg.Channel == "system" { diff --git a/pkg/agent/loop_mcp.go b/pkg/agent/loop_mcp.go new file mode 100644 index 000000000..2795db52a --- /dev/null +++ b/pkg/agent/loop_mcp.go @@ -0,0 +1,184 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package agent + +import ( + "context" + "fmt" + "sync" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/mcp" + "github.com/sipeed/picoclaw/pkg/tools" +) + +type mcpRuntime struct { + initOnce sync.Once + mu sync.Mutex + manager *mcp.Manager + initErr error +} + +func (r *mcpRuntime) setManager(manager *mcp.Manager) { + r.mu.Lock() + r.manager = manager + r.initErr = nil + r.mu.Unlock() +} + +func (r *mcpRuntime) setInitErr(err error) { + r.mu.Lock() + r.initErr = err + r.mu.Unlock() +} + +func (r *mcpRuntime) getInitErr() error { + r.mu.Lock() + defer r.mu.Unlock() + return r.initErr +} + +func (r *mcpRuntime) takeManager() *mcp.Manager { + r.mu.Lock() + defer r.mu.Unlock() + manager := r.manager + r.manager = nil + return manager +} + +func (r *mcpRuntime) hasManager() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r.manager != nil +} + +// ensureMCPInitialized loads MCP servers/tools once so both Run() and direct +// agent mode share the same initialization path. +func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error { + if !al.cfg.Tools.IsToolEnabled("mcp") { + return nil + } + + al.mcp.initOnce.Do(func() { + mcpManager := mcp.NewManager() + + defaultAgent := al.registry.GetDefaultAgent() + workspacePath := al.cfg.WorkspacePath() + if defaultAgent != nil && defaultAgent.Workspace != "" { + workspacePath = defaultAgent.Workspace + } + + if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil { + logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available", + map[string]any{ + "error": err.Error(), + }) + if closeErr := mcpManager.Close(); closeErr != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]any{ + "error": closeErr.Error(), + }) + } + return + } + + // Register MCP tools for all agents + servers := mcpManager.GetServers() + uniqueTools := 0 + totalRegistrations := 0 + agentIDs := al.registry.ListAgentIDs() + agentCount := len(agentIDs) + + for serverName, conn := range servers { + uniqueTools += len(conn.Tools) + for _, tool := range conn.Tools { + for _, agentID := range agentIDs { + agent, ok := al.registry.GetAgent(agentID) + if !ok { + continue + } + + mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) + + if al.cfg.Tools.MCP.Discovery.Enabled { + agent.Tools.RegisterHidden(mcpTool) + } else { + agent.Tools.Register(mcpTool) + } + + totalRegistrations++ + logger.DebugCF("agent", "Registered MCP tool", + map[string]any{ + "agent_id": agentID, + "server": serverName, + "tool": tool.Name, + "name": mcpTool.Name(), + }) + } + } + } + logger.InfoCF("agent", "MCP tools registered successfully", + map[string]any{ + "server_count": len(servers), + "unique_tools": uniqueTools, + "total_registrations": totalRegistrations, + "agent_count": agentCount, + }) + + // Initializes Discovery Tools only if enabled by configuration + if al.cfg.Tools.MCP.Enabled && al.cfg.Tools.MCP.Discovery.Enabled { + useBM25 := al.cfg.Tools.MCP.Discovery.UseBM25 + useRegex := al.cfg.Tools.MCP.Discovery.UseRegex + + // Fail fast: If discovery is enabled but no search method is turned on + if !useBM25 && !useRegex { + al.mcp.setInitErr(fmt.Errorf( + "tool discovery is enabled but neither 'use_bm25' nor 'use_regex' is set to true in the configuration", + )) + if closeErr := mcpManager.Close(); closeErr != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]any{ + "error": closeErr.Error(), + }) + } + return + } + + ttl := al.cfg.Tools.MCP.Discovery.TTL + if ttl <= 0 { + ttl = 5 // Default value + } + + maxSearchResults := al.cfg.Tools.MCP.Discovery.MaxSearchResults + if maxSearchResults <= 0 { + maxSearchResults = 5 // Default value + } + + logger.InfoCF("agent", "Initializing tool discovery", map[string]any{ + "bm25": useBM25, "regex": useRegex, "ttl": ttl, "max_results": maxSearchResults, + }) + + for _, agentID := range agentIDs { + agent, ok := al.registry.GetAgent(agentID) + if !ok { + continue + } + + if useRegex { + agent.Tools.Register(tools.NewRegexSearchTool(agent.Tools, ttl, maxSearchResults)) + } + if useBM25 { + agent.Tools.Register(tools.NewBM25SearchTool(agent.Tools, ttl, maxSearchResults)) + } + } + } + + al.mcp.setManager(mcpManager) + }) + + return al.mcp.getInitErr() +} diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 2e456fa60..6f90c6155 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -42,7 +42,7 @@ func newTestAgentLoop( Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -101,7 +101,7 @@ func TestNewAgentLoop_StateInitialized(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -137,7 +137,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -194,7 +194,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -230,7 +230,7 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) { cfg := config.DefaultConfig() cfg.Agents.Defaults.Workspace = tmpDir - cfg.Agents.Defaults.Model = "test-model" + cfg.Agents.Defaults.ModelName = "test-model" cfg.Agents.Defaults.MaxTokens = 4096 cfg.Agents.Defaults.MaxToolIterations = 10 @@ -274,7 +274,7 @@ func TestAgentLoop_Stop(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -394,7 +394,7 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -450,7 +450,7 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -530,7 +530,7 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { Defaults: config.AgentDefaults{ Workspace: tmpDir, Provider: "openai", - Model: "before-switch", + ModelName: "before-switch", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -587,7 +587,7 @@ func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -629,7 +629,7 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -700,7 +700,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -770,6 +770,56 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { } } +func TestProcessDirectWithChannel_InitializesMCPInAgentMode(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + Tools: config.ToolsConfig{ + MCP: config.MCPConfig{ + ToolConfig: config.ToolConfig{ + Enabled: true, + }, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + defer al.Close() + + if al.mcp.hasManager() { + t.Fatal("expected MCP manager to be nil before first direct processing") + } + + _, err = al.ProcessDirectWithChannel( + context.Background(), + "hello", + "session-1", + "cli", + "direct", + ) + if err != nil { + t.Fatalf("ProcessDirectWithChannel failed: %v", err) + } + + if !al.mcp.hasManager() { + t.Fatal("expected MCP manager to be initialized in direct agent mode") + } +} + func TestTargetReasoningChannelID_AllChannels(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { @@ -781,7 +831,7 @@ func TestTargetReasoningChannelID_AllChannels(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, @@ -851,7 +901,7 @@ func TestHandleReasoning(t *testing.T) { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: tmpDir, - Model: "test-model", + ModelName: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, diff --git a/pkg/agent/registry_test.go b/pkg/agent/registry_test.go index 518bb441f..b173ef967 100644 --- a/pkg/agent/registry_test.go +++ b/pkg/agent/registry_test.go @@ -29,7 +29,7 @@ func testCfg(agents []config.AgentConfig) *config.Config { Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: "/tmp/picoclaw-test-registry", - Model: "gpt-4", + ModelName: "gpt-4", MaxTokens: 8192, MaxToolIterations: 10, }, diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 2e55d4877..dff011ee2 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -6,6 +6,7 @@ import ( "path/filepath" "time" + "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/fileutil" ) @@ -39,11 +40,11 @@ func (c *AuthCredential) NeedsRefresh() bool { } func authFilePath() string { - if home := os.Getenv("PICOCLAW_HOME"); home != "" { + if home := os.Getenv(pkg.PicoClawHome); home != "" { return filepath.Join(home, "auth.json") } home, _ := os.UserHomeDir() - return filepath.Join(home, ".picoclaw", "auth.json") + return filepath.Join(home, pkg.DefaultPicoClawHome, "auth.json") } func LoadStore() (*AuthStore, error) { diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 7ad8f0417..12da3f1dd 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -30,9 +30,10 @@ type InboundMessage struct { } type OutboundMessage struct { - Channel string `json:"channel"` - ChatID string `json:"chat_id"` - Content string `json:"content"` + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + Content string `json:"content"` + ReplyToMessageID string `json:"reply_to_message_id,omitempty"` } // MediaPart describes a single media attachment to send. diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 063a66523..edb5b6f08 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/binary" "encoding/hex" + "regexp" "strconv" "strings" "sync/atomic" @@ -32,6 +33,9 @@ func init() { uniqueIDPrefix = hex.EncodeToString(b[:]) } +// audioAnnotationRe matches audio/voice annotations injected by channels (e.g. [voice], [audio: file.ogg]). +var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`) + // uniqueID generates a process-unique ID using a random prefix and an atomic counter. // This ID is intended for internal correlation (e.g. media scope keys) and is NOT // cryptographically secure — it must not be used in contexts where unpredictability matters. @@ -284,10 +288,15 @@ func (c *BaseChannel) HandleMessage( c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo) } } - // Placeholder — independent pipeline - if pc, ok := c.owner.(PlaceholderCapable); ok { - if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" { - c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID) + // Placeholder — independent pipeline. + // Skip when the message contains audio: the agent will send the + // placeholder after transcription completes, so the user sees + // "Thinking…" only once the voice has been processed. + if !audioAnnotationRe.MatchString(content) { + if pc, ok := c.owner.(PlaceholderCapable); ok { + if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" { + c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID) + } } } } diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index 8642ad362..c03122892 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -10,6 +10,7 @@ import ( "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" + dinglog "github.com/open-dingtalk/dingtalk-stream-sdk-go/logger" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -39,6 +40,9 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("dingtalk client_id and client_secret are required") } + // Set the logger for the Stream SDK + dinglog.SetLogger(logger.NewLogger("dingtalk")) + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(20000), channels.WithGroupTrigger(cfg.GroupTrigger), diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index c3bcbff8d..83a04907c 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -45,6 +45,14 @@ type DiscordChannel struct { } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { + discordgo.Logger = logger.NewLogger("discord"). + WithLevels(map[int]logger.LogLevel{ + discordgo.LogError: logger.ERROR, + discordgo.LogWarning: logger.WARN, + discordgo.LogInformational: logger.INFO, + discordgo.LogDebug: logger.DEBUG, + }).Log + session, err := discordgo.New("Bot " + cfg.Token) if err != nil { return nil, fmt.Errorf("failed to create discord session: %w", err) @@ -134,7 +142,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return nil } - return c.sendChunk(ctx, channelID, msg.Content) + return c.sendChunk(ctx, channelID, msg.Content, msg.ReplyToMessageID) } // SendMedia implements the channels.MediaSender interface. @@ -259,14 +267,29 @@ func (c *DiscordChannel) SendPlaceholder(ctx context.Context, chatID string) (st return msg.ID, nil } -func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { +func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content, replyToID string) error { // Use the passed ctx for timeout control sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) defer cancel() done := make(chan error, 1) go func() { - _, err := c.session.ChannelMessageSend(channelID, content) + var err error + + // If we have an ID, we send the message as "Reply" + if replyToID != "" { + _, err = c.session.ChannelMessageSendComplex(channelID, &discordgo.MessageSend{ + Content: content, + Reference: &discordgo.MessageReference{ + MessageID: replyToID, + ChannelID: channelID, + }, + }) + } else { + // Otherwise, we send a normal message + _, err = c.session.ChannelMessageSend(channelID, content) + } + done <- err }() diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 1a24bb980..472895a7a 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -102,6 +102,27 @@ func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { m.placeholders.Store(key, placeholderEntry{id: placeholderID, createdAt: time.Now()}) } +// SendPlaceholder sends a "Thinking…" placeholder for the given channel/chatID +// and records it for later editing. Returns true if a placeholder was sent. +func (m *Manager) SendPlaceholder(ctx context.Context, channel, chatID string) bool { + m.mu.RLock() + ch, ok := m.channels[channel] + m.mu.RUnlock() + if !ok { + return false + } + pc, ok := ch.(PlaceholderCapable) + if !ok { + return false + } + phID, err := pc.SendPlaceholder(ctx, chatID) + if err != nil || phID == "" { + return false + } + m.RecordPlaceholder(channel, chatID, phID) + return true +} + // RecordTypingStop registers a typing stop function for later invocation. // Implements PlaceholderRecorder. func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) { @@ -813,6 +834,39 @@ func (m *Manager) UnregisterChannel(name string) { delete(m.channels, name) } +// SendMessage sends an outbound message synchronously through the channel +// worker's rate limiter and retry logic. It blocks until the message is +// delivered (or all retries are exhausted), which preserves ordering when +// a subsequent operation depends on the message having been sent. +func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error { + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + return fmt.Errorf("channel %s not found", msg.Channel) + } + if !wExists || w == nil { + return fmt.Errorf("channel %s has no active worker", msg.Channel) + } + + maxLen := 0 + if mlp, ok := w.ch.(MessageLengthProvider); ok { + maxLen = mlp.MaxMessageLength() + } + if maxLen > 0 && len([]rune(msg.Content)) > maxLen { + for _, chunk := range SplitMessage(msg.Content, maxLen) { + chunkMsg := msg + chunkMsg.Content = chunk + m.sendWithRetry(ctx, msg.Channel, w, chunkMsg) + } + } else { + m.sendWithRetry(ctx, msg.Channel, w, msg) + } + return nil +} + func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { m.mu.RLock() _, exists := m.channels[channelName] diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index f09ecfe2f..1f3a628c2 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -17,16 +17,32 @@ import ( // mockChannel is a test double that delegates Send to a configurable function. type mockChannel struct { BaseChannel - sendFn func(ctx context.Context, msg bus.OutboundMessage) error + sendFn func(ctx context.Context, msg bus.OutboundMessage) error + sentMessages []bus.OutboundMessage + placeholdersSent int + editedMessages int + lastPlaceholderID string } func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + m.sentMessages = append(m.sentMessages, msg) return m.sendFn(ctx, msg) } func (m *mockChannel) Start(ctx context.Context) error { return nil } func (m *mockChannel) Stop(ctx context.Context) error { return nil } +func (m *mockChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + m.placeholdersSent++ + m.lastPlaceholderID = "mock-ph-123" + return m.lastPlaceholderID, nil +} + +func (m *mockChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + m.editedMessages++ + return nil +} + // newTestManager creates a minimal Manager suitable for unit tests. func newTestManager() *Manager { return &Manager{ @@ -860,3 +876,286 @@ func TestBuildMediaScope_WithMessageID(t *testing.T) { t.Fatalf("expected %s, got %s", expected, scope) } } + +func TestManager_PlaceholderConsumedByResponse(t *testing.T) { + mgr := &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + placeholders: sync.Map{}, + } + + mockCh := &mockChannel{ + sendFn: func(ctx context.Context, msg bus.OutboundMessage) error { + return nil + }, + } + worker := newChannelWorker("mock", mockCh) + mgr.channels["mock"] = mockCh + mgr.workers["mock"] = worker + + ctx := context.Background() + key := "mock:chat-1" + + // Simulate a placeholder recorded by base.go HandleMessage + mgr.RecordPlaceholder("mock", "chat-1", "ph-123") + + if _, ok := mgr.placeholders.Load(key); !ok { + t.Fatal("expected placeholder to be recorded") + } + + // Transcription feedback arrives first — it should consume the placeholder + // and be delivered via EditMessage, not Send. + msgTranscript := bus.OutboundMessage{ + Channel: "mock", + ChatID: "chat-1", + Content: "Transcript: hello", + } + mgr.sendWithRetry(ctx, "mock", worker, msgTranscript) + + if mockCh.editedMessages != 1 { + t.Errorf("expected 1 edited message (placeholder consumed by transcript), got %d", mockCh.editedMessages) + } + if len(mockCh.sentMessages) != 0 { + t.Errorf("expected 0 normal messages (transcript used edit), got %d", len(mockCh.sentMessages)) + } + + // Placeholder should be gone now + if _, ok := mgr.placeholders.Load(key); ok { + t.Error("expected placeholder to be removed after being consumed") + } + + // Final LLM response arrives — no placeholder left, so it goes through Send + msgFinal := bus.OutboundMessage{ + Channel: "mock", + ChatID: "chat-1", + Content: "Final Answer", + } + mgr.sendWithRetry(ctx, "mock", worker, msgFinal) + + if len(mockCh.sentMessages) != 1 { + t.Errorf("expected 1 normal message sent, got %d", len(mockCh.sentMessages)) + } +} + +func TestSendMessage_Synchronous(t *testing.T) { + m := newTestManager() + + var received []bus.OutboundMessage + ch := &mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + received = append(received, msg) + return nil + }, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + msg := bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "hello world", + ReplyToMessageID: "msg-456", + } + + err := m.SendMessage(context.Background(), msg) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // SendMessage is synchronous — message should already be delivered + if len(received) != 1 { + t.Fatalf("expected 1 message sent, got %d", len(received)) + } + if received[0].ReplyToMessageID != "msg-456" { + t.Fatalf("expected ReplyToMessageID msg-456, got %s", received[0].ReplyToMessageID) + } + if received[0].Content != "hello world" { + t.Fatalf("expected content 'hello world', got %s", received[0].Content) + } +} + +func TestSendMessage_UnknownChannel(t *testing.T) { + m := newTestManager() + + msg := bus.OutboundMessage{ + Channel: "nonexistent", + ChatID: "123", + Content: "hello", + } + + err := m.SendMessage(context.Background(), msg) + if err == nil { + t.Fatal("expected error for unknown channel") + } +} + +func TestSendMessage_NoWorker(t *testing.T) { + m := newTestManager() + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { return nil }, + } + m.channels["test"] = ch + // No worker registered + + msg := bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "hello", + } + + err := m.SendMessage(context.Background(), msg) + if err == nil { + t.Fatal("expected error when no worker exists") + } +} + +func TestSendMessage_WithRetry(t *testing.T) { + m := newTestManager() + + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return fmt.Errorf("transient: %w", ErrTemporary) + } + return nil + }, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + msg := bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "retry me", + } + + err := m.SendMessage(context.Background(), msg) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (1 failure + 1 success), got %d", callCount) + } +} + +func TestSendMessage_WithSplitting(t *testing.T) { + m := newTestManager() + + var received []string + ch := &mockChannelWithLength{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + received = append(received, msg.Content) + return nil + }, + }, + maxLen: 5, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + msg := bus.OutboundMessage{ + Channel: "test", + ChatID: "123", + Content: "hello world", + } + + err := m.SendMessage(context.Background(), msg) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(received) < 2 { + t.Fatalf("expected message to be split into at least 2 chunks, got %d", len(received)) + } +} + +func TestSendMessage_PreservesOrdering(t *testing.T) { + m := newTestManager() + + var order []string + ch := &mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + order = append(order, msg.Content) + return nil + }, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + // Send two messages sequentially — they must arrive in order + _ = m.SendMessage(context.Background(), bus.OutboundMessage{ + Channel: "test", ChatID: "1", Content: "first", + }) + _ = m.SendMessage(context.Background(), bus.OutboundMessage{ + Channel: "test", ChatID: "1", Content: "second", + }) + + if len(order) != 2 { + t.Fatalf("expected 2 messages, got %d", len(order)) + } + if order[0] != "first" || order[1] != "second" { + t.Fatalf("expected [first, second], got %v", order) + } +} + +func TestManager_SendPlaceholder(t *testing.T) { + mgr := &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + placeholders: sync.Map{}, + } + + mockCh := &mockChannel{ + sendFn: func(ctx context.Context, msg bus.OutboundMessage) error { + return nil + }, + } + mgr.channels["mock"] = mockCh + + ctx := context.Background() + + // SendPlaceholder should send a placeholder and record it + ok := mgr.SendPlaceholder(ctx, "mock", "chat-1") + if !ok { + t.Fatal("expected SendPlaceholder to succeed") + } + if mockCh.placeholdersSent != 1 { + t.Errorf("expected 1 placeholder sent, got %d", mockCh.placeholdersSent) + } + + key := "mock:chat-1" + if _, loaded := mgr.placeholders.Load(key); !loaded { + t.Error("expected placeholder to be recorded in manager") + } + + // SendPlaceholder on unknown channel should return false + ok = mgr.SendPlaceholder(ctx, "unknown", "chat-1") + if ok { + t.Error("expected SendPlaceholder to fail for unknown channel") + } +} diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index d51eee8fb..a45207f12 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -13,6 +13,9 @@ import ( "sync" "time" + "github.com/gomarkdown/markdown" + mdhtml "github.com/gomarkdown/markdown/html" + "github.com/gomarkdown/markdown/parser" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -268,6 +271,12 @@ func (c *MatrixChannel) Stop(ctx context.Context) error { return nil } +func markdownToHTML(md string) string { + p := parser.NewWithExtensions(parser.CommonExtensions | parser.AutoHeadingIDs) + renderer := mdhtml.NewRenderer(mdhtml.RendererOptions{Flags: mdhtml.CommonFlags}) + return strings.TrimSpace(string(markdown.ToHTML([]byte(md), p, renderer))) +} + func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning @@ -283,16 +292,22 @@ func (c *MatrixChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } - _, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{ - MsgType: event.MsgText, - Body: content, - }) + _, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, c.messageContent(content)) if err != nil { return fmt.Errorf("matrix send: %w", channels.ErrTemporary) } return nil } +func (c *MatrixChannel) messageContent(text string) *event.MessageEventContent { + mc := &event.MessageEventContent{MsgType: event.MsgText, Body: text} + if c.config.MessageFormat != "plain" { + mc.Format = event.FormatHTML + mc.FormattedBody = markdownToHTML(text) + } + return mc +} + // SendMedia implements channels.MediaSender. func (c *MatrixChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { if !c.IsRunning() { @@ -482,10 +497,7 @@ func (c *MatrixChannel) EditMessage(ctx context.Context, chatID string, messageI return fmt.Errorf("matrix message ID is empty") } - editContent := &event.MessageEventContent{ - MsgType: event.MsgText, - Body: content, - } + editContent := c.messageContent(content) editContent.SetEdit(id.EventID(messageID)) _, err := c.client.SendMessageEvent(ctx, roomID, event.EventMessage, editContent) diff --git a/pkg/channels/matrix/matrix_test.go b/pkg/channels/matrix/matrix_test.go index e76db0d3e..806a98739 100644 --- a/pkg/channels/matrix/matrix_test.go +++ b/pkg/channels/matrix/matrix_test.go @@ -4,12 +4,15 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "time" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + + "github.com/sipeed/picoclaw/pkg/config" ) func TestMatrixLocalpartMentionRegexp(t *testing.T) { @@ -289,3 +292,50 @@ func TestMatrixOutboundContent(t *testing.T) { t.Fatalf("unexpected fallback body: %q", noCaption.Body) } } + +func TestMarkdownToHTML(t *testing.T) { + tests := []struct { + name string + input string + contains string + }{ + {"bold", "**hello**", "hello"}, + {"italic", "_world_", "world"}, + {"header", "### Title", ""}, + {"inline code", "`x`", "x"}, + {"plain text", "just text", "just text"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := markdownToHTML(tt.input) + if !strings.Contains(got, tt.contains) { + t.Fatalf("markdownToHTML(%q) = %q, want it to contain %q", tt.input, got, tt.contains) + } + }) + } +} + +func TestMessageContent(t *testing.T) { + richtext := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "richtext"}} + plain := &MatrixChannel{config: config.MatrixConfig{MessageFormat: "plain"}} + defaultt := &MatrixChannel{config: config.MatrixConfig{}} + + for _, c := range []*MatrixChannel{richtext, defaultt} { + mc := c.messageContent("**hi**") + if mc.Format != event.FormatHTML { + t.Errorf("format %q: expected FormatHTML, got %q", c.config.MessageFormat, mc.Format) + } + if !strings.Contains(mc.FormattedBody, "hi") { + t.Errorf("format %q: FormattedBody %q missing ", c.config.MessageFormat, mc.FormattedBody) + } + if mc.Body != "**hi**" { + t.Errorf("format %q: Body should remain plain, got %q", c.config.MessageFormat, mc.Body) + } + } + + mc := plain.messageContent("**hi**") + if mc.Format != "" || mc.FormattedBody != "" { + t.Errorf("plain: expected no formatting, got format=%q formattedBody=%q", mc.Format, mc.FormattedBody) + } +} diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 540e3b7af..73200f64e 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -78,6 +78,7 @@ func (c *QQChannel) Start(ctx context.Context) error { return fmt.Errorf("QQ app_id and app_secret not configured") } + botgo.SetLogger(logger.NewLogger("botgo")) logger.InfoC("qq", "Starting QQ bot (WebSocket mode)") // Reinitialize shutdown signal for clean restart. diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 024b1b023..3ee849621 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -122,7 +122,11 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error slack.MsgOptionText(msg.Content, false), } - if threadTS != "" { + if msg.ReplyToMessageID != "" && threadTS == "" { + // Answer to the message by creating a Thread under it + opts = append(opts, slack.MsgOptionTS(msg.ReplyToMessageID)) + } else if threadTS != "" { + // If we are already in a thread, continue in the thread opts = append(opts, slack.MsgOptionTS(threadTS)) } diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index b04beeb6e..34ee46b7b 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -77,6 +77,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" { opts = append(opts, telego.WithAPIServer(baseURL)) } + opts = append(opts, telego.WithLogger(logger.NewLogger("telego"))) bot, err := telego.NewBot(telegramCfg.Token, opts...) if err != nil { @@ -180,6 +181,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // The Manager already splits messages to ≤4000 chars (WithMaxMessageLength), // so msg.Content is guaranteed to be within that limit. We still need to // check if HTML expansion pushes it beyond Telegram's 4096-char API limit. + replyToID := msg.ReplyToMessageID queue := []string{msg.Content} for len(queue) > 0 { chunk := queue[0] @@ -200,9 +202,11 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err continue } - if err := c.sendHTMLChunk(ctx, chatID, threadID, htmlContent, chunk); err != nil { + if err := c.sendHTMLChunk(ctx, chatID, threadID, htmlContent, chunk, replyToID); err != nil { return err } + // Only the first chunk should be a reply; subsequent chunks are normal messages. + replyToID = "" } return nil @@ -211,12 +215,20 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // sendHTMLChunk sends a single HTML message, falling back to the original // markdown as plain text on parse failure so users never see raw HTML tags. func (c *TelegramChannel) sendHTMLChunk( - ctx context.Context, chatID int64, threadID int, htmlContent, mdFallback string, + ctx context.Context, chatID int64, threadID int, htmlContent, mdFallback string, replyToID string, ) error { tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML tgMsg.MessageThreadID = threadID + if replyToID != "" { + if mid, parseErr := strconv.Atoi(replyToID); parseErr == nil { + tgMsg.ReplyParameters = &telego.ReplyParameters{ + MessageID: mid, + } + } + } + if _, err := c.bot.SendMessage(ctx, tgMsg); err != nil { logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{ "error": err.Error(), diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go index 7f230494f..7d07041ad 100644 --- a/pkg/channels/wecom/app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -209,7 +209,7 @@ func TestWeComAppVerifySignature(t *testing.T) { } }) - t.Run("empty token skips verification", func(t *testing.T) { + t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { cfgEmpty := config.WeComAppConfig{ CorpID: "test_corp_id", CorpSecret: "test_secret", @@ -218,8 +218,8 @@ func TestWeComAppVerifySignature(t *testing.T) { } chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should skip verification and return true") + if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should reject verification (fail-closed)") } }) } diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go index c053578b1..d223bb6b6 100644 --- a/pkg/channels/wecom/bot_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -189,8 +189,7 @@ func TestWeComBotVerifySignature(t *testing.T) { } }) - t.Run("empty token skips verification", func(t *testing.T) { - // Create a channel manually with empty token to test the behavior + t.Run("empty token rejects verification (fail-closed)", func(t *testing.T) { cfgEmpty := config.WeComConfig{ Token: "", WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", @@ -199,8 +198,8 @@ func TestWeComBotVerifySignature(t *testing.T) { config: cfgEmpty, } - if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should skip verification and return true") + if verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should reject verification (fail-closed)") } }) } diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go index 6510e6f81..9a622a2fc 100644 --- a/pkg/channels/wecom/common.go +++ b/pkg/channels/wecom/common.go @@ -31,7 +31,7 @@ func computeSignature(token, timestamp, nonce, encrypt string) string { // This is a common function used by both WeCom Bot and WeCom App func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { if token == "" { - return true // Skip verification if token is not set + return false } return computeSignature(token, timestamp, nonce, msgEncrypt) == msgSignature } diff --git a/pkg/config/config.go b/pkg/config/config.go index 13d5a7306..8bc46dfc4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,12 +4,15 @@ import ( "encoding/json" "fmt" "os" + "path/filepath" "strings" "sync/atomic" "github.com/caarlos0/env/v11" + "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/fileutil" + "github.com/sipeed/picoclaw/pkg/logger" ) // rrCounter is a global counter for round-robin load balancing across models. @@ -17,6 +20,8 @@ var rrCounter atomic.Uint64 // FlexibleStringSlice is a []string that also accepts JSON numbers, // so allow_from can contain both "123" and 123. +// It also supports parsing comma-separated strings from environment variables, +// including both English (,) and Chinese (,) commas. type FlexibleStringSlice []string func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { @@ -48,17 +53,46 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { return nil } +// UnmarshalText implements encoding.TextUnmarshaler to support env variable parsing. +// It handles comma-separated values with both English (,) and Chinese (,) commas. +func (f *FlexibleStringSlice) UnmarshalText(text []byte) error { + if len(text) == 0 { + *f = nil + return nil + } + + s := string(text) + // Replace Chinese comma with English comma, then split + s = strings.ReplaceAll(s, ",", ",") + parts := strings.Split(s, ",") + + result := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + result = append(result, part) + } + } + *f = result + return nil +} + +// CurrentVersion is the latest config schema version +const CurrentVersion = 1 + +// Config is the current config structure with version support type Config struct { + Version int `json:"version"` // Config schema version for migration Agents AgentsConfig `json:"agents"` Bindings []AgentBinding `json:"bindings,omitempty"` Session SessionConfig `json:"session,omitempty"` Channels ChannelsConfig `json:"channels"` - Providers ProvidersConfig `json:"providers,omitempty"` ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration Gateway GatewayConfig `json:"gateway"` Tools ToolsConfig `json:"tools"` Heartbeat HeartbeatConfig `json:"heartbeat"` Devices DevicesConfig `json:"devices"` + Voice VoiceConfig `json:"voice"` // BuildInfo contains build-time version information BuildInfo BuildInfo `json:"build_info,omitempty"` } @@ -73,19 +107,14 @@ type BuildInfo struct { // MarshalJSON implements custom JSON marshaling for Config // to omit providers section when empty and session when empty -func (c Config) MarshalJSON() ([]byte, error) { +func (c *Config) MarshalJSON() ([]byte, error) { type Alias Config aux := &struct { Providers *ProvidersConfig `json:"providers,omitempty"` Session *SessionConfig `json:"session,omitempty"` *Alias }{ - Alias: (*Alias)(&c), - } - - // Only include providers if not empty - if !c.Providers.IsEmpty() { - aux.Providers = &c.Providers + Alias: (*Alias)(c), } // Only include session if not empty @@ -196,7 +225,6 @@ type AgentDefaults struct { AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead ModelFallbacks []string `json:"model_fallbacks,omitempty"` ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` @@ -221,10 +249,7 @@ func (d *AgentDefaults) GetMaxMediaSize() int { // GetModelName returns the effective model name for the agent defaults. // It prefers the new "model_name" field but falls back to "model" for backward compatibility. func (d *AgentDefaults) GetModelName() string { - if d.ModelName != "" { - return d.ModelName - } - return d.Model + return d.ModelName } type ChannelsConfig struct { @@ -349,16 +374,17 @@ type SlackConfig struct { } type MatrixConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"` - Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"` - UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"` - AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_MATRIX_ACCESS_TOKEN"` - DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"` - JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MATRIX_ENABLED"` + Homeserver string `json:"homeserver" env:"PICOCLAW_CHANNELS_MATRIX_HOMESERVER"` + UserID string `json:"user_id" env:"PICOCLAW_CHANNELS_MATRIX_USER_ID"` + AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_MATRIX_ACCESS_TOKEN"` + DeviceID string `json:"device_id,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_DEVICE_ID"` + JoinOnInvite bool `json:"join_on_invite" env:"PICOCLAW_CHANNELS_MATRIX_JOIN_ON_INVITE"` + MessageFormat string `json:"message_format,omitempty" env:"PICOCLAW_CHANNELS_MATRIX_MESSAGE_FORMAT"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MATRIX_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` Placeholder PlaceholderConfig `json:"placeholder,omitempty"` - ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MATRIX_REASONING_CHANNEL_ID"` } type LINEConfig struct { @@ -472,6 +498,10 @@ type DevicesConfig struct { MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"` } +type VoiceConfig struct { + EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"` +} + type ProvidersConfig struct { Anthropic ProviderConfig `json:"anthropic"` OpenAI OpenAIProviderConfig `json:"openai"` @@ -495,6 +525,7 @@ type ProvidersConfig struct { Mistral ProviderConfig `json:"mistral"` Avian ProviderConfig `json:"avian"` Minimax ProviderConfig `json:"minimax"` + LongCat ProviderConfig `json:"longcat"` } // IsEmpty checks if all provider configs are empty (no API keys or API bases set) @@ -521,7 +552,8 @@ func (p ProvidersConfig) IsEmpty() bool { p.Qwen.APIKey == "" && p.Qwen.APIBase == "" && p.Mistral.APIKey == "" && p.Mistral.APIBase == "" && p.Avian.APIKey == "" && p.Avian.APIBase == "" && - p.Minimax.APIKey == "" && p.Minimax.APIBase == "" + p.Minimax.APIKey == "" && p.Minimax.APIBase == "" && + p.LongCat.APIKey == "" && p.LongCat.APIBase == "" } // MarshalJSON implements custom JSON marshaling for ProvidersConfig @@ -668,6 +700,7 @@ type CronToolsConfig struct { type ExecConfig struct { ToolConfig ` envPrefix:"PICOCLAW_TOOLS_EXEC_"` EnableDenyPatterns bool ` env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS" json:"enable_deny_patterns"` + AllowRemote bool ` env:"PICOCLAW_TOOLS_EXEC_ALLOW_REMOTE" json:"allow_remote"` CustomDenyPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS" json:"custom_deny_patterns"` CustomAllowPatterns []string ` env:"PICOCLAW_TOOLS_EXEC_CUSTOM_ALLOW_PATTERNS" json:"custom_allow_patterns"` TimeoutSeconds int ` env:"PICOCLAW_TOOLS_EXEC_TIMEOUT_SECONDS" json:"timeout_seconds"` // 0 means use default (60s) @@ -766,44 +799,53 @@ type MCPConfig struct { } func LoadConfig(path string) (*Config, error) { - cfg := DefaultConfig() - data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { - return cfg, nil + return DefaultConfig(), nil } return nil, err } - // Pre-scan the JSON to check how many model_list entries the user provided. - // Go's JSON decoder reuses existing slice backing-array elements rather than - // zero-initializing them, so fields absent from the user's JSON (e.g. api_base) - // would silently inherit values from the DefaultConfig template at the same - // index position. We only reset cfg.ModelList when the user actually provides - // entries; when count is 0 we keep DefaultConfig's built-in list as fallback. - var tmp Config - if err := json.Unmarshal(data, &tmp); err != nil { - return nil, err + // First, try to detect config version by reading the version field + var versionInfo struct { + Version int `json:"version"` } - if len(tmp.ModelList) > 0 { - cfg.ModelList = nil + if e := json.Unmarshal(data, &versionInfo); e != nil { + return nil, fmt.Errorf("failed to detect config version: %w", e) } - if err := json.Unmarshal(data, cfg); err != nil { - return nil, err + // Load config based on detected version + var cfg *Config + switch versionInfo.Version { + case 0: + // Legacy config (no version field) + v, e := loadConfigV0(data) + if e != nil { + return nil, e + } + cfg, e = v.Migrate() + if e != nil { + logger.DebugF("config migrate fail", map[string]any{"from": versionInfo.Version, "to": CurrentVersion}) + return nil, e + } + logger.DebugF("config migrate success", map[string]any{"from": versionInfo.Version, "to": CurrentVersion}) + defer func() { + _ = SaveConfig(path, cfg) + }() + case CurrentVersion: + // Current version + cfg, err = loadConfig(data) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version) } - if err := env.Parse(cfg); err != nil { - return nil, err - } - - // Migrate legacy channel config fields to new unified structures - cfg.migrateChannelConfigs() - - // Auto-migrate: if only legacy providers config exists, convert to model_list - if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { - cfg.ModelList = ConvertProvidersToModelList(cfg) + // Apply environment variables + if e := env.Parse(cfg); e != nil { + return nil, e } // Validate model_list for uniqueness and required fields @@ -811,23 +853,26 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Ensure Workspace has a default if not set + if cfg.Agents.Defaults.Workspace == "" { + homePath, _ := os.UserHomeDir() + if picoclawHome := os.Getenv(pkg.PicoClawHome); picoclawHome != "" { + homePath = picoclawHome + } else if homePath != "" { + homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome) + } + cfg.Agents.Defaults.Workspace = filepath.Join(homePath, pkg.WorkspaceName) + } + return cfg, nil } -func (c *Config) migrateChannelConfigs() { - // Discord: mention_only -> group_trigger.mention_only - if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly { - c.Channels.Discord.GroupTrigger.MentionOnly = true - } - - // OneBot: group_trigger_prefix -> group_trigger.prefixes - if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && - len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { - c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix - } -} - func SaveConfig(path string, cfg *Config) error { + // Ensure version is always set when saving + if cfg.Version == 0 { + cfg.Version = CurrentVersion + } + data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return err @@ -841,53 +886,6 @@ func (c *Config) WorkspacePath() string { return expandHome(c.Agents.Defaults.Workspace) } -func (c *Config) GetAPIKey() string { - if c.Providers.OpenRouter.APIKey != "" { - return c.Providers.OpenRouter.APIKey - } - if c.Providers.Anthropic.APIKey != "" { - return c.Providers.Anthropic.APIKey - } - if c.Providers.OpenAI.APIKey != "" { - return c.Providers.OpenAI.APIKey - } - if c.Providers.Gemini.APIKey != "" { - return c.Providers.Gemini.APIKey - } - if c.Providers.Zhipu.APIKey != "" { - return c.Providers.Zhipu.APIKey - } - if c.Providers.Groq.APIKey != "" { - return c.Providers.Groq.APIKey - } - if c.Providers.VLLM.APIKey != "" { - return c.Providers.VLLM.APIKey - } - if c.Providers.ShengSuanYun.APIKey != "" { - return c.Providers.ShengSuanYun.APIKey - } - if c.Providers.Cerebras.APIKey != "" { - return c.Providers.Cerebras.APIKey - } - return "" -} - -func (c *Config) GetAPIBase() string { - if c.Providers.OpenRouter.APIKey != "" { - if c.Providers.OpenRouter.APIBase != "" { - return c.Providers.OpenRouter.APIBase - } - return "https://openrouter.ai/api/v1" - } - if c.Providers.Zhipu.APIKey != "" { - return c.Providers.Zhipu.APIBase - } - if c.Providers.VLLM.APIKey != "" && c.Providers.VLLM.APIBase != "" { - return c.Providers.VLLM.APIBase - } - return "" -} - func expandHome(path string) string { if path == "" { return path @@ -930,11 +928,6 @@ func (c *Config) findMatches(modelName string) []ModelConfig { return matches } -// HasProvidersConfig checks if any provider in the old providers config has configuration. -func (c *Config) HasProvidersConfig() bool { - return !c.Providers.IsEmpty() -} - // ValidateModelList validates all ModelConfig entries in the model_list. // It checks that each model config is valid. // Note: Multiple entries with the same model_name are allowed for load balancing. diff --git a/pkg/config/config_old.go b/pkg/config/config_old.go new file mode 100644 index 000000000..782c3dc44 --- /dev/null +++ b/pkg/config/config_old.go @@ -0,0 +1,108 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package config + +type agentDefaultsV0 struct { + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + AllowReadOutsideWorkspace bool `json:"allow_read_outside_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_ALLOW_READ_OUTSIDE_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"` + SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` + Routing *RoutingConfig `json:"routing,omitempty"` +} + +// GetModelName returns the effective model name for the agent defaults. +// It prefers the new "model_name" field but falls back to "model" for backward compatibility. +func (d *agentDefaultsV0) GetModelName() string { + if d.ModelName != "" { + return d.ModelName + } + return d.Model +} + +type agentsConfigV0 struct { + Defaults agentDefaultsV0 `json:"defaults"` + List []AgentConfig `json:"list,omitempty"` +} + +// configV0 represents the config structure before versioning was introduced. +// This struct is used for loading legacy config files (version 0). +// It is unexported since it's only used internally for migration. +type configV0 struct { + Agents agentsConfigV0 `json:"agents"` + Bindings []AgentBinding `json:"bindings,omitempty"` + Session SessionConfig `json:"session,omitempty"` + Channels ChannelsConfig `json:"channels"` + Providers ProvidersConfig `json:"providers,omitempty"` + ModelList []ModelConfig `json:"model_list"` + Gateway GatewayConfig `json:"gateway"` + Tools ToolsConfig `json:"tools"` + Heartbeat HeartbeatConfig `json:"heartbeat"` + Devices DevicesConfig `json:"devices"` +} + +func (c *configV0) migrateChannelConfigs() { + // Discord: mention_only -> group_trigger.mention_only + if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly { + c.Channels.Discord.GroupTrigger.MentionOnly = true + } + + // OneBot: group_trigger_prefix -> group_trigger.prefixes + if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && + len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { + c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix + } +} + +func (c *configV0) Migrate() (*Config, error) { + // Migrate legacy channel config fields to new unified structures + cfg := DefaultConfig() + + // Always copy user's Agents config to preserve settings like Provider, Model, MaxTokens + cfg.Agents.List = c.Agents.List + cfg.Agents.Defaults.Workspace = c.Agents.Defaults.Workspace + cfg.Agents.Defaults.RestrictToWorkspace = c.Agents.Defaults.RestrictToWorkspace + cfg.Agents.Defaults.AllowReadOutsideWorkspace = c.Agents.Defaults.AllowReadOutsideWorkspace + cfg.Agents.Defaults.Provider = c.Agents.Defaults.Provider + cfg.Agents.Defaults.ModelName = c.Agents.Defaults.GetModelName() + cfg.Agents.Defaults.ModelFallbacks = c.Agents.Defaults.ModelFallbacks + cfg.Agents.Defaults.ImageModel = c.Agents.Defaults.ImageModel + cfg.Agents.Defaults.ImageModelFallbacks = c.Agents.Defaults.ImageModelFallbacks + cfg.Agents.Defaults.MaxTokens = c.Agents.Defaults.MaxTokens + cfg.Agents.Defaults.Temperature = c.Agents.Defaults.Temperature + cfg.Agents.Defaults.MaxToolIterations = c.Agents.Defaults.MaxToolIterations + cfg.Agents.Defaults.SummarizeMessageThreshold = c.Agents.Defaults.SummarizeMessageThreshold + cfg.Agents.Defaults.SummarizeTokenPercent = c.Agents.Defaults.SummarizeTokenPercent + cfg.Agents.Defaults.MaxMediaSize = c.Agents.Defaults.MaxMediaSize + cfg.Agents.Defaults.Routing = c.Agents.Defaults.Routing + + // Copy other top-level fields + cfg.Bindings = c.Bindings + cfg.Session = c.Session + cfg.Channels = c.Channels + cfg.Gateway = c.Gateway + cfg.Tools = c.Tools + cfg.Heartbeat = c.Heartbeat + cfg.Devices = c.Devices + + // Only override ModelList if user provided values + if len(c.ModelList) > 0 { + cfg.ModelList = c.ModelList + } + + cfg.Version = CurrentVersion + return cfg, nil +} diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 8baf3e6fd..8f495d5ec 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -5,7 +5,6 @@ import ( "os" "path/filepath" "runtime" - "strings" "testing" ) @@ -207,15 +206,6 @@ func TestDefaultConfig_WorkspacePath(t *testing.T) { } } -// TestDefaultConfig_Model verifies model is set -func TestDefaultConfig_Model(t *testing.T) { - cfg := DefaultConfig() - - if cfg.Agents.Defaults.Model != "" { - t.Error("Model should be empty") - } -} - // TestDefaultConfig_MaxTokens verifies max tokens has default value func TestDefaultConfig_MaxTokens(t *testing.T) { cfg := DefaultConfig() @@ -255,21 +245,6 @@ func TestDefaultConfig_Gateway(t *testing.T) { } } -// TestDefaultConfig_Providers verifies provider structure -func TestDefaultConfig_Providers(t *testing.T) { - cfg := DefaultConfig() - - if cfg.Providers.Anthropic.APIKey != "" { - t.Error("Anthropic API key should be empty by default") - } - if cfg.Providers.OpenAI.APIKey != "" { - t.Error("OpenAI API key should be empty by default") - } - if cfg.Providers.OpenRouter.APIKey != "" { - t.Error("OpenRouter API key should be empty by default") - } -} - // TestDefaultConfig_Channels verifies channels are disabled by default func TestDefaultConfig_Channels(t *testing.T) { cfg := DefaultConfig() @@ -328,25 +303,6 @@ func TestSaveConfig_FilePermissions(t *testing.T) { } } -func TestSaveConfig_IncludesEmptyLegacyModelField(t *testing.T) { - tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "config.json") - - cfg := DefaultConfig() - if err := SaveConfig(path, cfg); err != nil { - t.Fatalf("SaveConfig failed: %v", err) - } - - data, err := os.ReadFile(path) - if err != nil { - t.Fatalf("ReadFile failed: %v", err) - } - - if !strings.Contains(string(data), `"model": ""`) { - t.Fatalf("saved config should include empty legacy model field, got: %s", string(data)) - } -} - // TestConfig_Complete verifies all config fields are set func TestConfig_Complete(t *testing.T) { cfg := DefaultConfig() @@ -354,9 +310,6 @@ func TestConfig_Complete(t *testing.T) { if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } - if cfg.Agents.Defaults.Model != "" { - t.Error("Model should be empty") - } if cfg.Agents.Defaults.Temperature != nil { t.Error("Temperature should be nil when not provided") } @@ -375,19 +328,23 @@ func TestConfig_Complete(t *testing.T) { if !cfg.Heartbeat.Enabled { t.Error("Heartbeat should be enabled by default") } + if !cfg.Tools.Exec.AllowRemote { + t.Error("Exec.AllowRemote should be true by default") + } } -func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) { +func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) { cfg := DefaultConfig() - if !cfg.Providers.OpenAI.WebSearch { - t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true") + if !cfg.Tools.Exec.AllowRemote { + t.Fatal("DefaultConfig().Tools.Exec.AllowRemote should be true") } } -func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) { +func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) { dir := t.TempDir() configPath := filepath.Join(dir, "config.json") - if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil { + if err := os.WriteFile(configPath, []byte(`{"version":1,"tools":{"exec":{"enable_deny_patterns":true}}}`), + 0o600); err != nil { t.Fatalf("WriteFile() error: %v", err) } @@ -395,24 +352,8 @@ func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) { if err != nil { t.Fatalf("LoadConfig() error: %v", err) } - if !cfg.Providers.OpenAI.WebSearch { - t.Fatal("OpenAI codex web search should remain true when unset in config file") - } -} - -func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) { - dir := t.TempDir() - configPath := filepath.Join(dir, "config.json") - if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil { - t.Fatalf("WriteFile() error: %v", err) - } - - cfg, err := LoadConfig(configPath) - if err != nil { - t.Fatalf("LoadConfig() error: %v", err) - } - if cfg.Providers.OpenAI.WebSearch { - t.Fatal("OpenAI codex web search should be false when disabled in config file") + if !cfg.Tools.Exec.AllowRemote { + t.Fatal("tools.exec.allow_remote should remain true when unset in config file") } } @@ -482,3 +423,119 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) { t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want) } } + +// TestFlexibleStringSlice_UnmarshalText tests UnmarshalText with various comma separators +func TestFlexibleStringSlice_UnmarshalText(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "English commas only", + input: "123,456,789", + expected: []string{"123", "456", "789"}, + }, + { + name: "Chinese commas only", + input: "123,456,789", + expected: []string{"123", "456", "789"}, + }, + { + name: "Mixed English and Chinese commas", + input: "123,456,789", + expected: []string{"123", "456", "789"}, + }, + { + name: "Single value", + input: "123", + expected: []string{"123"}, + }, + { + name: "Values with whitespace", + input: " 123 , 456 , 789 ", + expected: []string{"123", "456", "789"}, + }, + { + name: "Empty string", + input: "", + expected: nil, + }, + { + name: "Only commas - English", + input: ",,", + expected: []string{}, + }, + { + name: "Only commas - Chinese", + input: ",,", + expected: []string{}, + }, + { + name: "Mixed commas with empty parts", + input: "123,,456,,789", + expected: []string{"123", "456", "789"}, + }, + { + name: "Complex mixed values", + input: "user1@example.com,user2@test.com, admin@domain.org", + expected: []string{"user1@example.com", "user2@test.com", "admin@domain.org"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var f FlexibleStringSlice + err := f.UnmarshalText([]byte(tt.input)) + if err != nil { + t.Fatalf("UnmarshalText(%q) error = %v", tt.input, err) + } + + if tt.expected == nil { + if f != nil { + t.Errorf("UnmarshalText(%q) = %v, want nil", tt.input, f) + } + return + } + + if len(f) != len(tt.expected) { + t.Errorf("UnmarshalText(%q) length = %d, want %d", tt.input, len(f), len(tt.expected)) + return + } + + for i, v := range tt.expected { + if f[i] != v { + t.Errorf("UnmarshalText(%q)[%d] = %q, want %q", tt.input, i, f[i], v) + } + } + }) + } +} + +// TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency tests nil vs empty slice behavior +func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) { + t.Run("Empty string returns nil", func(t *testing.T) { + var f FlexibleStringSlice + err := f.UnmarshalText([]byte("")) + if err != nil { + t.Fatalf("UnmarshalText error = %v", err) + } + if f != nil { + t.Errorf("Empty string should return nil, got %v", f) + } + }) + + t.Run("Commas only returns empty slice", func(t *testing.T) { + var f FlexibleStringSlice + err := f.UnmarshalText([]byte(",,,")) + if err != nil { + t.Fatalf("UnmarshalText error = %v", err) + } + if f == nil { + t.Error("Commas only should return empty slice, not nil") + } + if len(f) != 0 { + t.Errorf("Expected empty slice, got %v", f) + } + }) +} diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 5bb3bd1d6..938f74e73 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -8,6 +8,8 @@ package config import ( "os" "path/filepath" + + "github.com/sipeed/picoclaw/pkg" ) // DefaultConfig returns the default configuration for PicoClaw. @@ -15,21 +17,21 @@ func DefaultConfig() *Config { // Determine the base path for the workspace. // Priority: $PICOCLAW_HOME > ~/.picoclaw var homePath string - if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" { + if picoclawHome := os.Getenv(pkg.PicoClawHome); picoclawHome != "" { homePath = picoclawHome } else { userHome, _ := os.UserHomeDir() - homePath = filepath.Join(userHome, ".picoclaw") + homePath = filepath.Join(userHome, pkg.DefaultPicoClawHome) } - workspacePath := filepath.Join(homePath, "workspace") + workspacePath := filepath.Join(homePath, pkg.WorkspaceName) return &Config{ + Version: CurrentVersion, Agents: AgentsConfig{ Defaults: AgentDefaults{ Workspace: workspacePath, RestrictToWorkspace: true, Provider: "", - Model: "", MaxTokens: 32768, Temperature: nil, // nil means use provider default MaxToolIterations: 50, @@ -176,9 +178,6 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, }, }, - Providers: ProvidersConfig{ - OpenAI: OpenAIProviderConfig{WebSearch: true}, - }, ModelList: []ModelConfig{ // ============================================ // Add your API key to the model you want to use @@ -355,6 +354,14 @@ func DefaultConfig() *Config { APIKey: "", }, + // LongCat - https://longcat.chat/platform + { + ModelName: "LongCat-Flash-Thinking", + Model: "longcat/LongCat-Flash-Thinking", + APIBase: "https://api.longcat.chat/openai", + APIKey: "", + }, + // VLLM (local) - http://localhost:8000 { ModelName: "local-model", @@ -427,6 +434,7 @@ func DefaultConfig() *Config { Enabled: true, }, EnableDenyPatterns: true, + AllowRemote: true, TimeoutSeconds: 60, }, Skills: SkillsToolsConfig{ @@ -510,6 +518,9 @@ func DefaultConfig() *Config { Enabled: false, MonitorUSB: true, }, + Voice: VoiceConfig{ + EchoTranscription: false, + }, BuildInfo: BuildInfo{ Version: Version, GitCommit: GitCommit, diff --git a/pkg/config/migration.go b/pkg/config/migration.go index 51f21e4f4..4ce02d401 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -6,10 +6,15 @@ package config import ( + "encoding/json" "slices" "strings" ) +type migratable interface { + Migrate() (*Config, error) +} + // buildModelWithProtocol constructs a model string with protocol prefix. // If the model already contains a "/" (indicating it has a protocol prefix), it is returned as-is. // Otherwise, the protocol prefix is added. @@ -21,24 +26,24 @@ func buildModelWithProtocol(protocol, model string) string { return protocol + "/" + model } -// providerMigrationConfig defines how to migrate a provider from old config to new format. -type providerMigrationConfig struct { - // providerNames are the possible names used in agents.defaults.provider - providerNames []string - // protocol is the protocol prefix for the model field - protocol string - // buildConfig creates the ModelConfig from ProviderConfig - buildConfig func(p ProvidersConfig) (ModelConfig, bool) -} - -// ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig. +// v0ConvertProvidersToModelList converts the old ProvidersConfig to a slice of ModelConfig. // This enables backward compatibility with existing configurations. // It preserves the user's configured model from agents.defaults.model when possible. -func ConvertProvidersToModelList(cfg *Config) []ModelConfig { +func v0ConvertProvidersToModelList(cfg *configV0) []ModelConfig { if cfg == nil { return nil } + // providerMigrationConfig defines how to migrate a provider from old config to new format. + type providerMigrationConfig struct { + // providerNames are the possible names used in agents.defaults.provider + providerNames []string + // protocol is the protocol prefix for the model field + protocol string + // buildConfig creates the ModelConfig from ProviderConfig + buildConfig func(p ProvidersConfig) (ModelConfig, bool) + } + // Get user's configured provider and model userProvider := strings.ToLower(cfg.Agents.Defaults.Provider) userModel := cfg.Agents.Defaults.GetModelName() @@ -407,6 +412,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, true }, }, + { + providerNames: []string{"longcat"}, + protocol: "longcat", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.LongCat.APIKey == "" && p.LongCat.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "longcat", + Model: "longcat/LongCat-Flash-Thinking", + APIKey: p.LongCat.APIKey, + APIBase: p.LongCat.APIBase, + Proxy: p.LongCat.Proxy, + RequestTimeout: p.LongCat.RequestTimeout, + }, true + }, + }, } // Process each provider migration @@ -434,3 +456,44 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { return result } + +// loadConfigV0 loads a legacy config (no version field) +func loadConfigV0(data []byte) (migratable, error) { + var v0 configV0 + if err := json.Unmarshal(data, &v0); err != nil { + return nil, err + } + + v0.migrateChannelConfigs() + + // Auto-migrate: if only legacy providers config exists, convert to model_list + if len(v0.ModelList) == 0 && !v0.Providers.IsEmpty() { + v0.ModelList = v0ConvertProvidersToModelList(&v0) + } + + return &v0, nil +} + +// loadConfigV1 loads a version 1 config (current schema) +func loadConfig(data []byte) (*Config, error) { + cfg := DefaultConfig() + + // Pre-scan the JSON to check how many model_list entries the user provided. + // Go's JSON decoder reuses existing slice backing-array elements rather than + // zero-initializing them, so fields absent from the user's JSON (e.g. api_base) + // would silently inherit values from the DefaultConfig template at the same + // index position. We only reset cfg.ModelList when the user actually provides + // entries; when count is 0 we keep DefaultConfig's built-in list as fallback. + var tmp Config + if err := json.Unmarshal(data, &tmp); err != nil { + return nil, err + } + if len(tmp.ModelList) > 0 { + cfg.ModelList = nil + } + + if err := json.Unmarshal(data, cfg); err != nil { + return nil, err + } + return cfg, nil +} diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index d3019aab0..edf873b35 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -11,7 +11,7 @@ import ( ) func TestConvertProvidersToModelList_OpenAI(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{ ProviderConfig: ProviderConfig{ @@ -22,7 +22,7 @@ func TestConvertProvidersToModelList_OpenAI(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -40,7 +40,7 @@ func TestConvertProvidersToModelList_OpenAI(t *testing.T) { } func TestConvertProvidersToModelList_Anthropic(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ Anthropic: ProviderConfig{ APIKey: "ant-key", @@ -49,7 +49,7 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -64,7 +64,7 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) { } func TestConvertProvidersToModelList_LiteLLM(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ LiteLLM: ProviderConfig{ APIKey: "litellm-key", @@ -73,7 +73,7 @@ func TestConvertProvidersToModelList_LiteLLM(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -91,7 +91,7 @@ func TestConvertProvidersToModelList_LiteLLM(t *testing.T) { } func TestConvertProvidersToModelList_Multiple(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "openai-key"}}, Groq: ProviderConfig{APIKey: "groq-key"}, @@ -99,7 +99,7 @@ func TestConvertProvidersToModelList_Multiple(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 3 { t.Fatalf("len(result) = %d, want 3", len(result)) @@ -119,11 +119,11 @@ func TestConvertProvidersToModelList_Multiple(t *testing.T) { } func TestConvertProvidersToModelList_Empty(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{}, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 0 { t.Errorf("len(result) = %d, want 0", len(result)) @@ -131,7 +131,7 @@ func TestConvertProvidersToModelList_Empty(t *testing.T) { } func TestConvertProvidersToModelList_Nil(t *testing.T) { - result := ConvertProvidersToModelList(nil) + result := v0ConvertProvidersToModelList(nil) if result != nil { t.Errorf("result = %v, want nil", result) @@ -139,7 +139,7 @@ func TestConvertProvidersToModelList_Nil(t *testing.T) { } func TestConvertProvidersToModelList_AllProviders(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}}, LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"}, @@ -162,19 +162,20 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { Qwen: ProviderConfig{APIKey: "key17"}, Mistral: ProviderConfig{APIKey: "key18"}, Avian: ProviderConfig{APIKey: "key19"}, + LongCat: ProviderConfig{APIKey: "key-longcat"}, }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) - // All 21 providers should be converted - if len(result) != 21 { - t.Errorf("len(result) = %d, want 21", len(result)) + // All 22 providers should be converted + if len(result) != 22 { + t.Errorf("len(result) = %d, want 22", len(result)) } } func TestConvertProvidersToModelList_Proxy(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{ ProviderConfig: ProviderConfig{ @@ -185,7 +186,7 @@ func TestConvertProvidersToModelList_Proxy(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -197,7 +198,7 @@ func TestConvertProvidersToModelList_Proxy(t *testing.T) { } func TestConvertProvidersToModelList_RequestTimeout(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ Ollama: ProviderConfig{ APIKey: "ollama-key", @@ -206,7 +207,7 @@ func TestConvertProvidersToModelList_RequestTimeout(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -218,7 +219,7 @@ func TestConvertProvidersToModelList_RequestTimeout(t *testing.T) { } func TestConvertProvidersToModelList_AuthMethod(t *testing.T) { - cfg := &Config{ + cfg := &configV0{ Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{ ProviderConfig: ProviderConfig{ @@ -228,7 +229,7 @@ func TestConvertProvidersToModelList_AuthMethod(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 0 { t.Errorf("len(result) = %d, want 0 (AuthMethod alone should not create entry)", len(result)) @@ -238,9 +239,9 @@ func TestConvertProvidersToModelList_AuthMethod(t *testing.T) { // Tests for preserving user's configured model during migration func TestConvertProvidersToModelList_PreservesUserModel_DeepSeek(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "deepseek", Model: "deepseek-reasoner", }, @@ -250,7 +251,7 @@ func TestConvertProvidersToModelList_PreservesUserModel_DeepSeek(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -263,9 +264,9 @@ func TestConvertProvidersToModelList_PreservesUserModel_DeepSeek(t *testing.T) { } func TestConvertProvidersToModelList_PreservesUserModel_OpenAI(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "openai", Model: "gpt-4-turbo", }, @@ -275,7 +276,7 @@ func TestConvertProvidersToModelList_PreservesUserModel_OpenAI(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -287,9 +288,9 @@ func TestConvertProvidersToModelList_PreservesUserModel_OpenAI(t *testing.T) { } func TestConvertProvidersToModelList_PreservesUserModel_Anthropic(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "claude", // alternative name Model: "claude-opus-4-20250514", }, @@ -299,7 +300,7 @@ func TestConvertProvidersToModelList_PreservesUserModel_Anthropic(t *testing.T) }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -311,9 +312,9 @@ func TestConvertProvidersToModelList_PreservesUserModel_Anthropic(t *testing.T) } func TestConvertProvidersToModelList_PreservesUserModel_Qwen(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "qwen", Model: "qwen-plus", }, @@ -323,7 +324,7 @@ func TestConvertProvidersToModelList_PreservesUserModel_Qwen(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -335,9 +336,9 @@ func TestConvertProvidersToModelList_PreservesUserModel_Qwen(t *testing.T) { } func TestConvertProvidersToModelList_UsesDefaultWhenNoUserModel(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "deepseek", Model: "", // no model specified }, @@ -347,7 +348,7 @@ func TestConvertProvidersToModelList_UsesDefaultWhenNoUserModel(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -360,9 +361,9 @@ func TestConvertProvidersToModelList_UsesDefaultWhenNoUserModel(t *testing.T) { } func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "deepseek", Model: "deepseek-reasoner", }, @@ -373,7 +374,7 @@ func TestConvertProvidersToModelList_MultipleProviders_PreservesUserModel(t *tes }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 2 { t.Fatalf("len(result) = %d, want 2", len(result)) @@ -409,9 +410,9 @@ func TestConvertProvidersToModelList_ProviderNameAliases(t *testing.T) { for _, tt := range tests { t.Run(tt.providerAlias, func(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: tt.providerAlias, Model: strings.TrimPrefix( tt.expectedModel, @@ -442,7 +443,7 @@ func TestConvertProvidersToModelList_ProviderNameAliases(t *testing.T) { tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1], ) - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) } @@ -464,9 +465,9 @@ func TestConvertProvidersToModelList_NoProviderField_SingleProvider(t *testing.T // - No provider field set // - model = "glm-4.7" // - Only zhipu has API key configured - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "", // Not set Model: "glm-4.7", }, @@ -476,7 +477,7 @@ func TestConvertProvidersToModelList_NoProviderField_SingleProvider(t *testing.T }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -497,9 +498,9 @@ func TestConvertProvidersToModelList_NoProviderField_MultipleProviders(t *testin // When multiple providers are configured but no provider field is set, // the FIRST provider (in migration order) will use userModel as ModelName // for backward compatibility with legacy implicit provider selection - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "", // Not set Model: "some-model", }, @@ -510,7 +511,7 @@ func TestConvertProvidersToModelList_NoProviderField_MultipleProviders(t *testin }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 2 { t.Fatalf("len(result) = %d, want 2", len(result)) @@ -530,9 +531,9 @@ func TestConvertProvidersToModelList_NoProviderField_MultipleProviders(t *testin func TestConvertProvidersToModelList_NoProviderField_NoModel(t *testing.T) { // Edge case: no provider, no model - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "", Model: "", }, @@ -542,7 +543,7 @@ func TestConvertProvidersToModelList_NoProviderField_NoModel(t *testing.T) { }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) != 1 { t.Fatalf("len(result) = %d, want 1", len(result)) @@ -583,9 +584,9 @@ func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) { // Test for legacy config with protocol prefix in model name func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) { - cfg := &Config{ - Agents: AgentsConfig{ - Defaults: AgentDefaults{ + cfg := &configV0{ + Agents: agentsConfigV0{ + Defaults: agentDefaultsV0{ Provider: "", // No explicit provider Model: "openrouter/auto", // Model already has protocol prefix }, @@ -595,7 +596,7 @@ func TestConvertProvidersToModelList_LegacyModelWithProtocolPrefix(t *testing.T) }, } - result := ConvertProvidersToModelList(cfg) + result := v0ConvertProvidersToModelList(cfg) if len(result) < 1 { t.Fatalf("len(result) = %d, want at least 1", len(result)) diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go index da6e506f8..db0344311 100644 --- a/pkg/config/model_config_test.go +++ b/pkg/config/model_config_test.go @@ -113,39 +113,7 @@ func TestGetModelConfig_Concurrent(t *testing.T) { } } -func TestAgentDefaults_GetModelName_BackwardCompat(t *testing.T) { - tests := []struct { - name string - defaults AgentDefaults - wantName string - }{ - { - name: "new model_name field only", - defaults: AgentDefaults{ModelName: "new-model"}, - wantName: "new-model", - }, - { - name: "old model field only", - defaults: AgentDefaults{Model: "legacy-model"}, - wantName: "legacy-model", - }, - { - name: "both fields - model_name takes precedence", - defaults: AgentDefaults{ModelName: "new-model", Model: "old-model"}, - wantName: "new-model", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.defaults.GetModelName(); got != tt.wantName { - t.Errorf("GetModelName() = %q, want %q", got, tt.wantName) - } - }) - } -} - -func TestAgentDefaults_JSON_BackwardCompat(t *testing.T) { +func TestAgentDefaultsV0_JSON_BackwardCompat(t *testing.T) { tests := []struct { name string json string @@ -170,7 +138,7 @@ func TestAgentDefaults_JSON_BackwardCompat(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var defaults AgentDefaults + var defaults agentDefaultsV0 if err := json.Unmarshal([]byte(tt.json), &defaults); err != nil { t.Fatalf("Unmarshal error: %v", err) } @@ -181,69 +149,6 @@ func TestAgentDefaults_JSON_BackwardCompat(t *testing.T) { } } -func TestFullConfig_JSON_BackwardCompat(t *testing.T) { - // Test complete config with both old and new formats - oldFormat := `{ - "agents": { - "defaults": { - "workspace": "~/.picoclaw/workspace", - "model": "gpt4", - "max_tokens": 4096 - } - }, - "model_list": [ - { - "model_name": "gpt4", - "model": "openai/gpt-4o", - "api_key": "test-key" - } - ] - }` - - newFormat := `{ - "agents": { - "defaults": { - "workspace": "~/.picoclaw/workspace", - "model_name": "gpt4", - "max_tokens": 4096 - } - }, - "model_list": [ - { - "model_name": "gpt4", - "model": "openai/gpt-4o", - "api_key": "test-key" - } - ] - }` - - for name, jsonStr := range map[string]string{ - "old format (model)": oldFormat, - "new format (model_name)": newFormat, - } { - t.Run(name, func(t *testing.T) { - cfg := &Config{} - if err := json.Unmarshal([]byte(jsonStr), cfg); err != nil { - t.Fatalf("Unmarshal error: %v", err) - } - - // Check that GetModelName returns correct value - if got := cfg.Agents.Defaults.GetModelName(); got != "gpt4" { - t.Errorf("GetModelName() = %q, want %q", got, "gpt4") - } - - // Check that GetModelConfig works - modelCfg, err := cfg.GetModelConfig("gpt4") - if err != nil { - t.Fatalf("GetModelConfig error: %v", err) - } - if modelCfg.Model != "openai/gpt-4o" { - t.Errorf("Model = %q, want %q", modelCfg.Model, "openai/gpt-4o") - } - }) - } -} - func TestModelConfig_Validate(t *testing.T) { tests := []struct { name string diff --git a/pkg/env.go b/pkg/env.go new file mode 100644 index 000000000..47f219434 --- /dev/null +++ b/pkg/env.go @@ -0,0 +1,13 @@ +// all environment variables including default values put here + +package pkg + +const ( + Logo = "🦞" + // AppName is the name of the app + AppName = "PicoClaw" + + PicoClawHome = "PICOCLAW_HOME" + DefaultPicoClawHome = ".picoclaw" + WorkspaceName = "workspace" +) diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 56dc87a53..80adcf86c 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -1,24 +1,24 @@ package logger import ( - "encoding/json" "fmt" - "log" "os" + "path/filepath" "runtime" "strings" "sync" - "time" + + "github.com/rs/zerolog" ) -type LogLevel int +type LogLevel = zerolog.Level const ( - DEBUG LogLevel = iota - INFO - WARN - ERROR - FATAL + DEBUG = zerolog.DebugLevel + INFO = zerolog.InfoLevel + WARN = zerolog.WarnLevel + ERROR = zerolog.ErrorLevel + FATAL = zerolog.FatalLevel ) var ( @@ -31,27 +31,24 @@ var ( } currentLevel = INFO - logger *Logger + logger zerolog.Logger + fileLogger zerolog.Logger + logFile *os.File once sync.Once mu sync.RWMutex ) -type Logger struct { - file *os.File -} - -type LogEntry struct { - Level string `json:"level"` - Timestamp string `json:"timestamp"` - Component string `json:"component,omitempty"` - Message string `json:"message"` - Fields map[string]any `json:"fields,omitempty"` - Caller string `json:"caller,omitempty"` -} - func init() { once.Do(func() { - logger = &Logger{} + zerolog.SetGlobalLevel(zerolog.InfoLevel) + + consoleWriter := zerolog.ConsoleWriter{ + Out: os.Stdout, + TimeFormat: "15:04:05", // TODO: make it configurable??? + } + + logger = zerolog.New(consoleWriter).With().Timestamp().Logger() + fileLogger = zerolog.Logger{} }) } @@ -59,6 +56,7 @@ func SetLevel(level LogLevel) { mu.Lock() defer mu.Unlock() currentLevel = level + zerolog.SetGlobalLevel(level) } func GetLevel() LogLevel { @@ -71,17 +69,22 @@ func EnableFileLogging(filePath string) error { mu.Lock() defer mu.Unlock() - file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + return fmt.Errorf("failed to create log directory: %w", err) + } + + newFile, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { return fmt.Errorf("failed to open log file: %w", err) } - if logger.file != nil { - logger.file.Close() + // Close old file if exists + if logFile != nil { + logFile.Close() } - logger.file = file - log.Println("File logging enabled:", filePath) + logFile = newFile + fileLogger = zerolog.New(logFile).With().Timestamp().Caller().Logger() return nil } @@ -89,10 +92,57 @@ func DisableFileLogging() { mu.Lock() defer mu.Unlock() - if logger.file != nil { - logger.file.Close() - logger.file = nil - log.Println("File logging disabled") + if logFile != nil { + logFile.Close() + logFile = nil + } + fileLogger = zerolog.Logger{} +} + +func getCallerInfo() (string, int, string) { + for i := 2; i < 15; i++ { + pc, file, line, ok := runtime.Caller(i) + if !ok { + continue + } + + fn := runtime.FuncForPC(pc) + if fn == nil { + continue + } + + // bypass common loggers + if strings.HasSuffix(file, "/logger.go") || + strings.HasSuffix(file, "/log.go") { + continue + } + + funcName := fn.Name() + if strings.HasPrefix(funcName, "runtime.") { + continue + } + + return filepath.Base(file), line, filepath.Base(funcName) + } + + return "???", 0, "???" +} + +//nolint:zerologlint +func getEvent(logger zerolog.Logger, level LogLevel) *zerolog.Event { + switch level { + case zerolog.DebugLevel: + return logger.Debug() + case zerolog.InfoLevel: + return logger.Info() + case zerolog.WarnLevel: + return logger.Warn() + case zerolog.ErrorLevel: + return logger.Error() + case zerolog.FatalLevel: + return logger.Fatal() + default: + return logger.Info() } } @@ -101,65 +151,41 @@ func logMessage(level LogLevel, component string, message string, fields map[str return } - entry := LogEntry{ - Level: logLevelNames[level], - Timestamp: time.Now().UTC().Format(time.RFC3339), - Component: component, - Message: message, - Fields: fields, - } + callerFile, callerLine, callerFunc := getCallerInfo() - if pc, file, line, ok := runtime.Caller(2); ok { - fn := runtime.FuncForPC(pc) - if fn != nil { - entry.Caller = fmt.Sprintf("%s:%d (%s)", file, line, fn.Name()) - } - } + event := getEvent(logger, level) - if logger.file != nil { - jsonData, err := json.Marshal(entry) - if err == nil { - logger.file.Write(append(jsonData, '\n')) - } - } - - var fieldStr string - if len(fields) > 0 { - fieldStr = " " + formatFields(fields) + // Build combined field with component and caller + if component != "" { + event.Str("caller", fmt.Sprintf("%-6s %s:%d (%s)", component, callerFile, callerLine, callerFunc)) } else { - fieldStr = "" + event.Str("caller", fmt.Sprintf(" %s:%d (%s)", callerFile, callerLine, callerFunc)) } - logLine := fmt.Sprintf("[%s] [%s]%s %s%s", - entry.Timestamp, - logLevelNames[level], - formatComponent(component), - message, - fieldStr, - ) + for k, v := range fields { + event.Interface(k, v) + } - log.Println(logLine) + event.Msg(message) + + // Also log to file if enabled + if fileLogger.GetLevel() != zerolog.NoLevel { + fileEvent := getEvent(fileLogger, level) + + if component != "" { + fileEvent.Str("component", component) + } + for k, v := range fields { + fileEvent.Interface(k, v) + } + fileEvent.Msg(message) + } if level == FATAL { os.Exit(1) } } -func formatComponent(component string) string { - if component == "" { - return "" - } - return fmt.Sprintf(" %s:", component) -} - -func formatFields(fields map[string]any) string { - parts := make([]string, 0, len(fields)) - for k, v := range fields { - parts = append(parts, fmt.Sprintf("%s=%v", k, v)) - } - return fmt.Sprintf("{%s}", strings.Join(parts, ", ")) -} - func Debug(message string) { logMessage(DEBUG, "", message, nil) } @@ -232,6 +258,10 @@ func FatalC(component string, message string) { logMessage(FATAL, component, message, nil) } +func Fatalf(message string, ss ...any) { + logMessage(FATAL, "", fmt.Sprintf(message, ss...), nil) +} + func FatalF(message string, fields map[string]any) { logMessage(FATAL, "", message, fields) } diff --git a/pkg/logger/logger_3rd_party.go b/pkg/logger/logger_3rd_party.go new file mode 100644 index 000000000..da50d686a --- /dev/null +++ b/pkg/logger/logger_3rd_party.go @@ -0,0 +1,95 @@ +// this file is for compatible with 3rd party loggers, should not be called in PicoClaw project + +package logger + +import "fmt" + +// Logger implements common Logger interface +type Logger struct { + component string + levels map[int]LogLevel +} + +// Debug logs debug messages +func (b *Logger) Debug(v ...any) { + logMessage(DEBUG, b.component, fmt.Sprint(v...), nil) +} + +// Info logs info messages +func (b *Logger) Info(v ...any) { + logMessage(INFO, b.component, fmt.Sprint(v...), nil) +} + +// Warn logs warning messages +func (b *Logger) Warn(v ...any) { + logMessage(WARN, b.component, fmt.Sprint(v...), nil) +} + +// Error logs error messages +func (b *Logger) Error(v ...any) { + logMessage(ERROR, b.component, fmt.Sprint(v...), nil) +} + +// Debugf logs formatted debug messages +func (b *Logger) Debugf(format string, v ...any) { + logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil) +} + +// Infof logs formatted info messages +func (b *Logger) Infof(format string, v ...any) { + logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil) +} + +// Warnf logs formatted warning messages +func (b *Logger) Warnf(format string, v ...any) { + logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil) +} + +// Warningf logs formatted warning messages +func (b *Logger) Warningf(format string, v ...any) { + logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil) +} + +// Errorf logs formatted error messages +func (b *Logger) Errorf(format string, v ...any) { + logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil) +} + +// Fatalf logs formatted fatal messages and exits +func (b *Logger) Fatalf(format string, v ...any) { + logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil) +} + +// Log logs a message at a given level with caller information +// the func name must be this because 3rd party loggers expect this +// msgL: message level (DEBUG, INFO, WARN, ERROR, FATAL) +// caller: unused parameter reserved for compatibility +// format: format string +// a: format arguments +// +//nolint:goprintffuncname +func (b *Logger) Log(msgL, caller int, format string, a ...any) { + level := LogLevel(msgL) + if b.levels != nil { + if lvl, ok := b.levels[msgL]; ok { + level = lvl + } + } + logMessage(level, b.component, fmt.Sprintf(format, a...), nil) +} + +// Sync flushes log buffer (no-op for this implementation) +func (b *Logger) Sync() error { + return nil +} + +// WithLevels sets log levels mapping for this logger +func (b *Logger) WithLevels(levels map[int]LogLevel) *Logger { + b.levels = levels + return b +} + +// NewLogger creates a new logger instance with optional component name +func NewLogger(component string) *Logger { + return &Logger{component: component} +} diff --git a/pkg/memory/migration.go b/pkg/memory/migration.go index c9d5176ab..b64c62a9f 100644 --- a/pkg/memory/migration.go +++ b/pkg/memory/migration.go @@ -48,6 +48,12 @@ func MigrateFromJSON( if !strings.HasSuffix(name, ".json") { continue } + // Skip JSONL metadata files. They are part of the new storage format, + // not legacy session snapshots, and re-importing them would overwrite + // the paired .jsonl history with an empty message list. + if strings.HasSuffix(name, ".meta.json") { + continue + } // Skip already-migrated files. if strings.HasSuffix(name, ".migrated") { continue diff --git a/pkg/memory/migration_test.go b/pkg/memory/migration_test.go index 3170758b7..4466c96f9 100644 --- a/pkg/memory/migration_test.go +++ b/pkg/memory/migration_test.go @@ -382,3 +382,55 @@ func TestMigrateFromJSON_NonexistentDir(t *testing.T) { t.Errorf("expected 0, got %d", count) } } + +func TestMigrateFromJSON_SkipsMetaJSONFiles(t *testing.T) { + sessionsDir := t.TempDir() + store, err := NewJSONLStore(sessionsDir) + if err != nil { + t.Fatalf("NewJSONLStore: %v", err) + } + ctx := context.Background() + + if addErr := store.AddMessage(ctx, "agent:main:pico:direct:pico:test", "user", "keep me"); addErr != nil { + t.Fatalf("AddMessage: %v", addErr) + } + if summaryErr := store.SetSummary(ctx, "agent:main:pico:direct:pico:test", "keep summary"); summaryErr != nil { + t.Fatalf("SetSummary: %v", summaryErr) + } + + metaPath := filepath.Join(sessionsDir, "agent_main_pico_direct_pico_test.meta.json") + if _, statErr := os.Stat(metaPath); statErr != nil { + t.Fatalf("meta file missing before migration: %v", statErr) + } + + count, err := MigrateFromJSON(ctx, sessionsDir, store) + if err != nil { + t.Fatalf("MigrateFromJSON: %v", err) + } + if count != 0 { + t.Fatalf("expected 0 migrated, got %d", count) + } + + history, err := store.GetHistory(ctx, "agent:main:pico:direct:pico:test") + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 1 || history[0].Content != "keep me" { + t.Fatalf("history = %+v, want preserved single message", history) + } + + summary, err := store.GetSummary(ctx, "agent:main:pico:direct:pico:test") + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if summary != "keep summary" { + t.Fatalf("summary = %q, want %q", summary, "keep summary") + } + + if _, statErr := os.Stat(metaPath); statErr != nil { + t.Fatalf("meta file should remain in place: %v", statErr) + } + if _, statErr := os.Stat(metaPath + ".migrated"); !os.IsNotExist(statErr) { + t.Fatalf("meta file should not be renamed, stat err = %v", statErr) + } +} diff --git a/pkg/migrate/internal/common.go b/pkg/migrate/internal/common.go index c77ab9f26..32c6ac83b 100644 --- a/pkg/migrate/internal/common.go +++ b/pkg/migrate/internal/common.go @@ -5,20 +5,22 @@ import ( "io" "os" "path/filepath" + + "github.com/sipeed/picoclaw/pkg" ) func ResolveTargetHome(override string) (string, error) { if override != "" { return ExpandHome(override), nil } - if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" { + if envHome := os.Getenv(pkg.PicoClawHome); envHome != "" { return ExpandHome(envHome), nil } home, err := os.UserHomeDir() if err != nil { return "", fmt.Errorf("resolving home directory: %w", err) } - return filepath.Join(home, ".picoclaw"), nil + return filepath.Join(home, pkg.DefaultPicoClawHome), nil } func ExpandHome(path string) string { diff --git a/pkg/migrate/sources/openclaw/common.go b/pkg/migrate/sources/openclaw/common.go index d57dbe34f..337c950d0 100644 --- a/pkg/migrate/sources/openclaw/common.go +++ b/pkg/migrate/sources/openclaw/common.go @@ -4,7 +4,6 @@ var migrateableFiles = []string{ "AGENTS.md", "SOUL.md", "USER.md", - "TOOLS.md", "HEARTBEAT.md", } diff --git a/pkg/migrate/sources/openclaw/openclaw_config.go b/pkg/migrate/sources/openclaw/openclaw_config.go index e272d17a9..e95c2f3ec 100644 --- a/pkg/migrate/sources/openclaw/openclaw_config.go +++ b/pkg/migrate/sources/openclaw/openclaw_config.go @@ -1111,6 +1111,7 @@ func (c ToolsConfig) ToStandardTools() config.ToolsConfig { Exec: config.ExecConfig{ EnableDenyPatterns: c.Exec.EnableDenyPatterns, CustomDenyPatterns: c.Exec.CustomDenyPatterns, + AllowRemote: config.DefaultConfig().Tools.Exec.AllowRemote, }, } } diff --git a/pkg/migrate/sources/openclaw/openclaw_config_test.go b/pkg/migrate/sources/openclaw/openclaw_config_test.go index 3a7d0c686..802693825 100644 --- a/pkg/migrate/sources/openclaw/openclaw_config_test.go +++ b/pkg/migrate/sources/openclaw/openclaw_config_test.go @@ -290,6 +290,20 @@ func TestConvertToPicoClaw(t *testing.T) { } } +func TestToStandardConfig_ExecAllowRemoteDefaultsTrue(t *testing.T) { + cfg := (&PicoClawConfig{ + Tools: ToolsConfig{ + Exec: ExecConfig{ + EnableDenyPatterns: true, + }, + }, + }).ToStandardConfig() + + if !cfg.Tools.Exec.AllowRemote { + t.Fatal("ToStandardConfig() should preserve the default tools.exec.allow_remote=true") + } +} + func TestConvertToPicoClawWithQQAndDingTalk(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, "openclaw.json") diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index d4d648f5a..228cad9c9 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -416,7 +416,7 @@ func TestCreateProvider_ClaudeCli(t *testing.T) { cfg.ModelList = []config.ModelConfig{ {ModelName: "claude-sonnet-4.6", Model: "claude-cli/claude-sonnet-4.6", Workspace: "/test/ws"}, } - cfg.Agents.Defaults.Model = "claude-sonnet-4.6" + cfg.Agents.Defaults.ModelName = "claude-sonnet-4.6" provider, _, err := CreateProvider(cfg) if err != nil { @@ -437,7 +437,7 @@ func TestCreateProvider_ClaudeCode(t *testing.T) { cfg.ModelList = []config.ModelConfig{ {ModelName: "claude-code", Model: "claude-cli/claude-code"}, } - cfg.Agents.Defaults.Model = "claude-code" + cfg.Agents.Defaults.ModelName = "claude-code" provider, _, err := CreateProvider(cfg) if err != nil { @@ -453,7 +453,7 @@ func TestCreateProvider_ClaudeCodec(t *testing.T) { cfg.ModelList = []config.ModelConfig{ {ModelName: "claudecode", Model: "claude-cli/claudecode"}, } - cfg.Agents.Defaults.Model = "claudecode" + cfg.Agents.Defaults.ModelName = "claudecode" provider, _, err := CreateProvider(cfg) if err != nil { @@ -469,7 +469,7 @@ func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) { cfg.ModelList = []config.ModelConfig{ {ModelName: "claude-cli", Model: "claude-cli/claude-sonnet"}, } - cfg.Agents.Defaults.Model = "claude-cli" + cfg.Agents.Defaults.ModelName = "claude-cli" cfg.Agents.Defaults.Workspace = "" provider, _, err := CreateProvider(cfg) diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index ee9c11899..354acafcb 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -1,384 +1,7 @@ package providers import ( - "fmt" - "strings" - "github.com/sipeed/picoclaw/pkg/auth" - "github.com/sipeed/picoclaw/pkg/config" ) -const defaultAnthropicAPIBase = "https://api.anthropic.com/v1" - var getCredential = auth.GetCredential - -type providerType int - -const ( - providerTypeHTTPCompat providerType = iota - providerTypeClaudeAuth - providerTypeCodexAuth - providerTypeCodexCLIToken - providerTypeClaudeCLI - providerTypeCodexCLI - providerTypeGitHubCopilot -) - -type providerSelection struct { - providerType providerType - apiKey string - apiBase string - proxy string - model string - workspace string - connectMode string - enableWebSearch bool -} - -func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { - model := cfg.Agents.Defaults.GetModelName() - providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - lowerModel := strings.ToLower(model) - - if providerName == "" && model == "" { - return providerSelection{}, fmt.Errorf("no model configured: agents.defaults.model is empty") - } - - sel := providerSelection{ - providerType: providerTypeHTTPCompat, - model: model, - } - - // First, prefer explicit provider configuration. - if providerName != "" { - switch providerName { - case "groq": - if cfg.Providers.Groq.APIKey != "" { - sel.apiKey = cfg.Providers.Groq.APIKey - sel.apiBase = cfg.Providers.Groq.APIBase - sel.proxy = cfg.Providers.Groq.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.groq.com/openai/v1" - } - } - case "openai", "gpt": - if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { - sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch - if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { - sel.providerType = providerTypeCodexCLIToken - return sel, nil - } - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - sel.providerType = providerTypeCodexAuth - return sel, nil - } - sel.apiKey = cfg.Providers.OpenAI.APIKey - sel.apiBase = cfg.Providers.OpenAI.APIBase - sel.proxy = cfg.Providers.OpenAI.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.openai.com/v1" - } - } - case "anthropic", "claude": - if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - sel.apiBase = cfg.Providers.Anthropic.APIBase - if sel.apiBase == "" { - sel.apiBase = defaultAnthropicAPIBase - } - sel.providerType = providerTypeClaudeAuth - return sel, nil - } - sel.apiKey = cfg.Providers.Anthropic.APIKey - sel.apiBase = cfg.Providers.Anthropic.APIBase - sel.proxy = cfg.Providers.Anthropic.Proxy - if sel.apiBase == "" { - sel.apiBase = defaultAnthropicAPIBase - } - } - case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { - sel.apiKey = cfg.Providers.OpenRouter.APIKey - sel.proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - sel.apiBase = cfg.Providers.OpenRouter.APIBase - } else { - sel.apiBase = "https://openrouter.ai/api/v1" - } - } - case "litellm": - if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" { - sel.apiKey = cfg.Providers.LiteLLM.APIKey - sel.apiBase = cfg.Providers.LiteLLM.APIBase - sel.proxy = cfg.Providers.LiteLLM.Proxy - if sel.apiBase == "" { - sel.apiBase = "http://localhost:4000/v1" - } - } - case "zhipu", "glm": - if cfg.Providers.Zhipu.APIKey != "" { - sel.apiKey = cfg.Providers.Zhipu.APIKey - sel.apiBase = cfg.Providers.Zhipu.APIBase - sel.proxy = cfg.Providers.Zhipu.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - } - case "gemini", "google": - if cfg.Providers.Gemini.APIKey != "" { - sel.apiKey = cfg.Providers.Gemini.APIKey - sel.apiBase = cfg.Providers.Gemini.APIBase - sel.proxy = cfg.Providers.Gemini.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - } - case "vllm": - if cfg.Providers.VLLM.APIBase != "" { - sel.apiKey = cfg.Providers.VLLM.APIKey - sel.apiBase = cfg.Providers.VLLM.APIBase - sel.proxy = cfg.Providers.VLLM.Proxy - } - case "shengsuanyun": - if cfg.Providers.ShengSuanYun.APIKey != "" { - sel.apiKey = cfg.Providers.ShengSuanYun.APIKey - sel.apiBase = cfg.Providers.ShengSuanYun.APIBase - sel.proxy = cfg.Providers.ShengSuanYun.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://router.shengsuanyun.com/api/v1" - } - } - case "nvidia": - if cfg.Providers.Nvidia.APIKey != "" { - sel.apiKey = cfg.Providers.Nvidia.APIKey - sel.apiBase = cfg.Providers.Nvidia.APIBase - sel.proxy = cfg.Providers.Nvidia.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://integrate.api.nvidia.com/v1" - } - } - case "vivgrid": - if cfg.Providers.Vivgrid.APIKey != "" { - sel.apiKey = cfg.Providers.Vivgrid.APIKey - sel.apiBase = cfg.Providers.Vivgrid.APIBase - sel.proxy = cfg.Providers.Vivgrid.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.vivgrid.com/v1" - } - } - case "claude-cli", "claude-code", "claudecode": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - sel.providerType = providerTypeClaudeCLI - sel.workspace = workspace - return sel, nil - case "codex-cli", "codex-code": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - sel.providerType = providerTypeCodexCLI - sel.workspace = workspace - return sel, nil - case "deepseek": - if cfg.Providers.DeepSeek.APIKey != "" { - sel.apiKey = cfg.Providers.DeepSeek.APIKey - sel.apiBase = cfg.Providers.DeepSeek.APIBase - sel.proxy = cfg.Providers.DeepSeek.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.deepseek.com/v1" - } - if model != "deepseek-chat" && model != "deepseek-reasoner" { - sel.model = "deepseek-chat" - } - } - case "avian": - if cfg.Providers.Avian.APIKey != "" { - sel.apiKey = cfg.Providers.Avian.APIKey - sel.apiBase = cfg.Providers.Avian.APIBase - sel.proxy = cfg.Providers.Avian.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.avian.io/v1" - } - } - case "mistral": - if cfg.Providers.Mistral.APIKey != "" { - sel.apiKey = cfg.Providers.Mistral.APIKey - sel.apiBase = cfg.Providers.Mistral.APIBase - sel.proxy = cfg.Providers.Mistral.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.mistral.ai/v1" - } - } - case "minimax": - if cfg.Providers.Minimax.APIKey != "" { - sel.apiKey = cfg.Providers.Minimax.APIKey - sel.apiBase = cfg.Providers.Minimax.APIBase - sel.proxy = cfg.Providers.Minimax.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.minimaxi.com/v1" - } - } - case "github_copilot", "copilot": - sel.providerType = providerTypeGitHubCopilot - if cfg.Providers.GitHubCopilot.APIBase != "" { - sel.apiBase = cfg.Providers.GitHubCopilot.APIBase - } else { - sel.apiBase = "localhost:4321" - } - sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode - return sel, nil - } - } - - // Fallback: infer provider from model and configured keys. - if sel.apiKey == "" && sel.apiBase == "" { - switch { - case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": - sel.apiKey = cfg.Providers.Moonshot.APIKey - sel.apiBase = cfg.Providers.Moonshot.APIBase - sel.proxy = cfg.Providers.Moonshot.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.moonshot.cn/v1" - } - case strings.HasPrefix(model, "openrouter/") || - strings.HasPrefix(model, "anthropic/") || - strings.HasPrefix(model, "openai/") || - strings.HasPrefix(model, "meta-llama/") || - strings.HasPrefix(model, "deepseek/") || - strings.HasPrefix(model, "google/"): - sel.apiKey = cfg.Providers.OpenRouter.APIKey - sel.proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - sel.apiBase = cfg.Providers.OpenRouter.APIBase - } else { - sel.apiBase = "https://openrouter.ai/api/v1" - } - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && - (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - sel.apiBase = cfg.Providers.Anthropic.APIBase - if sel.apiBase == "" { - sel.apiBase = defaultAnthropicAPIBase - } - sel.providerType = providerTypeClaudeAuth - return sel, nil - } - sel.apiKey = cfg.Providers.Anthropic.APIKey - sel.apiBase = cfg.Providers.Anthropic.APIBase - sel.proxy = cfg.Providers.Anthropic.Proxy - if sel.apiBase == "" { - sel.apiBase = defaultAnthropicAPIBase - } - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && - (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch - if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { - sel.providerType = providerTypeCodexCLIToken - return sel, nil - } - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - sel.providerType = providerTypeCodexAuth - return sel, nil - } - sel.apiKey = cfg.Providers.OpenAI.APIKey - sel.apiBase = cfg.Providers.OpenAI.APIBase - sel.proxy = cfg.Providers.OpenAI.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.openai.com/v1" - } - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - sel.apiKey = cfg.Providers.Gemini.APIKey - sel.apiBase = cfg.Providers.Gemini.APIBase - sel.proxy = cfg.Providers.Gemini.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - sel.apiKey = cfg.Providers.Zhipu.APIKey - sel.apiBase = cfg.Providers.Zhipu.APIBase - sel.proxy = cfg.Providers.Zhipu.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - sel.apiKey = cfg.Providers.Groq.APIKey - sel.apiBase = cfg.Providers.Groq.APIBase - sel.proxy = cfg.Providers.Groq.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.groq.com/openai/v1" - } - case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": - sel.apiKey = cfg.Providers.Nvidia.APIKey - sel.apiBase = cfg.Providers.Nvidia.APIBase - sel.proxy = cfg.Providers.Nvidia.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://integrate.api.nvidia.com/v1" - } - case strings.HasPrefix(model, "vivgrid/") && cfg.Providers.Vivgrid.APIKey != "": - sel.apiKey = cfg.Providers.Vivgrid.APIKey - sel.apiBase = cfg.Providers.Vivgrid.APIBase - sel.proxy = cfg.Providers.Vivgrid.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.vivgrid.com/v1" - } - case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": - sel.apiKey = cfg.Providers.Ollama.APIKey - sel.apiBase = cfg.Providers.Ollama.APIBase - sel.proxy = cfg.Providers.Ollama.Proxy - if sel.apiBase == "" { - sel.apiBase = "http://localhost:11434/v1" - } - case (strings.Contains(lowerModel, "mistral") || strings.HasPrefix(model, "mistral/")) && cfg.Providers.Mistral.APIKey != "": - sel.apiKey = cfg.Providers.Mistral.APIKey - sel.apiBase = cfg.Providers.Mistral.APIBase - sel.proxy = cfg.Providers.Mistral.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.mistral.ai/v1" - } - case (strings.Contains(lowerModel, "minimax") || strings.HasPrefix(model, "minimax/")) && cfg.Providers.Minimax.APIKey != "": - sel.apiKey = cfg.Providers.Minimax.APIKey - sel.apiBase = cfg.Providers.Minimax.APIBase - sel.proxy = cfg.Providers.Minimax.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.minimaxi.com/v1" - } - case strings.HasPrefix(model, "avian/") && cfg.Providers.Avian.APIKey != "": - sel.apiKey = cfg.Providers.Avian.APIKey - sel.apiBase = cfg.Providers.Avian.APIBase - sel.proxy = cfg.Providers.Avian.Proxy - if sel.apiBase == "" { - sel.apiBase = "https://api.avian.io/v1" - } - case cfg.Providers.VLLM.APIBase != "": - sel.apiKey = cfg.Providers.VLLM.APIKey - sel.apiBase = cfg.Providers.VLLM.APIBase - sel.proxy = cfg.Providers.VLLM.Proxy - default: - if cfg.Providers.OpenRouter.APIKey != "" { - sel.apiKey = cfg.Providers.OpenRouter.APIKey - sel.proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - sel.apiBase = cfg.Providers.OpenRouter.APIBase - } else { - sel.apiBase = "https://openrouter.ai/api/v1" - } - } else { - return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model) - } - } - } - - if sel.providerType == providerTypeHTTPCompat { - if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") { - return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model) - } - if sel.apiBase == "" { - return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model) - } - } - - return sel, nil -} diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index a798154cb..9749e7a15 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -95,7 +95,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian", - "minimax": + "minimax", "longcat": // All other OpenAI-compatible HTTP providers if cfg.APIKey == "" && cfg.APIBase == "" { return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol) @@ -215,6 +215,8 @@ func getDefaultAPIBase(protocol string) string { return "https://api.avian.io/v1" case "minimax": return "https://api.minimaxi.com/v1" + case "longcat": + return "https://api.longcat.chat/openai" default: return "" } diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index 17bc55d25..6c7bb4795 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -113,6 +113,7 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) { {"vllm", "vllm"}, {"deepseek", "deepseek"}, {"ollama", "ollama"}, + {"longcat", "longcat"}, } for _, tt := range tests { @@ -162,6 +163,29 @@ func TestCreateProviderFromConfig_LiteLLM(t *testing.T) { } } +func TestCreateProviderFromConfig_LongCat(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-longcat", + Model: "longcat/LongCat-Flash-Thinking", + APIKey: "test-key", + APIBase: "https://api.longcat.chat/openai", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "LongCat-Flash-Thinking" { + t.Errorf("modelID = %q, want %q", modelID, "LongCat-Flash-Thinking") + } + if _, ok := provider.(*HTTPProvider); !ok { + t.Fatalf("expected *HTTPProvider, got %T", provider) + } +} + func TestCreateProviderFromConfig_Anthropic(t *testing.T) { cfg := &config.ModelConfig{ ModelName: "test-anthropic", diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go index 36ccda4a1..bd8fbd1c4 100644 --- a/pkg/providers/factory_test.go +++ b/pkg/providers/factory_test.go @@ -1,234 +1,15 @@ package providers import ( - "strings" "testing" "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) -func TestResolveProviderSelection(t *testing.T) { - tests := []struct { - name string - setup func(*config.Config) - wantType providerType - wantAPIBase string - wantProxy string - wantErrSubstr string - }{ - { - name: "explicit litellm provider uses configured base", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "litellm" - cfg.Providers.LiteLLM.APIKey = "litellm-key" - cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1" - cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "http://localhost:4000/v1", - wantProxy: "http://127.0.0.1:7890", - }, - { - name: "explicit litellm provider defaults base when only key is configured", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "litellm" - cfg.Providers.LiteLLM.APIKey = "litellm-key" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "http://localhost:4000/v1", - }, - { - name: "explicit claude-cli provider routes to cli provider type", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "claude-cli" - cfg.Agents.Defaults.Workspace = "/tmp/ws" - }, - wantType: providerTypeClaudeCLI, - }, - { - name: "explicit copilot provider routes to github copilot type", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "copilot" - }, - wantType: providerTypeGitHubCopilot, - wantAPIBase: "localhost:4321", - }, - { - name: "explicit deepseek provider uses deepseek defaults", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "deepseek" - cfg.Agents.Defaults.Model = "deepseek/deepseek-chat" - cfg.Providers.DeepSeek.APIKey = "deepseek-key" - cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://api.deepseek.com/v1", - wantProxy: "http://127.0.0.1:7890", - }, - { - name: "explicit shengsuanyun provider uses defaults", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "shengsuanyun" - cfg.Providers.ShengSuanYun.APIKey = "ssy-key" - cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://router.shengsuanyun.com/api/v1", - wantProxy: "http://127.0.0.1:7890", - }, - { - name: "explicit nvidia provider uses defaults", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "nvidia" - cfg.Providers.Nvidia.APIKey = "nvapi-test" - cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://integrate.api.nvidia.com/v1", - wantProxy: "http://127.0.0.1:7890", - }, - { - name: "explicit vivgrid provider uses defaults", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "vivgrid" - cfg.Providers.Vivgrid.APIKey = "vivgrid-key" - cfg.Providers.Vivgrid.Proxy = "http://127.0.0.1:7890" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://api.vivgrid.com/v1", - wantProxy: "http://127.0.0.1:7890", - }, - { - name: "openrouter model uses openrouter defaults", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "openrouter/auto" - cfg.Providers.OpenRouter.APIKey = "sk-or-test" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://openrouter.ai/api/v1", - }, - { - name: "anthropic oauth routes to claude auth provider", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "claude-sonnet-4.6" - cfg.Providers.Anthropic.AuthMethod = "oauth" - }, - wantType: providerTypeClaudeAuth, - }, - { - name: "openai oauth routes to codex auth provider", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "gpt-4o" - cfg.Providers.OpenAI.AuthMethod = "oauth" - }, - wantType: providerTypeCodexAuth, - }, - { - name: "openai codex-cli auth routes to codex cli token provider", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "gpt-4o" - cfg.Providers.OpenAI.AuthMethod = "codex-cli" - }, - wantType: providerTypeCodexCLIToken, - }, - { - name: "explicit codex-code provider routes to codex cli provider type", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Provider = "codex-code" - cfg.Agents.Defaults.Workspace = "/tmp/ws" - }, - wantType: providerTypeCodexCLI, - }, - { - name: "zhipu model uses zhipu base default", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "glm-4.7" - cfg.Providers.Zhipu.APIKey = "zhipu-key" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://open.bigmodel.cn/api/paas/v4", - }, - { - name: "groq model uses groq base default", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "groq/llama-3.3-70b" - cfg.Providers.Groq.APIKey = "gsk-key" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://api.groq.com/openai/v1", - }, - { - name: "ollama model uses ollama base default", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b" - cfg.Providers.Ollama.APIKey = "ollama-key" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "http://localhost:11434/v1", - }, - { - name: "moonshot model keeps proxy and default base", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5" - cfg.Providers.Moonshot.APIKey = "moonshot-key" - cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890" - }, - wantType: providerTypeHTTPCompat, - wantAPIBase: "https://api.moonshot.cn/v1", - wantProxy: "http://127.0.0.1:7890", - }, - { - name: "missing keys returns model config error", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "custom-model" - }, - wantErrSubstr: "no API key configured for model", - }, - { - name: "openrouter prefix without key returns provider key error", - setup: func(cfg *config.Config) { - cfg.Agents.Defaults.Model = "openrouter/auto" - }, - wantErrSubstr: "no API key configured for provider", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cfg := config.DefaultConfig() - tt.setup(cfg) - - got, err := resolveProviderSelection(cfg) - if tt.wantErrSubstr != "" { - if err == nil { - t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr) - } - if !strings.Contains(err.Error(), tt.wantErrSubstr) { - t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr) - } - return - } - - if err != nil { - t.Fatalf("resolveProviderSelection() error = %v", err) - } - if got.providerType != tt.wantType { - t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType) - } - if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase { - t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase) - } - if tt.wantProxy != "" && got.proxy != tt.wantProxy { - t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy) - } - }) - } -} - func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Model = "test-openrouter" + cfg.Agents.Defaults.ModelName = "test-openrouter" cfg.ModelList = []config.ModelConfig{ { ModelName: "test-openrouter", @@ -250,7 +31,7 @@ func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) { func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Model = "test-codex" + cfg.Agents.Defaults.ModelName = "test-codex" cfg.ModelList = []config.ModelConfig{ { ModelName: "test-codex", @@ -271,7 +52,7 @@ func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) { func TestCreateProviderReturnsClaudeCliProviderForClaudeCli(t *testing.T) { cfg := config.DefaultConfig() - cfg.Agents.Defaults.Model = "test-claude-cli" + cfg.Agents.Defaults.ModelName = "test-claude-cli" cfg.ModelList = []config.ModelConfig{ { ModelName: "test-claude-cli", @@ -304,7 +85,7 @@ func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) { } cfg := config.DefaultConfig() - cfg.Agents.Defaults.Model = "test-claude-oauth" + cfg.Agents.Defaults.ModelName = "test-claude-oauth" cfg.ModelList = []config.ModelConfig{ { ModelName: "test-claude-oauth", diff --git a/pkg/providers/legacy_provider.go b/pkg/providers/legacy_provider.go index 26905159f..4b0815dd4 100644 --- a/pkg/providers/legacy_provider.go +++ b/pkg/providers/legacy_provider.go @@ -18,23 +18,6 @@ import ( func CreateProvider(cfg *config.Config) (LLMProvider, string, error) { model := cfg.Agents.Defaults.GetModelName() - // Ensure model_list is populated from providers config if needed - // This handles two cases: - // 1. ModelList is empty - convert all providers - // 2. ModelList has some entries but not all providers - merge missing ones - if cfg.HasProvidersConfig() { - providerModels := config.ConvertProvidersToModelList(cfg) - existingModelNames := make(map[string]bool) - for _, m := range cfg.ModelList { - existingModelNames[m.ModelName] = true - } - for _, pm := range providerModels { - if !existingModelNames[pm.ModelName] { - cfg.ModelList = append(cfg.ModelList, pm) - } - } - } - // Must have model_list at this point if len(cfg.ModelList) == 0 { return nil, "", fmt.Errorf("no providers configured. Please add entries to model_list in your config") diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 0e8db7409..f97bf3acd 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -156,9 +156,10 @@ func (p *Provider) Chat( // The key is typically the agent ID — stable per agent, shared across requests. // See: https://platform.openai.com/docs/guides/prompt-caching // Prompt caching is only supported by OpenAI-native endpoints. - // Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs. + // Non-OpenAI providers (Mistral, Gemini, DeepSeek, etc.) reject unknown + // fields with 422 errors, so only include it for OpenAI APIs. if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" { - if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") { + if supportsPromptCacheKey(p.apiBase) { requestBody["prompt_cache_key"] = cacheKey } } @@ -283,8 +284,8 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { ID string `json:"id"` Type string `json:"type"` Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` } `json:"function"` ExtraContent *struct { Google *struct { @@ -323,12 +324,7 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { if tc.Function != nil { name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) - arguments["raw"] = tc.Function.Arguments - } - } + arguments = decodeToolCallArguments(tc.Function.Arguments, name) } // Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence @@ -361,6 +357,39 @@ func parseResponse(body io.Reader) (*LLMResponse, error) { }, nil } +func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any { + arguments := make(map[string]any) + raw = bytes.TrimSpace(raw) + if len(raw) == 0 || bytes.Equal(raw, []byte("null")) { + return arguments + } + + var decoded any + if err := json.Unmarshal(raw, &decoded); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err) + arguments["raw"] = string(raw) + return arguments + } + + switch v := decoded.(type) { + case string: + if strings.TrimSpace(v) == "" { + return arguments + } + if err := json.Unmarshal([]byte(v), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) + arguments["raw"] = v + } + return arguments + case map[string]any: + return v + default: + log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded) + arguments["raw"] = string(raw) + return arguments + } +} + // openaiMessage is the wire-format message for OpenAI-compatible APIs. // It mirrors protocoltypes.Message but omits SystemParts, which is an // internal field that would be unknown to third-party endpoints. @@ -476,3 +505,16 @@ func asFloat(v any) (float64, bool) { return 0, false } } + +// supportsPromptCacheKey reports whether the given API base is known to +// support the prompt_cache_key request field. Currently only OpenAI's own +// API and Azure OpenAI support this. All other OpenAI-compatible providers +// (Mistral, Gemini, DeepSeek, Groq, etc.) reject unknown fields with 422 errors. +func supportsPromptCacheKey(apiBase string) bool { + u, err := url.Parse(apiBase) + if err != nil { + return false + } + host := u.Hostname() + return host == "api.openai.com" || strings.HasSuffix(host, ".openai.azure.com") +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 9a3a7acc5..41f278a1b 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -108,6 +108,55 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) { } } +func TestProviderChat_ParsesToolCallsWithObjectArguments(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "", + "tool_calls": []map[string]any{ + { + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": map[string]any{ + "city": "SF", + "metric": true, + }, + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } + if out.ToolCalls[0].Arguments["metric"] != true { + t.Fatalf("ToolCalls[0].Arguments[metric] = %v, want true", out.ToolCalls[0].Arguments["metric"]) + } +} + func TestProviderChat_ParsesReasoningContent(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := map[string]any{ @@ -669,6 +718,111 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { } } +// chatWithCacheKey sets up a test server, sends a Chat request with prompt_cache_key, +// and returns the decoded request body for assertion. +func chatWithCacheKey(t *testing.T, apiBase string) map[string]any { + t.Helper() + var requestBody map[string]any + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + p.apiBase = apiBase + p.httpClient = &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + r.URL, _ = url.Parse(server.URL + r.URL.Path) + return http.DefaultTransport.RoundTrip(r) + }), + } + + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "test-model", + map[string]any{"prompt_cache_key": "agent-main"}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + return requestBody +} + +func TestProviderChat_PromptCacheKeySentToOpenAI(t *testing.T) { + body := chatWithCacheKey(t, "https://api.openai.com/v1") + if body["prompt_cache_key"] != "agent-main" { + t.Fatalf("prompt_cache_key = %v, want %q", body["prompt_cache_key"], "agent-main") + } +} + +func TestProviderChat_PromptCacheKeyOmittedForNonOpenAI(t *testing.T) { + tests := []struct { + name string + apiBase string + }{ + {"mistral", "https://api.mistral.ai/v1"}, + {"gemini", "https://generativelanguage.googleapis.com/v1beta"}, + {"deepseek", "https://api.deepseek.com/v1"}, + {"groq", "https://api.groq.com/openai/v1"}, + {"minimax", "https://api.minimaxi.com/v1"}, + {"ollama_local", "http://localhost:11434/v1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := chatWithCacheKey(t, tt.apiBase) + if _, exists := body["prompt_cache_key"]; exists { + t.Fatalf("prompt_cache_key should NOT be sent to %s, but was included in request", tt.name) + } + }) + } +} + +func TestSupportsPromptCacheKey(t *testing.T) { + tests := []struct { + apiBase string + want bool + }{ + {"https://api.openai.com/v1", true}, + {"https://api.openai.com/v1/", true}, + {"https://myresource.openai.azure.com/openai/deployments/gpt-4", true}, + {"https://eastus.openai.azure.com/v1", true}, + {"https://api.mistral.ai/v1", false}, + {"https://generativelanguage.googleapis.com/v1beta", false}, + {"https://api.deepseek.com/v1", false}, + {"https://api.groq.com/openai/v1", false}, + {"http://localhost:11434/v1", false}, + {"https://openrouter.ai/api/v1", false}, + // Edge cases: proxy URLs with openai.com in path should NOT match + {"https://my-proxy.com/api.openai.com/v1", false}, + {"https://proxy.example.com/openai.azure.com/v1", false}, + // Malformed or empty + {"", false}, + {"not-a-url", false}, + } + for _, tt := range tests { + if got := supportsPromptCacheKey(tt.apiBase); got != tt.want { + t.Errorf("supportsPromptCacheKey(%q) = %v, want %v", tt.apiBase, got, tt.want) + } + } +} + func TestSerializeMessages_StripsSystemParts(t *testing.T) { messages := []protocoltypes.Message{ { diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go index 8255db5f9..fdfc899f9 100644 --- a/pkg/routing/route_test.go +++ b/pkg/routing/route_test.go @@ -11,7 +11,7 @@ func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *co Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: "/tmp/picoclaw-test", - Model: "gpt-4", + ModelName: "gpt-4", }, List: agents, }, diff --git a/pkg/session/manager.go b/pkg/session/manager.go index a31dbd55c..ef720b7c5 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -32,7 +32,7 @@ func NewSessionManager(storage string) *SessionManager { } if storage != "" { - os.MkdirAll(storage, 0o755) + os.MkdirAll(storage, 0o700) sm.loadSessions() } @@ -216,7 +216,7 @@ func (sm *SessionManager) Save(key string) error { _ = tmpFile.Close() return err } - if err := tmpFile.Chmod(0o644); err != nil { + if err := tmpFile.Chmod(0o600); err != nil { _ = tmpFile.Close() return err } diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index 30d84635a..f5985a662 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -10,14 +10,15 @@ import ( "regexp" "strings" + "github.com/gomarkdown/markdown" + "github.com/gomarkdown/markdown/ast" + "github.com/gomarkdown/markdown/parser" + "gopkg.in/yaml.v3" + "github.com/sipeed/picoclaw/pkg/logger" ) -var ( - namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`) - reFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`) - reStripFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`) -) +var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`) const ( MaxNameLength = 64 @@ -226,11 +227,20 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata { return nil } - frontmatter := sl.extractFrontmatter(string(content)) + frontmatter, bodyContent := splitFrontmatter(string(content)) + dirName := filepath.Base(filepath.Dir(skillPath)) + title, bodyDescription := extractMarkdownMetadata(bodyContent) + + metadata := &SkillMetadata{ + Name: dirName, + Description: bodyDescription, + } + if title != "" && namePattern.MatchString(title) && len(title) <= MaxNameLength { + metadata.Name = title + } + if frontmatter == "" { - return &SkillMetadata{ - Name: filepath.Base(filepath.Dir(skillPath)), - } + return metadata } // Try JSON first (for backward compatibility) @@ -239,60 +249,133 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata { Description string `json:"description"` } if err := json.Unmarshal([]byte(frontmatter), &jsonMeta); err == nil { - return &SkillMetadata{ - Name: jsonMeta.Name, - Description: jsonMeta.Description, + if jsonMeta.Name != "" { + metadata.Name = jsonMeta.Name } + if jsonMeta.Description != "" { + metadata.Description = jsonMeta.Description + } + return metadata } // Fall back to simple YAML parsing yamlMeta := sl.parseSimpleYAML(frontmatter) - return &SkillMetadata{ - Name: yamlMeta["name"], - Description: yamlMeta["description"], + if name := yamlMeta["name"]; name != "" { + metadata.Name = name } + if description := yamlMeta["description"]; description != "" { + metadata.Description = description + } + return metadata } -// parseSimpleYAML parses simple key: value YAML format -// Example: name: github\n description: "..." -// Normalizes line endings to handle \n (Unix), \r\n (Windows), and \r (classic Mac) +func extractMarkdownMetadata(content string) (title, description string) { + p := parser.NewWithExtensions(parser.CommonExtensions) + doc := markdown.Parse([]byte(content), p) + if doc == nil { + return "", "" + } + + ast.WalkFunc(doc, func(node ast.Node, entering bool) ast.WalkStatus { + if !entering { + return ast.GoToNext + } + + switch n := node.(type) { + case *ast.Heading: + if title == "" && n.Level == 1 { + title = nodeText(n) + if title != "" && description != "" { + return ast.Terminate + } + } + case *ast.Paragraph: + if description == "" { + description = nodeText(n) + if title != "" && description != "" { + return ast.Terminate + } + } + } + return ast.GoToNext + }) + + return title, description +} + +func nodeText(n ast.Node) string { + var b strings.Builder + ast.WalkFunc(n, func(node ast.Node, entering bool) ast.WalkStatus { + if !entering { + return ast.GoToNext + } + + switch t := node.(type) { + case *ast.Text: + b.Write(t.Literal) + case *ast.Code: + b.Write(t.Literal) + case *ast.Softbreak, *ast.Hardbreak, *ast.NonBlockingSpace: + b.WriteByte(' ') + } + return ast.GoToNext + }) + return strings.Join(strings.Fields(b.String()), " ") +} + +// parseSimpleYAML parses YAML frontmatter and extracts known metadata fields. func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string { result := make(map[string]string) - // Normalize line endings: convert \r\n and \r to \n - normalized := strings.ReplaceAll(content, "\r\n", "\n") - normalized = strings.ReplaceAll(normalized, "\r", "\n") - - for line := range strings.SplitSeq(normalized, "\n") { - line = strings.TrimSpace(line) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - parts := strings.SplitN(line, ":", 2) - if len(parts) == 2 { - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - // Remove quotes if present - value = strings.Trim(value, "\"'") - result[key] = value - } + var meta struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + } + if err := yaml.Unmarshal([]byte(content), &meta); err != nil { + return result + } + if meta.Name != "" { + result["name"] = meta.Name + } + if meta.Description != "" { + result["description"] = meta.Description } return result } func (sl *SkillsLoader) extractFrontmatter(content string) string { - // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks - match := reFrontmatter.FindStringSubmatch(content) - if len(match) > 1 { - return match[1] - } - return "" + frontmatter, _ := splitFrontmatter(content) + return frontmatter } func (sl *SkillsLoader) stripFrontmatter(content string) string { - return reStripFrontmatter.ReplaceAllString(content, "") + _, body := splitFrontmatter(content) + return body +} + +func splitFrontmatter(content string) (frontmatter, body string) { + normalized := string(parser.NormalizeNewlines([]byte(content))) + lines := strings.Split(normalized, "\n") + if len(lines) == 0 || lines[0] != "---" { + return "", content + } + + end := -1 + for i := 1; i < len(lines); i++ { + if lines[i] == "---" { + end = i + break + } + } + if end == -1 { + return "", content + } + + frontmatter = strings.Join(lines[1:end], "\n") + body = strings.Join(lines[end+1:], "\n") + body = strings.TrimLeft(body, "\n") + return frontmatter, body } func escapeXML(s string) string { diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go index 31619f9c2..645d8b7ac 100644 --- a/pkg/skills/loader_test.go +++ b/pkg/skills/loader_test.go @@ -342,3 +342,78 @@ func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) { builtin, }, roots) } + +func TestGetSkillMetadata_UsesMarkdownParagraphWhenNoFrontmatter(t *testing.T) { + tmp := t.TempDir() + skillDir := filepath.Join(tmp, "workspace", "skills", "plain-skill") + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + + content := "# Plain Skill\n\nThis is parsed from markdown paragraph.\n" + require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte(content), 0o644)) + + sl := &SkillsLoader{} + meta := sl.getSkillMetadata(filepath.Join(skillDir, "SKILL.md")) + require.NotNil(t, meta) + assert.Equal(t, "plain-skill", meta.Name) + assert.Equal(t, "This is parsed from markdown paragraph.", meta.Description) +} + +func TestGetSkillMetadata_FrontmatterOverridesMarkdown(t *testing.T) { + tmp := t.TempDir() + skillDir := filepath.Join(tmp, "workspace", "skills", "plain-skill") + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + + content := "---\nname: frontmatter-skill\ndescription: frontmatter description\n---\n\n# Plain Skill\n\nBody description.\n" + require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte(content), 0o644)) + + sl := &SkillsLoader{} + meta := sl.getSkillMetadata(filepath.Join(skillDir, "SKILL.md")) + require.NotNil(t, meta) + assert.Equal(t, "frontmatter-skill", meta.Name) + assert.Equal(t, "frontmatter description", meta.Description) +} + +func TestGetSkillMetadata_YAMLMultilineDescription(t *testing.T) { + tmp := t.TempDir() + skillDir := filepath.Join(tmp, "workspace", "skills", "plain-skill") + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + + content := "---\nname: frontmatter-skill\ndescription: |\n line 1: with colon\n line 2\n---\n\n# Plain Skill\n\nBody description.\n" + require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte(content), 0o644)) + + sl := &SkillsLoader{} + meta := sl.getSkillMetadata(filepath.Join(skillDir, "SKILL.md")) + require.NotNil(t, meta) + assert.Equal(t, "frontmatter-skill", meta.Name) + assert.Equal(t, "line 1: with colon\nline 2", meta.Description) +} + +func TestGetSkillMetadata_InvalidHeadingNameFallsBackToDirName(t *testing.T) { + tmp := t.TempDir() + skillDir := filepath.Join(tmp, "workspace", "skills", "valid-name") + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + + content := "# Invalid Heading Name\n\nBody description.\n" + require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte(content), 0o644)) + + sl := &SkillsLoader{} + meta := sl.getSkillMetadata(filepath.Join(skillDir, "SKILL.md")) + require.NotNil(t, meta) + assert.Equal(t, "valid-name", meta.Name) + assert.Equal(t, "Body description.", meta.Description) +} + +func TestGetSkillMetadata_IgnoresHTMLCommentBlocks(t *testing.T) { + tmp := t.TempDir() + skillDir := filepath.Join(tmp, "workspace", "skills", "biomed-skill") + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + + content := "\n\n# Biomed Skill\n\nSummarize biomedical papers.\n" + require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte(content), 0o644)) + + sl := &SkillsLoader{} + meta := sl.getSkillMetadata(filepath.Join(skillDir, "SKILL.md")) + require.NotNil(t, meta) + assert.Equal(t, "biomed-skill", meta.Name) + assert.Equal(t, "Summarize biomedical papers.", meta.Description) +} diff --git a/pkg/state/state.go b/pkg/state/state.go index 57f371f12..5da7bbde1 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -40,8 +40,8 @@ func NewManager(workspace string) *Manager { oldStateFile := filepath.Join(workspace, "state.json") // Create state directory if it doesn't exist - if err := os.MkdirAll(stateDir, 0o755); err != nil { - log.Fatalf("[FATAL] state: failed to create state directory: %v", err) + if err := os.MkdirAll(stateDir, 0o700); err != nil { + log.Printf("[WARN] state: failed to create state directory %s: %v", stateDir, err) } sm := &Manager{ diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index e5e116ef6..3924e5533 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -2,7 +2,6 @@ package state import ( "encoding/json" - "errors" "fmt" "os" "os/exec" @@ -217,10 +216,7 @@ func TestNewManager_EmptyWorkspace(t *testing.T) { } } -func TestNewManager_MkdirFailureCrashes(t *testing.T) { - // Since log.Fatalf calls os.Exit(1), we cannot test it normally - // Otherwise, the test suite would stop altogether. - // We use the standard pattern of Go: rerun this test in a subprocess. +func TestNewManager_MkdirFailureDoesNotCrash(t *testing.T) { if os.Getenv("BE_CRASHER") == "1" { tmpDir := os.Getenv("CRASH_DIR") @@ -240,15 +236,11 @@ func TestNewManager_MkdirFailureCrashes(t *testing.T) { } defer os.RemoveAll(tmpDir) - cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureCrashes") + cmd := exec.Command(os.Args[0], "-test.run=TestNewManager_MkdirFailureDoesNotCrash") cmd.Env = append(os.Environ(), "BE_CRASHER=1", "CRASH_DIR="+tmpDir) err = cmd.Run() - - var e *exec.ExitError - if errors.As(err, &e) && !e.Success() { - return + if err != nil { + t.Fatalf("NewManager should not crash when state dir creation fails, got: %v", err) } - - t.Fatalf("The process ended without error, a crash was expected via os.Exit(1). Err: %v", err) } diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 6af0aa9e1..648cc3c6c 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -8,6 +8,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -73,6 +74,10 @@ func (t *CronTool) Parameters() map[string]any { "type": "string", "description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.", }, + "command_confirm": map[string]any{ + "type": "boolean", + "description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.", + }, "at_seconds": map[string]any{ "type": "integer", "description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.", @@ -175,12 +180,17 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult deliver = d } + // GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm. + // Non-command reminders (plain messages) remain open to all channels. command, _ := args["command"].(string) + commandConfirm, _ := args["command_confirm"].(bool) if command != "" { - // Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically) - // Actually, let's keep deliver=false to let the system know it's not a simple chat message - // But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set. - // However, logically, it's not "delivered" to chat directly as is. + if !constants.IsInternalChannel(channel) { + return ErrorResult("scheduling command execution is restricted to internal channels") + } + if !commandConfirm { + return ErrorResult("command_confirm=true is required to schedule command execution") + } deliver = false } @@ -281,7 +291,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // Execute command if present if job.Payload.Command != "" { args := map[string]any{ - "command": job.Payload.Command, + "command": job.Payload.Command, + "__channel": channel, + "__chat_id": chatID, } result := t.execTool.Execute(ctx, args) diff --git a/pkg/tools/cron_test.go b/pkg/tools/cron_test.go new file mode 100644 index 000000000..1776abc65 --- /dev/null +++ b/pkg/tools/cron_test.go @@ -0,0 +1,116 @@ +package tools + +import ( + "context" + "path/filepath" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/cron" +) + +func newTestCronTool(t *testing.T) *CronTool { + t.Helper() + storePath := filepath.Join(t.TempDir(), "cron.json") + cronService := cron.NewCronService(storePath, nil) + msgBus := bus.NewMessageBus() + cfg := config.DefaultConfig() + tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg) + if err != nil { + t.Fatalf("NewCronTool() error: %v", err) + } + return tool +} + +// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels +func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "check disk", + "command": "df -h", + "command_confirm": true, + "at_seconds": float64(60), + }) + + if !result.IsError { + t.Fatal("expected command scheduling to be blocked from remote channel") + } + if !strings.Contains(result.ForLLM, "restricted to internal channels") { + t.Errorf("expected 'restricted to internal channels', got: %s", result.ForLLM) + } +} + +// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required +func TestCronTool_CommandRequiresConfirm(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "cli", "direct") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "check disk", + "command": "df -h", + "at_seconds": float64(60), + }) + + if !result.IsError { + t.Fatal("expected error when command_confirm is missing") + } + if !strings.Contains(result.ForLLM, "command_confirm=true") { + t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM) + } +} + +// TestCronTool_CommandAllowedFromInternalChannel verifies command scheduling works from internal channels +func TestCronTool_CommandAllowedFromInternalChannel(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "cli", "direct") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "check disk", + "command": "df -h", + "command_confirm": true, + "at_seconds": float64(60), + }) + + if result.IsError { + t.Fatalf("expected command scheduling to succeed from internal channel, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Cron job added") { + t.Errorf("expected 'Cron job added', got: %s", result.ForLLM) + } +} + +// TestCronTool_AddJobRequiresSessionContext verifies fail-closed when channel/chatID missing +func TestCronTool_AddJobRequiresSessionContext(t *testing.T) { + tool := newTestCronTool(t) + result := tool.Execute(context.Background(), map[string]any{ + "action": "add", + "message": "reminder", + "at_seconds": float64(60), + }) + + if !result.IsError { + t.Fatal("expected error when session context is missing") + } + if !strings.Contains(result.ForLLM, "no session context") { + t.Errorf("expected 'no session context' message, got: %s", result.ForLLM) + } +} + +// TestCronTool_NonCommandJobAllowedFromRemoteChannel verifies regular reminders work from any channel +func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) { + tool := newTestCronTool(t) + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{ + "action": "add", + "message": "time to stretch", + "at_seconds": float64(600), + }) + + if result.IsError { + t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index b8a811d03..67e2ad257 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -14,6 +14,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" ) type ExecTool struct { @@ -23,6 +24,7 @@ type ExecTool struct { allowPatterns []*regexp.Regexp customAllowPatterns []*regexp.Regexp restrictToWorkspace bool + allowRemote bool } var ( @@ -100,10 +102,12 @@ func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) { func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) { denyPatterns := make([]*regexp.Regexp, 0) customAllowPatterns := make([]*regexp.Regexp, 0) + allowRemote := true if config != nil { execConfig := config.Tools.Exec enableDenyPatterns := execConfig.EnableDenyPatterns + allowRemote = execConfig.AllowRemote if enableDenyPatterns { denyPatterns = append(denyPatterns, defaultDenyPatterns...) if len(execConfig.CustomDenyPatterns) > 0 { @@ -143,6 +147,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf allowPatterns: nil, customAllowPatterns: customAllowPatterns, restrictToWorkspace: restrict, + allowRemote: allowRemote, }, nil } @@ -177,6 +182,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult return ErrorResult("command is required") } + // GHSA-pv8c-p6jf-3fpp: block exec from remote channels (e.g. Telegram webhooks) + // unless explicitly opted-in via config. Fail-closed: empty channel = blocked. + if !t.allowRemote { + channel := ToolChannel(ctx) + if channel == "" { + channel, _ = args["__channel"].(string) + } + channel = strings.TrimSpace(channel) + if channel == "" || !constants.IsInternalChannel(channel) { + return ErrorResult("exec is restricted to internal channels") + } + } + cwd := t.workingDir if wd, ok := args["working_dir"].(string); ok && wd != "" { if t.restrictToWorkspace && t.workingDir != "" { @@ -201,6 +219,25 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult return ErrorResult(guardError) } + // Re-resolve symlinks immediately before execution to shrink the TOCTOU window + // between validation and cmd.Dir assignment. + if t.restrictToWorkspace && t.workingDir != "" && cwd != t.workingDir { + resolved, err := filepath.EvalSymlinks(cwd) + if err != nil { + return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err)) + } + absWorkspace, _ := filepath.Abs(t.workingDir) + wsResolved, _ := filepath.EvalSymlinks(absWorkspace) + if wsResolved == "" { + wsResolved = absWorkspace + } + rel, err := filepath.Rel(wsResolved, resolved) + if err != nil || !filepath.IsLocal(rel) { + return ErrorResult("Command blocked by safety guard (working directory escaped workspace)") + } + cwd = resolved + } + // timeout == 0 means no timeout var cmdCtx context.Context var cancel context.CancelFunc diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index ff9ea4a15..90265e5bd 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -301,6 +301,85 @@ func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) { } } +// TestShellTool_RemoteChannelBlockedByDefault verifies exec is blocked for remote channels +func TestShellTool_RemoteChannelBlockedByDefault(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = false + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + + if !result.IsError { + t.Fatal("expected remote-channel exec to be blocked") + } + if !strings.Contains(result.ForLLM, "restricted to internal channels") { + t.Errorf("expected 'restricted to internal channels' message, got: %s", result.ForLLM) + } +} + +// TestShellTool_InternalChannelAllowed verifies exec is allowed for internal channels +func TestShellTool_InternalChannelAllowed(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = false + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + ctx := WithToolContext(context.Background(), "cli", "direct") + result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + + if result.IsError { + t.Fatalf("expected internal channel exec to succeed, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "hi") { + t.Errorf("expected output to contain 'hi', got: %s", result.ForLLM) + } +} + +// TestShellTool_EmptyChannelBlockedWhenNotAllowRemote verifies fail-closed when no channel context +func TestShellTool_EmptyChannelBlockedWhenNotAllowRemote(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = false + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "command": "echo hi", + }) + + if !result.IsError { + t.Fatal("expected exec with empty channel to be blocked when allowRemote=false") + } +} + +// TestShellTool_AllowRemoteBypassesChannelCheck verifies allowRemote=true permits any channel +func TestShellTool_AllowRemoteBypassesChannelCheck(t *testing.T) { + cfg := &config.Config{} + cfg.Tools.Exec.EnableDenyPatterns = true + cfg.Tools.Exec.AllowRemote = true + + tool, err := NewExecToolWithConfig("", false, cfg) + if err != nil { + t.Fatalf("NewExecToolWithConfig() error: %v", err) + } + ctx := WithToolContext(context.Background(), "telegram", "chat-1") + result := tool.Execute(ctx, map[string]any{"command": "echo hi"}) + + if result.IsError { + t.Fatalf("expected allowRemote=true to permit remote channel, got: %s", result.ForLLM) + } +} + // TestShellTool_RestrictToWorkspace verifies workspace restriction func TestShellTool_RestrictToWorkspace(t *testing.T) { tmpDir := t.TempDir() diff --git a/pkg/tools/web.go b/pkg/tools/web.go index e248ea966..003cd860c 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/url" "regexp" @@ -818,6 +819,10 @@ func NewWebFetchTool(maxChars int, fetchLimitBytes int64) (*WebFetchTool, error) return NewWebFetchToolWithProxy(maxChars, "", fetchLimitBytes) } +// allowPrivateWebFetchHosts controls whether loopback/private hosts are allowed. +// This is false in normal runtime to reduce SSRF exposure, and tests can override it temporarily. +var allowPrivateWebFetchHosts atomic.Bool + func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) (*WebFetchTool, error) { if maxChars <= 0 { maxChars = defaultMaxChars @@ -826,10 +831,20 @@ func NewWebFetchToolWithProxy(maxChars int, proxy string, fetchLimitBytes int64) if err != nil { return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err) } + if transport, ok := client.Transport.(*http.Transport); ok { + dialer := &net.Dialer{ + Timeout: 15 * time.Second, + KeepAlive: 30 * time.Second, + } + transport.DialContext = newSafeDialContext(dialer) + } client.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) >= maxRedirects { return fmt.Errorf("stopped after %d redirects", maxRedirects) } + if isObviousPrivateHost(req.URL.Hostname()) { + return fmt.Errorf("redirect target is private or local network host") + } return nil } if fetchLimitBytes <= 0 { @@ -888,6 +903,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe return ErrorResult("missing domain in URL") } + // Lightweight pre-flight: block obvious localhost/literal-IP without DNS resolution. + // The real SSRF guard is newSafeDialContext at connect time. + hostname := parsedURL.Hostname() + if isObviousPrivateHost(hostname) { + return ErrorResult("fetching private or local network hosts is not allowed") + } + maxChars := t.maxChars if mc, ok := args["maxChars"].(float64); ok { if int(mc) > 100 { @@ -901,7 +923,6 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe } req.Header.Set("User-Agent", userAgent) - resp, err := t.client.Do(req) if err != nil { return ErrorResult(fmt.Sprintf("request failed: %v", err)) @@ -992,3 +1013,127 @@ func (t *WebFetchTool) extractText(htmlContent string) string { return strings.Join(cleanLines, "\n") } + +// newSafeDialContext re-resolves DNS at connect time to mitigate DNS rebinding (TOCTOU) +// where a hostname resolves to a public IP during pre-flight but a private IP at connect time. +func newSafeDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return func(ctx context.Context, network, address string) (net.Conn, error) { + if allowPrivateWebFetchHosts.Load() { + return dialer.DialContext(ctx, network, address) + } + + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("invalid target address %q: %w", address, err) + } + if host == "" { + return nil, fmt.Errorf("empty target host") + } + + if ip := net.ParseIP(host); ip != nil { + if isPrivateOrRestrictedIP(ip) { + return nil, fmt.Errorf("blocked private or local target: %s", host) + } + return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + } + + ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve %s: %w", host, err) + } + + attempted := 0 + var lastErr error + for _, ipAddr := range ipAddrs { + if isPrivateOrRestrictedIP(ipAddr.IP) { + continue + } + attempted++ + conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port)) + if err == nil { + return conn, nil + } + lastErr = err + } + + if attempted == 0 { + return nil, fmt.Errorf("all resolved addresses for %s are private or restricted", host) + } + if lastErr != nil { + return nil, fmt.Errorf("failed connecting to public addresses for %s: %w", host, lastErr) + } + return nil, fmt.Errorf("failed connecting to public addresses for %s", host) + } +} + +// isObviousPrivateHost performs a lightweight, no-DNS check for obviously private hosts. +// It catches localhost, literal private IPs, and empty hosts. It does NOT resolve DNS — +// the real SSRF guard is newSafeDialContext which checks IPs at connect time. +func isObviousPrivateHost(host string) bool { + if allowPrivateWebFetchHosts.Load() { + return false + } + + h := strings.ToLower(strings.TrimSpace(host)) + h = strings.TrimSuffix(h, ".") + if h == "" { + return true + } + + if h == "localhost" || strings.HasSuffix(h, ".localhost") { + return true + } + + if ip := net.ParseIP(h); ip != nil { + return isPrivateOrRestrictedIP(ip) + } + + return false +} + +// isPrivateOrRestrictedIP returns true for IPs that should never be reached via web_fetch: +// RFC 1918, loopback, link-local (incl. cloud metadata 169.254.x.x), carrier-grade NAT, +// IPv6 unique-local (fc00::/7), 6to4 (2002::/16), and Teredo (2001:0000::/32). +func isPrivateOrRestrictedIP(ip net.IP) bool { + if ip == nil { + return true + } + + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsMulticast() || ip.IsUnspecified() { + return true + } + + if ip4 := ip.To4(); ip4 != nil { + // IPv4 private, loopback, link-local, and carrier-grade NAT ranges. + if ip4[0] == 10 || + ip4[0] == 127 || + ip4[0] == 0 || + (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) || + (ip4[0] == 192 && ip4[1] == 168) || + (ip4[0] == 169 && ip4[1] == 254) || + (ip4[0] == 100 && ip4[1] >= 64 && ip4[1] <= 127) { + return true + } + return false + } + + if len(ip) == net.IPv6len { + // IPv6 unique local addresses (fc00::/7) + if (ip[0] & 0xfe) == 0xfc { + return true + } + // 6to4 addresses (2002::/16): check the embedded IPv4 at bytes [2:6]. + if ip[0] == 0x20 && ip[1] == 0x02 { + embedded := net.IPv4(ip[2], ip[3], ip[4], ip[5]) + return isPrivateOrRestrictedIP(embedded) + } + // Teredo (2001:0000::/32): client IPv4 is at bytes [12:16], XOR-inverted. + if ip[0] == 0x20 && ip[1] == 0x01 && ip[2] == 0x00 && ip[3] == 0x00 { + client := net.IPv4(ip[12]^0xff, ip[13]^0xff, ip[14]^0xff, ip[15]^0xff) + return isPrivateOrRestrictedIP(client) + } + } + + return false +} diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 188fb8adb..0737d2087 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "net/http" "net/http/httptest" "strings" @@ -18,6 +19,8 @@ const testFetchLimit = int64(10 * 1024 * 1024) // TestWebTool_WebFetch_Success verifies successful URL fetching func TestWebTool_WebFetch_Success(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) @@ -55,6 +58,8 @@ func TestWebTool_WebFetch_Success(t *testing.T) { // TestWebTool_WebFetch_JSON verifies JSON content handling func TestWebTool_WebFetch_JSON(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + testData := map[string]string{"key": "value", "number": "123"} expectedJSON, _ := json.MarshalIndent(testData, "", " ") @@ -163,6 +168,8 @@ func TestWebTool_WebFetch_MissingURL(t *testing.T) { // TestWebTool_WebFetch_Truncation verifies content truncation func TestWebTool_WebFetch_Truncation(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + longContent := strings.Repeat("x", 20000) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -205,6 +212,8 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } func TestWebFetchTool_PayloadTooLarge(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + // Create a mock HTTP server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") @@ -290,6 +299,8 @@ func TestWebTool_WebSearch_MissingQuery(t *testing.T) { // TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) @@ -404,6 +415,205 @@ func TestWebFetchTool_extractText(t *testing.T) { } } +func withPrivateWebFetchHostsAllowed(t *testing.T) { + t.Helper() + previous := allowPrivateWebFetchHosts.Load() + allowPrivateWebFetchHosts.Store(true) + t.Cleanup(func() { + allowPrivateWebFetchHosts.Store(previous) + }) +} + +func TestWebTool_WebFetch_PrivateHostBlocked(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://127.0.0.1:0", + }) + + if !result.IsError { + t.Errorf("expected error for private host URL, got success") + } + if !strings.Contains(result.ForLLM, "private or local network") && + !strings.Contains(result.ForUser, "private or local network") { + t.Errorf("expected private host block message, got %q", result.ForLLM) + } +} + +func TestWebTool_WebFetch_PrivateHostAllowedForTests(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + })) + defer server.Close() + + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": server.URL, + }) + + if result.IsError { + t.Errorf("expected success when private host access is allowed in tests, got %q", result.ForLLM) + } +} + +// TestWebFetch_BlocksIPv4MappedIPv6Loopback verifies ::ffff:127.0.0.1 is blocked +func TestWebFetch_BlocksIPv4MappedIPv6Loopback(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[::ffff:127.0.0.1]:0", + }) + + if !result.IsError { + t.Error("expected error for IPv4-mapped IPv6 loopback URL, got success") + } +} + +// TestWebFetch_BlocksMetadataIP verifies 169.254.169.254 is blocked +func TestWebFetch_BlocksMetadataIP(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://169.254.169.254/latest/meta-data", + }) + + if !result.IsError { + t.Error("expected error for cloud metadata IP, got success") + } +} + +// TestWebFetch_BlocksIPv6UniqueLocal verifies fc00::/7 addresses are blocked +func TestWebFetch_BlocksIPv6UniqueLocal(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[fd00::1]:0", + }) + + if !result.IsError { + t.Error("expected error for IPv6 unique local address, got success") + } +} + +// TestWebFetch_Blocks6to4WithPrivateEmbed verifies 6to4 with private embedded IPv4 is blocked +func TestWebFetch_Blocks6to4WithPrivateEmbed(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + // 2002:7f00:0001::1 embeds 127.0.0.1 + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[2002:7f00:0001::1]:0", + }) + + if !result.IsError { + t.Error("expected error for 6to4 with private embedded IPv4, got success") + } +} + +// TestWebFetch_Allows6to4WithPublicEmbed verifies 6to4 with public embedded IPv4 is NOT blocked +func TestWebFetch_Allows6to4WithPublicEmbed(t *testing.T) { + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + // 2002:0801:0101::1 embeds 8.1.1.1 (public) — pre-flight should pass, + // connection will fail (no listener) but that's after the SSRF check. + result := tool.Execute(context.Background(), map[string]any{ + "url": "http://[2002:0801:0101::1]:0", + }) + + // Should NOT be blocked by SSRF check — error should be connection failure, not "private" + if result.IsError && strings.Contains(result.ForLLM, "private") { + t.Error("6to4 with public embedded IPv4 should not be blocked as private") + } +} + +// TestWebFetch_RedirectToPrivateBlocked verifies redirects to private IPs are blocked +func TestWebFetch_RedirectToPrivateBlocked(t *testing.T) { + withPrivateWebFetchHostsAllowed(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Redirect to a private IP + http.Redirect(w, r, "http://10.0.0.1/secret", http.StatusFound) + })) + defer server.Close() + + // Temporarily disable private host allowance for the redirect check + allowPrivateWebFetchHosts.Store(false) + defer allowPrivateWebFetchHosts.Store(true) + + tool, err := NewWebFetchTool(50000, testFetchLimit) + if err != nil { + t.Fatalf("Failed to create web fetch tool: %v", err) + } + result := tool.Execute(context.Background(), map[string]any{ + "url": server.URL, + }) + + if !result.IsError { + t.Error("expected error when redirecting to private IP, got success") + } +} + +// TestIsPrivateOrRestrictedIP_Table tests IP classification logic +func TestIsPrivateOrRestrictedIP_Table(t *testing.T) { + tests := []struct { + ip string + blocked bool + desc string + }{ + {"127.0.0.1", true, "IPv4 loopback"}, + {"10.0.0.1", true, "IPv4 private class A"}, + {"172.16.0.1", true, "IPv4 private class B"}, + {"192.168.1.1", true, "IPv4 private class C"}, + {"169.254.169.254", true, "link-local / cloud metadata"}, + {"100.64.0.1", true, "carrier-grade NAT"}, + {"0.0.0.0", true, "unspecified"}, + {"8.8.8.8", false, "public DNS"}, + {"1.1.1.1", false, "public DNS"}, + {"::1", true, "IPv6 loopback"}, + {"::ffff:127.0.0.1", true, "IPv4-mapped IPv6 loopback"}, + {"::ffff:10.0.0.1", true, "IPv4-mapped IPv6 private"}, + {"fc00::1", true, "IPv6 unique local"}, + {"fd00::1", true, "IPv6 unique local"}, + {"2002:7f00:0001::1", true, "6to4 with embedded 127.x (private)"}, + {"2002:0a00:0001::1", true, "6to4 with embedded 10.0.0.1 (private)"}, + {"2002:0801:0101::1", false, "6to4 with embedded 8.1.1.1 (public)"}, + {"2001:0000:4136:e378:8000:63bf:f5ff:fffe", true, "Teredo with client 10.0.0.1 (private)"}, + {"2001:0000:4136:e378:8000:63bf:f7f6:fefe", false, "Teredo with client 8.9.1.1 (public)"}, + {"2607:f8b0:4004:800::200e", false, "public IPv6 (Google)"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("failed to parse IP: %s", tt.ip) + } + got := isPrivateOrRestrictedIP(ip) + if got != tt.blocked { + t.Errorf("isPrivateOrRestrictedIP(%s) = %v, want %v", tt.ip, got, tt.blocked) + } + }) + } +} + // TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain func TestWebTool_WebFetch_MissingDomain(t *testing.T) { tool, err := NewWebFetchTool(50000, testFetchLimit) diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index e949d7a22..5b18612b1 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -166,11 +166,7 @@ func (t *GroqTranscriber) Name() string { // DetectTranscriber inspects cfg and returns the appropriate Transcriber, or // nil if no supported transcription provider is configured. func DetectTranscriber(cfg *config.Config) Transcriber { - // Direct Groq provider config takes priority. - if key := cfg.Providers.Groq.APIKey; key != "" { - return NewGroqTranscriber(key) - } - // Fall back to any model-list entry that uses the groq/ protocol. + // return any model-list entry that uses the groq/ protocol. for _, mc := range cfg.ModelList { if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" { return NewGroqTranscriber(mc.APIKey) diff --git a/pkg/voice/transcriber_test.go b/pkg/voice/transcriber_test.go index 9b6add333..e7d10c40f 100644 --- a/pkg/voice/transcriber_test.go +++ b/pkg/voice/transcriber_test.go @@ -34,15 +34,6 @@ func TestDetectTranscriber(t *testing.T) { cfg: &config.Config{}, wantNil: true, }, - { - name: "groq provider key", - cfg: &config.Config{ - Providers: config.ProvidersConfig{ - Groq: config.ProviderConfig{APIKey: "sk-groq-direct"}, - }, - }, - wantName: "groq", - }, { name: "groq via model list", cfg: &config.Config{ @@ -65,9 +56,6 @@ func TestDetectTranscriber(t *testing.T) { { name: "provider key takes priority over model list", cfg: &config.Config{ - Providers: config.ProvidersConfig{ - Groq: config.ProviderConfig{APIKey: "sk-groq-direct"}, - }, ModelList: []config.ModelConfig{ {Model: "groq/llama-3.3-70b", APIKey: "sk-groq-model"}, }, diff --git a/web/backend/api/config.go b/web/backend/api/config.go index f160b42b6..091e3fbae 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "os" "github.com/sipeed/picoclaw/pkg/config" ) @@ -17,36 +16,11 @@ func (h *Handler) registerConfigRoutes(mux *http.ServeMux) { mux.HandleFunc("PATCH /api/config", h.handlePatchConfig) } -// loadFilteredConfig loads the configuration and filters out default placeholder credentials -// (like API limits/keys) if the configuration file has not been created yet by the user. -func (h *Handler) loadFilteredConfig() (*config.Config, error) { - cfg, err := config.LoadConfig(h.configPath) - if err != nil { - return nil, err - } - - configExists := false - if h.configPath != "" { - if _, err := os.Stat(h.configPath); err == nil { - configExists = true - } - } - - if !configExists { - for i := range cfg.ModelList { - cfg.ModelList[i].APIKey = "" - cfg.ModelList[i].AuthMethod = "" - } - } - - return cfg, nil -} - // handleGetConfig returns the complete system configuration. // // GET /api/config func (h *Handler) handleGetConfig(w http.ResponseWriter, r *http.Request) { - cfg, err := h.loadFilteredConfig() + cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return @@ -74,6 +48,9 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) return } + if execAllowRemoteOmitted(body) { + cfg.Tools.Exec.AllowRemote = config.DefaultConfig().Tools.Exec.AllowRemote + } if errs := validateConfig(&cfg); len(errs) > 0 { w.Header().Set("Content-Type", "application/json") @@ -94,6 +71,20 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) } +func execAllowRemoteOmitted(body []byte) bool { + var raw struct { + Tools *struct { + Exec *struct { + AllowRemote *bool `json:"allow_remote"` + } `json:"exec"` + } `json:"tools"` + } + if err := json.Unmarshal(body, &raw); err != nil { + return false + } + return raw.Tools == nil || raw.Tools.Exec == nil || raw.Tools.Exec.AllowRemote == nil +} + // handlePatchConfig partially updates the system configuration using JSON Merge Patch (RFC 7396). // Only the fields present in the request body will be updated; all other fields remain unchanged. // diff --git a/web/backend/api/config_test.go b/web/backend/api/config_test.go new file mode 100644 index 000000000..29811e37e --- /dev/null +++ b/web/backend/api/config_test.go @@ -0,0 +1,88 @@ +package api + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleUpdateConfig_PreservesExecAllowRemoteDefaultWhenOmitted(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest(http.MethodPut, "/api/config", bytes.NewBufferString(`{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace" + } + }, + "model_list": [ + { + "model_name": "custom-default", + "model": "openai/gpt-4o", + "api_key": "sk-default" + } + ] + }`)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if !cfg.Tools.Exec.AllowRemote { + t.Fatal("tools.exec.allow_remote should remain true when omitted from PUT /api/config") + } +} + +func TestHandleUpdateConfig_DoesNotInheritDefaultModelFields(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest(http.MethodPut, "/api/config", bytes.NewBufferString(`{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace" + } + }, + "model_list": [ + { + "model_name": "custom-default", + "model": "openai/gpt-4o", + "api_key": "sk-default" + } + ] + }`)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + if got := cfg.ModelList[0].APIBase; got != "" { + t.Fatalf("model_list[0].api_base = %q, want empty string", got) + } +} diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 8f86dd73d..41f702e32 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -10,7 +10,6 @@ import ( "net/http" "os" "os/exec" - "path/filepath" "runtime" "strconv" "strings" @@ -19,6 +18,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/web/backend/utils" ) // gateway holds the state for the managed gateway process. @@ -36,6 +36,7 @@ var gateway = struct { func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus) mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents) + mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs) mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart) mux.HandleFunc("POST /api/gateway/stop", h.handleGatewayStop) mux.HandleFunc("POST /api/gateway/restart", h.handleGatewayRestart) @@ -89,11 +90,12 @@ func (h *Handler) gatewayStartReady() (bool, string, error) { return false, fmt.Sprintf("default model %q is invalid", modelName), nil } - hasCredential := strings.TrimSpace(modelCfg.APIKey) != "" || - strings.TrimSpace(modelCfg.AuthMethod) != "" - if !hasCredential { + if !hasModelConfiguration(*modelCfg) { return false, fmt.Sprintf("default model %q has no credentials configured", modelName), nil } + if requiresRuntimeProbe(*modelCfg) && !probeLocalModelAvailability(*modelCfg) { + return false, fmt.Sprintf("default model %q is not reachable", modelName), nil + } return true, "", nil } @@ -131,14 +133,18 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool { func (h *Handler) startGatewayLocked() (int, error) { // Locate the picoclaw executable - execPath := findPicoclawBinary() + execPath := utils.FindPicoclawBinary() cmd := exec.Command(execPath, "gateway") + cmd.Env = os.Environ() // Forward the launcher's config path via the environment variable that // GetConfigPath() already reads, so the gateway sub-process uses the same // config file without requiring a --config flag on the gateway subcommand. if h.configPath != "" { - cmd.Env = append(os.Environ(), "PICOCLAW_CONFIG="+h.configPath) + cmd.Env = append(cmd.Env, "PICOCLAW_CONFIG="+h.configPath) + } + if host := h.gatewayHostOverride(); host != "" { + cmd.Env = append(cmd.Env, "PICOCLAW_GATEWAY_HOST="+host) } stdoutPipe, err := cmd.StdoutPipe() @@ -207,10 +213,7 @@ func (h *Handler) startGatewayLocked() (int, error) { if err != nil { continue } - healthHost := "127.0.0.1" - if cfg.Gateway.Host != "" && cfg.Gateway.Host != "0.0.0.0" { - healthHost = cfg.Gateway.Host - } + healthHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) healthPort := cfg.Gateway.Port if healthPort == 0 { healthPort = 18790 @@ -353,6 +356,20 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) { h.handleGatewayStart(w, r) } +// handleGatewayClearLogs clears the in-memory gateway log buffer. +// +// POST /api/gateway/logs/clear +func (h *Handler) handleGatewayClearLogs(w http.ResponseWriter, r *http.Request) { + gateway.logs.Clear() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "status": "cleared", + "log_total": 0, + "log_run_id": gateway.logs.RunID(), + }) +} + // handleGatewayStatus returns the gateway run status, health info, and logs. // // GET /api/gateway/status @@ -375,9 +392,7 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { host := "127.0.0.1" port := 18790 if err == nil && cfg != nil { - if cfg.Gateway.Host != "" && cfg.Gateway.Host != "0.0.0.0" { - host = cfg.Gateway.Host - } + host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) if cfg.Gateway.Port != 0 { port = cfg.Gateway.Port } @@ -535,36 +550,6 @@ func (h *Handler) currentGatewayStatus() string { return string(encoded) } -// findPicoclawBinary locates the picoclaw executable. -// Search order: -// 1. PICOCLAW_BINARY environment variable (explicit override) -// 2. Same directory as the current executable -// 3. Falls back to "picoclaw" and relies on $PATH -func findPicoclawBinary() string { - binaryName := "picoclaw" - if runtime.GOOS == "windows" { - binaryName = "picoclaw.exe" - } - - // 1. Explicit override via environment variable - if p := os.Getenv("PICOCLAW_BINARY"); p != "" { - if info, _ := os.Stat(p); info != nil && !info.IsDir() { - return p - } - } - - // 2. Same directory as the launcher executable - if exe, err := os.Executable(); err == nil { - candidate := filepath.Join(filepath.Dir(exe), binaryName) - if info, err := os.Stat(candidate); err == nil && !info.IsDir() { - return candidate - } - } - - // 3. Fall back to PATH lookup - return "picoclaw" -} - // scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF. func scanPipe(r io.Reader, buf *LogBuffer) { scanner := bufio.NewScanner(r) diff --git a/web/backend/api/gateway_host.go b/web/backend/api/gateway_host.go new file mode 100644 index 000000000..a499c1ea2 --- /dev/null +++ b/web/backend/api/gateway_host.go @@ -0,0 +1,66 @@ +package api + +import ( + "net" + "net/http" + "strconv" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func (h *Handler) effectiveLauncherPublic() bool { + if h.serverPublicExplicit { + return h.serverPublic + } + + cfg, err := h.loadLauncherConfig() + if err == nil { + return cfg.Public + } + + return h.serverPublic +} + +func (h *Handler) gatewayHostOverride() string { + if h.effectiveLauncherPublic() { + return "0.0.0.0" + } + return "" +} + +func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string { + if override := h.gatewayHostOverride(); override != "" { + return override + } + if cfg == nil { + return "" + } + return strings.TrimSpace(cfg.Gateway.Host) +} + +func gatewayProbeHost(bindHost string) string { + if bindHost == "" || bindHost == "0.0.0.0" { + return "127.0.0.1" + } + return bindHost +} + +func requestHostName(r *http.Request) string { + reqHost, _, err := net.SplitHostPort(r.Host) + if err == nil { + return reqHost + } + if strings.TrimSpace(r.Host) != "" { + return r.Host + } + return "127.0.0.1" +} + +func (h *Handler) buildWsURL(r *http.Request, cfg *config.Config) string { + host := h.effectiveGatewayBindHost(cfg) + if host == "" || host == "0.0.0.0" { + host = requestHostName(r) + } + return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws" +} diff --git a/web/backend/api/gateway_host_test.go b/web/backend/api/gateway_host_test.go new file mode 100644 index 000000000..afd600359 --- /dev/null +++ b/web/backend/api/gateway_host_test.go @@ -0,0 +1,59 @@ +package api + +import ( + "net/http/httptest" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/web/backend/launcherconfig" +) + +func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + launcherPath := launcherconfig.PathForAppConfig(configPath) + if err := launcherconfig.Save(launcherPath, launcherconfig.Config{ + Port: 18800, + Public: false, + }); err != nil { + t.Fatalf("launcherconfig.Save() error = %v", err) + } + + h := NewHandler(configPath) + h.SetServerOptions(18800, true, true, nil) + + if got := h.gatewayHostOverride(); got != "0.0.0.0" { + t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0") + } +} + +func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + launcherPath := launcherconfig.PathForAppConfig(configPath) + if err := launcherconfig.Save(launcherPath, launcherconfig.Config{ + Port: 18800, + Public: true, + }); err != nil { + t.Fatalf("launcherconfig.Save() error = %v", err) + } + + h := NewHandler(configPath) + h.SetServerOptions(18800, false, false, nil) + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = 18790 + + req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil) + req.Host = "192.168.1.9:18800" + + if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18790/pico/ws" { + t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18790/pico/ws") + } +} + +func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) { + if got := gatewayProbeHost("0.0.0.0"); got != "127.0.0.1" { + t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1") + } +} diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index 998c133b5..c7fb4dbc8 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -6,10 +6,13 @@ import ( "net/http/httptest" "os" "path/filepath" + "strconv" "strings" "testing" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/web/backend/utils" ) func TestGatewayStartReady_NoDefaultModel(t *testing.T) { @@ -31,8 +34,9 @@ func TestGatewayStartReady_NoDefaultModel(t *testing.T) { func TestGatewayStartReady_InvalidDefaultModel(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") cfg := config.DefaultConfig() - cfg.Agents.Defaults.Model = "missing-model" - if err := config.SaveConfig(configPath, cfg); err != nil { + cfg.Agents.Defaults.ModelName = "missing-model" + err := config.SaveConfig(configPath, cfg) + if err != nil { t.Fatalf("SaveConfig() error = %v", err) } @@ -54,7 +58,8 @@ func TestGatewayStartReady_ValidDefaultModel(t *testing.T) { cfg := config.DefaultConfig() cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName cfg.ModelList[0].APIKey = "test-key" - if err := config.SaveConfig(configPath, cfg); err != nil { + err := config.SaveConfig(configPath, cfg) + if err != nil { t.Fatalf("SaveConfig() error = %v", err) } @@ -74,7 +79,8 @@ func TestGatewayStartReady_DefaultModelWithoutCredential(t *testing.T) { cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName cfg.ModelList[0].APIKey = "" cfg.ModelList[0].AuthMethod = "" - if err := config.SaveConfig(configPath, cfg); err != nil { + err := config.SaveConfig(configPath, cfg) + if err != nil { t.Fatalf("SaveConfig() error = %v", err) } @@ -91,6 +97,195 @@ func TestGatewayStartReady_DefaultModelWithoutCredential(t *testing.T) { } } +func TestGatewayStartReady_LocalModelWithoutAPIKey(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetModelProbeHooks(t) + + probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool { + return false + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://localhost:8000/v1", + }} + cfg.Agents.Defaults.ModelName = "local-vllm" + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + ready, reason, err := h.gatewayStartReady() + if err != nil { + t.Fatalf("gatewayStartReady() error = %v", err) + } + if ready { + t.Fatalf("gatewayStartReady() ready = true, want false without a running local service") + } + if !strings.Contains(reason, "not reachable") { + t.Fatalf("gatewayStartReady() reason = %q, want contains %q", reason, "not reachable") + } +} + +func TestGatewayStartReady_LocalModelWithRunningService(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetModelProbeHooks(t) + + probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool { + return apiBase == "http://127.0.0.1:8000/v1" && modelID == "custom-model" + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{{ + ModelName: "local-vllm", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + }} + cfg.Agents.Defaults.ModelName = "local-vllm" + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + ready, reason, err := h.gatewayStartReady() + if err != nil { + t.Fatalf("gatewayStartReady() error = %v", err) + } + if !ready { + t.Fatalf("gatewayStartReady() ready = false, want true with a running local service (reason=%q)", reason) + } +} + +func TestGatewayStartReady_RemoteVLLMWithAPIKeyDoesNotProbe(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetModelProbeHooks(t) + + probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool { + t.Fatalf("unexpected OpenAI-compatible probe for %q (%q)", apiBase, modelID) + return false + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{{ + ModelName: "remote-vllm", + Model: "vllm/custom-model", + APIBase: "https://models.example.com/v1", + APIKey: "remote-key", + }} + cfg.Agents.Defaults.ModelName = "remote-vllm" + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + ready, reason, err := h.gatewayStartReady() + if err != nil { + t.Fatalf("gatewayStartReady() error = %v", err) + } + if !ready { + t.Fatalf("gatewayStartReady() ready = false, want true for remote vllm with api key (reason=%q)", reason) + } +} + +func TestGatewayStartReady_LocalOllamaUsesDefaultProbeBase(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetModelProbeHooks(t) + + probeOllamaModelFunc = func(apiBase, modelID string) bool { + return apiBase == "http://localhost:11434/v1" && modelID == "llama3" + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{{ + ModelName: "local-ollama", + Model: "ollama/llama3", + }} + cfg.Agents.Defaults.ModelName = "local-ollama" + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + ready, reason, err := h.gatewayStartReady() + if err != nil { + t.Fatalf("gatewayStartReady() error = %v", err) + } + if !ready { + t.Fatalf("gatewayStartReady() ready = false, want true with default Ollama probe base (reason=%q)", reason) + } +} + +func TestGatewayStartReady_OAuthModelRequiresStoredCredential(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{{ + ModelName: "openai-oauth", + Model: "openai/gpt-5.2", + AuthMethod: "oauth", + }} + cfg.Agents.Defaults.ModelName = "openai-oauth" + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + ready, reason, err := h.gatewayStartReady() + if err != nil { + t.Fatalf("gatewayStartReady() error = %v", err) + } + if ready { + t.Fatalf("gatewayStartReady() ready = true, want false without stored credential") + } + if !strings.Contains(reason, "no credentials configured") { + t.Fatalf("gatewayStartReady() reason = %q, want contains %q", reason, "no credentials configured") + } + + err = auth.SetCredential(oauthProviderOpenAI, &auth.AuthCredential{ + AccessToken: "openai-token", + Provider: oauthProviderOpenAI, + AuthMethod: "oauth", + }) + if err != nil { + t.Fatalf("SetCredential() error = %v", err) + } + + ready, reason, err = h.gatewayStartReady() + if err != nil { + t.Fatalf("gatewayStartReady() error = %v", err) + } + if !ready { + t.Fatalf("gatewayStartReady() ready = false, want true with stored credential (reason=%q)", reason) + } +} + func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -122,6 +317,71 @@ func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) { } } +func TestGatewayClearLogsResetsBufferedHistory(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + gateway.logs.Clear() + gateway.logs.Append("first line") + gateway.logs.Append("second line") + previousRunID := gateway.logs.RunID() + + clearRec := httptest.NewRecorder() + clearReq := httptest.NewRequest(http.MethodPost, "/api/gateway/logs/clear", nil) + mux.ServeHTTP(clearRec, clearReq) + + if clearRec.Code != http.StatusOK { + t.Fatalf("clear status = %d, want %d", clearRec.Code, http.StatusOK) + } + + var clearBody map[string]any + if err := json.Unmarshal(clearRec.Body.Bytes(), &clearBody); err != nil { + t.Fatalf("unmarshal clear response: %v", err) + } + + if got := clearBody["status"]; got != "cleared" { + t.Fatalf("clear status body = %#v, want %q", got, "cleared") + } + + clearRunID, ok := clearBody["log_run_id"].(float64) + if !ok { + t.Fatalf("log_run_id missing or not number: %#v", clearBody["log_run_id"]) + } + if int(clearRunID) <= previousRunID { + t.Fatalf("log_run_id = %d, want > %d", int(clearRunID), previousRunID) + } + + statusRec := httptest.NewRecorder() + statusReq := httptest.NewRequest( + http.MethodGet, + "/api/gateway/status?log_offset=0&log_run_id="+strconv.Itoa(previousRunID), + nil, + ) + mux.ServeHTTP(statusRec, statusReq) + + if statusRec.Code != http.StatusOK { + t.Fatalf("status code = %d, want %d", statusRec.Code, http.StatusOK) + } + + var statusBody map[string]any + if err := json.Unmarshal(statusRec.Body.Bytes(), &statusBody); err != nil { + t.Fatalf("unmarshal status response: %v", err) + } + + logs, ok := statusBody["logs"].([]any) + if !ok { + t.Fatalf("logs missing or not array: %#v", statusBody["logs"]) + } + if len(logs) != 0 { + t.Fatalf("logs len = %d, want 0", len(logs)) + } + if got := statusBody["log_total"]; got != float64(0) { + t.Fatalf("log_total = %#v, want 0", got) + } +} + func TestFindPicoclawBinary_EnvOverride(t *testing.T) { // Create a temporary file to act as the mock binary tmpDir := t.TempDir() @@ -132,9 +392,9 @@ func TestFindPicoclawBinary_EnvOverride(t *testing.T) { t.Setenv("PICOCLAW_BINARY", mockBinary) - got := findPicoclawBinary() + got := utils.FindPicoclawBinary() if got != mockBinary { - t.Errorf("findPicoclawBinary() = %q, want %q", got, mockBinary) + t.Errorf("FindPicoclawBinary() = %q, want %q", got, mockBinary) } } @@ -142,9 +402,9 @@ func TestFindPicoclawBinary_EnvOverride_InvalidPath(t *testing.T) { // When PICOCLAW_BINARY points to a non-existent path, fall through to next strategy t.Setenv("PICOCLAW_BINARY", "/nonexistent/picoclaw-binary") - got := findPicoclawBinary() + got := utils.FindPicoclawBinary() // Should not return the invalid path; falls back to "picoclaw" or another found path if got == "/nonexistent/picoclaw-binary" { - t.Errorf("findPicoclawBinary() returned invalid env path %q, expected fallback", got) + t.Errorf("FindPicoclawBinary() returned invalid env path %q, expected fallback", got) } } diff --git a/web/backend/api/launcher_config_test.go b/web/backend/api/launcher_config_test.go index 5049dd88f..0d6af823c 100644 --- a/web/backend/api/launcher_config_test.go +++ b/web/backend/api/launcher_config_test.go @@ -14,7 +14,7 @@ import ( func TestGetLauncherConfigUsesRuntimeFallback(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) - h.SetServerOptions(19999, true, []string{"192.168.1.0/24"}) + h.SetServerOptions(19999, true, false, []string{"192.168.1.0/24"}) mux := http.NewServeMux() h.RegisterRoutes(mux) diff --git a/web/backend/api/log.go b/web/backend/api/log.go index ecf7d422f..f83f6f34c 100644 --- a/web/backend/api/log.go +++ b/web/backend/api/log.go @@ -4,7 +4,7 @@ import "sync" // LogBuffer is a thread-safe ring buffer that stores the most recent N log lines. // It supports incremental reads via LinesSince and tracks a runID that increments -// on each Reset (used to detect gateway restarts). +// whenever the buffer is reset or cleared so clients can detect log history resets. type LogBuffer struct { mu sync.RWMutex lines []string @@ -45,6 +45,12 @@ func (b *LogBuffer) Reset() { b.runID++ } +// Clear removes all buffered lines and increments the runID so clients treat +// subsequent reads as a new log stream. +func (b *LogBuffer) Clear() { + b.Reset() +} + // LinesSince returns lines appended after the given offset, the current total count, and the runID. // If offset >= total, no lines are returned. If offset is too old (evicted), all buffered lines are returned. func (b *LogBuffer) LinesSince(offset int) (lines []string, total int, runID int) { diff --git a/web/backend/api/model_status.go b/web/backend/api/model_status.go new file mode 100644 index 000000000..22bf5c15b --- /dev/null +++ b/web/backend/api/model_status.go @@ -0,0 +1,324 @@ +package api + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/config" +) + +const modelProbeTimeout = 800 * time.Millisecond + +var ( + probeTCPServiceFunc = probeTCPService + probeOllamaModelFunc = probeOllamaModel + probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel +) + +func hasModelConfiguration(m config.ModelConfig) bool { + authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod)) + apiKey := strings.TrimSpace(m.APIKey) + + if authMethod == "oauth" || authMethod == "token" { + if provider, ok := oauthProviderForModel(m.Model); ok { + cred, err := oauthGetCredential(provider) + if err != nil || cred == nil { + return false + } + return strings.TrimSpace(cred.AccessToken) != "" || strings.TrimSpace(cred.RefreshToken) != "" + } + return true + } + + if requiresRuntimeProbe(m) { + return true + } + + return apiKey != "" +} + +// isModelConfigured reports whether a model is currently available to use. +// Local models must be reachable; remote/API-key models only need saved config. +func isModelConfigured(m config.ModelConfig) bool { + if !hasModelConfiguration(m) { + return false + } + if requiresRuntimeProbe(m) { + return probeLocalModelAvailability(m) + } + return true +} + +func requiresRuntimeProbe(m config.ModelConfig) bool { + authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod)) + if authMethod == "local" { + return true + } + + switch modelProtocol(m.Model) { + case "claude-cli", "claudecli", "codex-cli", "codexcli", "github-copilot", "copilot": + return true + case "ollama", "vllm": + apiBase := strings.TrimSpace(m.APIBase) + return apiBase == "" || hasLocalAPIBase(apiBase) + } + + if hasLocalAPIBase(m.APIBase) { + return true + } + + return false +} + +func probeLocalModelAvailability(m config.ModelConfig) bool { + apiBase := modelProbeAPIBase(m) + protocol, modelID := splitModel(m.Model) + switch protocol { + case "ollama": + return probeOllamaModelFunc(apiBase, modelID) + case "vllm": + return probeOpenAICompatibleModelFunc(apiBase, modelID) + case "github-copilot", "copilot": + return probeTCPServiceFunc(apiBase) + case "claude-cli", "claudecli", "codex-cli", "codexcli": + return true + default: + if hasLocalAPIBase(apiBase) { + return probeOpenAICompatibleModelFunc(apiBase, modelID) + } + return false + } +} + +func modelProbeAPIBase(m config.ModelConfig) string { + if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" { + return normalizeModelProbeAPIBase(apiBase) + } + + switch modelProtocol(m.Model) { + case "ollama": + return "http://localhost:11434/v1" + case "vllm": + return "http://localhost:8000/v1" + case "github-copilot", "copilot": + return "localhost:4321" + default: + return "" + } +} + +func normalizeModelProbeAPIBase(raw string) string { + u, err := parseAPIBase(raw) + if err != nil { + return strings.TrimSpace(raw) + } + + switch strings.ToLower(u.Hostname()) { + case "0.0.0.0": + u.Host = net.JoinHostPort("127.0.0.1", u.Port()) + case "::": + u.Host = net.JoinHostPort("::1", u.Port()) + default: + return strings.TrimSpace(raw) + } + + if u.Port() == "" { + u.Host = u.Hostname() + } + + return u.String() +} + +func oauthProviderForModel(model string) (string, bool) { + switch modelProtocol(model) { + case "openai": + return oauthProviderOpenAI, true + case "anthropic": + return oauthProviderAnthropic, true + case "antigravity", "google-antigravity": + return oauthProviderGoogleAntigravity, true + default: + return "", false + } +} + +func modelProtocol(model string) string { + protocol, _ := splitModel(model) + return protocol +} + +func splitModel(model string) (protocol, modelID string) { + model = strings.ToLower(strings.TrimSpace(model)) + protocol, _, found := strings.Cut(model, "/") + if !found { + return "openai", model + } + return protocol, strings.TrimSpace(model[strings.Index(model, "/")+1:]) +} + +func hasLocalAPIBase(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return false + } + + u, err := url.Parse(raw) + if err != nil || u.Hostname() == "" { + u, err = url.Parse("//" + raw) + if err != nil { + return false + } + } + + switch strings.ToLower(u.Hostname()) { + case "localhost", "127.0.0.1", "::1", "0.0.0.0": + return true + default: + return false + } +} + +func probeTCPService(raw string) bool { + hostPort, err := hostPortFromAPIBase(raw) + if err != nil { + return false + } + + conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout) + if err != nil { + return false + } + _ = conn.Close() + return true +} + +func probeOllamaModel(apiBase, modelID string) bool { + root, err := apiRootFromAPIBase(apiBase) + if err != nil { + return false + } + + var resp struct { + Models []struct { + Name string `json:"name"` + Model string `json:"model"` + } `json:"models"` + } + if err := getJSON(root+"/api/tags", &resp); err != nil { + return false + } + + for _, model := range resp.Models { + if ollamaModelMatches(model.Name, modelID) || ollamaModelMatches(model.Model, modelID) { + return true + } + } + return false +} + +func probeOpenAICompatibleModel(apiBase, modelID string) bool { + if strings.TrimSpace(apiBase) == "" { + return false + } + + var resp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := getJSON(strings.TrimRight(strings.TrimSpace(apiBase), "/")+"/models", &resp); err != nil { + return false + } + + for _, model := range resp.Data { + if strings.EqualFold(strings.TrimSpace(model.ID), modelID) { + return true + } + } + return false +} + +func getJSON(rawURL string, out any) error { + req, err := http.NewRequest(http.MethodGet, rawURL, nil) + if err != nil { + return err + } + + client := &http.Client{Timeout: modelProbeTimeout} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status %d", resp.StatusCode) + } + + return json.NewDecoder(resp.Body).Decode(out) +} + +func apiRootFromAPIBase(raw string) (string, error) { + u, err := parseAPIBase(raw) + if err != nil { + return "", err + } + return (&url.URL{Scheme: u.Scheme, Host: u.Host}).String(), nil +} + +func hostPortFromAPIBase(raw string) (string, error) { + u, err := parseAPIBase(raw) + if err != nil { + return "", err + } + + if port := u.Port(); port != "" { + return u.Host, nil + } + switch strings.ToLower(u.Scheme) { + case "https": + return net.JoinHostPort(u.Hostname(), "443"), nil + default: + return net.JoinHostPort(u.Hostname(), "80"), nil + } +} + +func parseAPIBase(raw string) (*url.URL, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, fmt.Errorf("empty api base") + } + + u, err := url.Parse(raw) + if err == nil && u.Hostname() != "" { + return u, nil + } + + u, err = url.Parse("//" + raw) + if err != nil || u.Hostname() == "" { + return nil, fmt.Errorf("invalid api base %q", raw) + } + if u.Scheme == "" { + u.Scheme = "http" + } + return u, nil +} + +func ollamaModelMatches(candidate, want string) bool { + candidate = strings.TrimSpace(candidate) + want = strings.TrimSpace(want) + if candidate == "" || want == "" { + return false + } + if strings.EqualFold(candidate, want) { + return true + } + + base, _, _ := strings.Cut(candidate, ":") + return strings.EqualFold(base, want) +} diff --git a/web/backend/api/models.go b/web/backend/api/models.go index cb57d6f2e..2e3f3dd55 100644 --- a/web/backend/api/models.go +++ b/web/backend/api/models.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "strconv" + "sync" "github.com/sipeed/picoclaw/pkg/config" ) @@ -45,13 +46,24 @@ type modelResponse struct { // // GET /api/models func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) { - cfg, err := h.loadFilteredConfig() + cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } defaultModel := cfg.Agents.Defaults.GetModelName() + configured := make([]bool, len(cfg.ModelList)) + + var wg sync.WaitGroup + wg.Add(len(cfg.ModelList)) + for i, m := range cfg.ModelList { + go func(i int, m config.ModelConfig) { + defer wg.Done() + configured[i] = isModelConfigured(m) + }(i, m) + } + wg.Wait() models := make([]modelResponse, 0, len(cfg.ModelList)) for i, m := range cfg.ModelList { @@ -69,7 +81,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) { MaxTokensField: m.MaxTokensField, RequestTimeout: m.RequestTimeout, ThinkingLevel: m.ThinkingLevel, - Configured: m.APIKey != "" || m.AuthMethod != "", + Configured: configured[i], IsDefault: m.ModelName == defaultModel, }) } @@ -212,9 +224,6 @@ func (h *Handler) handleDeleteModel(w http.ResponseWriter, r *http.Request) { if cfg.Agents.Defaults.ModelName == deletedModelName { cfg.Agents.Defaults.ModelName = "" } - if cfg.Agents.Defaults.Model == deletedModelName { - cfg.Agents.Defaults.Model = "" - } if err := config.SaveConfig(h.configPath, cfg); err != nil { http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) diff --git a/web/backend/api/models_test.go b/web/backend/api/models_test.go new file mode 100644 index 000000000..7061eb3f7 --- /dev/null +++ b/web/backend/api/models_test.go @@ -0,0 +1,313 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +func resetModelProbeHooks(t *testing.T) { + t.Helper() + + origTCPProbe := probeTCPServiceFunc + origOllamaProbe := probeOllamaModelFunc + origOpenAIProbe := probeOpenAICompatibleModelFunc + t.Cleanup(func() { + probeTCPServiceFunc = origTCPProbe + probeOllamaModelFunc = origOllamaProbe + probeOpenAICompatibleModelFunc = origOpenAIProbe + }) +} + +func TestHandleListModels_ConfiguredStatusUsesRuntimeProbesForLocalModels(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetOAuthHooks(t) + resetModelProbeHooks(t) + + var mu sync.Mutex + var openAIProbes []string + var ollamaProbes []string + var tcpProbes []string + + probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool { + mu.Lock() + openAIProbes = append(openAIProbes, apiBase+"|"+modelID) + mu.Unlock() + return apiBase == "http://127.0.0.1:8000/v1" && modelID == "custom-model" + } + probeOllamaModelFunc = func(apiBase, modelID string) bool { + mu.Lock() + ollamaProbes = append(ollamaProbes, apiBase+"|"+modelID) + mu.Unlock() + return apiBase == "http://localhost:11434/v1" && modelID == "llama3" + } + probeTCPServiceFunc = func(apiBase string) bool { + mu.Lock() + tcpProbes = append(tcpProbes, apiBase) + mu.Unlock() + return apiBase == "http://127.0.0.1:4321" + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{ + { + ModelName: "openai-oauth", + Model: "openai/gpt-5.2", + AuthMethod: "oauth", + }, + { + ModelName: "vllm-local", + Model: "vllm/custom-model", + APIBase: "http://127.0.0.1:8000/v1", + }, + { + ModelName: "ollama-default", + Model: "ollama/llama3", + }, + { + ModelName: "vllm-remote", + Model: "vllm/custom-model", + APIBase: "https://models.example.com/v1", + APIKey: "remote-key", + }, + { + ModelName: "copilot-gpt-5.2", + Model: "github-copilot/gpt-5.2", + APIBase: "http://127.0.0.1:4321", + AuthMethod: "oauth", + }, + } + cfg.Agents.Defaults.ModelName = "openai-oauth" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/models", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Models []modelResponse `json:"models"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + got := make(map[string]bool, len(resp.Models)) + for _, model := range resp.Models { + got[model.ModelName] = model.Configured + } + + if got["openai-oauth"] { + t.Fatalf("openai oauth model configured = true, want false without stored credential") + } + if !got["vllm-local"] { + t.Fatalf("vllm local model configured = false, want true when local probe succeeds") + } + if !got["ollama-default"] { + t.Fatalf("ollama default model configured = false, want true when default local probe succeeds") + } + if !got["vllm-remote"] { + t.Fatalf("remote vllm model configured = false, want true with api_key") + } + if !got["copilot-gpt-5.2"] { + t.Fatalf("copilot model configured = false, want true when local bridge probe succeeds") + } + if len(openAIProbes) != 1 || openAIProbes[0] != "http://127.0.0.1:8000/v1|custom-model" { + t.Fatalf("openAI probes = %#v, want only local vllm probe", openAIProbes) + } + if len(ollamaProbes) != 1 || ollamaProbes[0] != "http://localhost:11434/v1|llama3" { + t.Fatalf("ollama probes = %#v, want default local probe", ollamaProbes) + } + if len(tcpProbes) != 1 || tcpProbes[0] != "http://127.0.0.1:4321" { + t.Fatalf("tcp probes = %#v, want only local copilot probe", tcpProbes) + } +} + +func TestHandleListModels_ConfiguredStatusForOAuthModelWithCredential(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetOAuthHooks(t) + resetModelProbeHooks(t) + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{{ + ModelName: "claude-oauth", + Model: "anthropic/claude-sonnet-4.6", + AuthMethod: "oauth", + }} + cfg.Agents.Defaults.ModelName = "claude-oauth" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + if err := auth.SetCredential(oauthProviderAnthropic, &auth.AuthCredential{ + AccessToken: "anthropic-token", + Provider: oauthProviderAnthropic, + AuthMethod: "oauth", + }); err != nil { + t.Fatalf("SetCredential() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/models", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Models []modelResponse `json:"models"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Models) != 1 { + t.Fatalf("len(models) = %d, want 1", len(resp.Models)) + } + if !resp.Models[0].Configured { + t.Fatalf("oauth model configured = false, want true with stored credential") + } +} + +func TestHandleListModels_ProbesLocalModelsConcurrently(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetOAuthHooks(t) + resetModelProbeHooks(t) + + started := make(chan string, 2) + release := make(chan struct{}) + + probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool { + started <- apiBase + "|" + modelID + <-release + return true + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{ + { + ModelName: "local-vllm-a", + Model: "vllm/custom-a", + APIBase: "http://127.0.0.1:8000/v1", + }, + { + ModelName: "local-vllm-b", + Model: "vllm/custom-b", + APIBase: "http://127.0.0.1:8001/v1", + }, + } + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + recCh := make(chan *httptest.ResponseRecorder, 1) + go func() { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/models", nil) + mux.ServeHTTP(rec, req) + recCh <- rec + }() + + for i := 0; i < 2; i++ { + select { + case <-started: + case <-time.After(200 * time.Millisecond): + t.Fatal("expected both local probes to start before the first one completed") + } + } + close(release) + + rec := <-recCh + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } +} + +func TestHandleListModels_NormalizesWildcardLocalAPIBaseForProbe(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + resetOAuthHooks(t) + resetModelProbeHooks(t) + + var gotProbe string + probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool { + gotProbe = apiBase + "|" + modelID + return apiBase == "http://127.0.0.1:8000/v1" && modelID == "custom-model" + } + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.ModelList = []config.ModelConfig{{ + ModelName: "vllm-local", + Model: "vllm/custom-model", + APIBase: "http://0.0.0.0:8000/v1", + }} + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/models", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + Models []modelResponse `json:"models"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Models) != 1 { + t.Fatalf("len(models) = %d, want 1", len(resp.Models)) + } + if !resp.Models[0].Configured { + t.Fatal("wildcard-bound local model configured = false, want true after probe host normalization") + } + if gotProbe != "http://127.0.0.1:8000/v1|custom-model" { + t.Fatalf("probe api base = %q, want %q", gotProbe, "http://127.0.0.1:8000/v1|custom-model") + } +} diff --git a/web/backend/api/oauth.go b/web/backend/api/oauth.go index 04cd595f2..e264c2900 100644 --- a/web/backend/api/oauth.go +++ b/web/backend/api/oauth.go @@ -744,17 +744,6 @@ func (h *Handler) syncProviderAuthMethod(provider, authMethod string) error { return err } - switch provider { - case oauthProviderOpenAI: - cfg.Providers.OpenAI.AuthMethod = authMethod - case oauthProviderAnthropic: - cfg.Providers.Anthropic.AuthMethod = authMethod - case oauthProviderGoogleAntigravity: - cfg.Providers.Antigravity.AuthMethod = authMethod - default: - return fmt.Errorf("unsupported provider %q", provider) - } - found := false for i := range cfg.ModelList { if modelBelongsToProvider(provider, cfg.ModelList[i].Model) { diff --git a/web/backend/api/oauth_test.go b/web/backend/api/oauth_test.go index 2103e1efc..78249be40 100644 --- a/web/backend/api/oauth_test.go +++ b/web/backend/api/oauth_test.go @@ -166,7 +166,6 @@ func TestOAuthLogoutClearsCredentialAndConfig(t *testing.T) { if err != nil { t.Fatalf("LoadConfig error: %v", err) } - cfg.Providers.OpenAI.AuthMethod = "oauth" cfg.ModelList = append(cfg.ModelList, config.ModelConfig{ ModelName: "gpt-5.2", Model: "openai/gpt-5.2", @@ -208,9 +207,6 @@ func TestOAuthLogoutClearsCredentialAndConfig(t *testing.T) { if err != nil { t.Fatalf("LoadConfig error: %v", err) } - if updated.Providers.OpenAI.AuthMethod != "" { - t.Fatalf("providers.openai.auth_method = %q, want empty", updated.Providers.OpenAI.AuthMethod) - } for _, m := range updated.ModelList { if strings.HasPrefix(m.Model, "openai/") && m.AuthMethod != "" { t.Fatalf("openai model auth_method = %q, want empty", m.AuthMethod) diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index fc942d51c..a4590dcde 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -5,9 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" - "net" "net/http" - "strconv" "time" "github.com/sipeed/picoclaw/pkg/config" @@ -30,7 +28,7 @@ func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) { return } - wsURL := buildWsURL(r, cfg) + wsURL := h.buildWsURL(r, cfg) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{ @@ -58,7 +56,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { return } - wsURL := fmt.Sprintf("ws://%s/pico/ws", net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port))) + wsURL := h.buildWsURL(r, cfg) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{ @@ -123,7 +121,7 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { return } - wsURL := buildWsURL(r, cfg) + wsURL := h.buildWsURL(r, cfg) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]any{ @@ -134,22 +132,6 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { }) } -// buildWsURL creates a WebSocket URL for the Pico Channel. -// When the gateway host is "0.0.0.0" or empty, it uses the hostname from the -// incoming HTTP request so the browser gets a connectable address. -func buildWsURL(r *http.Request, cfg *config.Config) string { - host := cfg.Gateway.Host - if host == "" || host == "0.0.0.0" { - // Use the hostname the browser used to reach this backend - reqHost, _, err := net.SplitHostPort(r.Host) - if err != nil { - reqHost = r.Host // r.Host might not have a port - } - host = reqHost - } - return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws" -} - // generateSecureToken creates a random 32-character hex string. func generateSecureToken() string { b := make([]byte, 16) diff --git a/web/backend/api/router.go b/web/backend/api/router.go index c250724d1..5f081dee9 100644 --- a/web/backend/api/router.go +++ b/web/backend/api/router.go @@ -9,13 +9,14 @@ import ( // Handler serves HTTP API requests. type Handler struct { - configPath string - serverPort int - serverPublic bool - serverCIDRs []string - oauthMu sync.Mutex - oauthFlows map[string]*oauthFlow - oauthState map[string]string + configPath string + serverPort int + serverPublic bool + serverPublicExplicit bool + serverCIDRs []string + oauthMu sync.Mutex + oauthFlows map[string]*oauthFlow + oauthState map[string]string } // NewHandler creates an instance of the API handler. @@ -29,9 +30,10 @@ func NewHandler(configPath string) *Handler { } // SetServerOptions stores current backend listen options for fallback behavior. -func (h *Handler) SetServerOptions(port int, public bool, allowedCIDRs []string) { +func (h *Handler) SetServerOptions(port int, public bool, publicExplicit bool, allowedCIDRs []string) { h.serverPort = port h.serverPublic = public + h.serverPublicExplicit = publicExplicit h.serverCIDRs = append([]string(nil), allowedCIDRs...) } @@ -58,6 +60,10 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) { // Channel catalog (for frontend navigation/config pages) h.registerChannelRoutes(mux) + // Skills and tools support/actions + h.registerSkillRoutes(mux) + h.registerToolRoutes(mux) + // OS startup / launch-at-login h.registerStartupRoutes(mux) diff --git a/web/backend/api/session.go b/web/backend/api/session.go index e3cf674fc..42d451a05 100644 --- a/web/backend/api/session.go +++ b/web/backend/api/session.go @@ -1,7 +1,9 @@ package api import ( + "bufio" "encoding/json" + "errors" "net/http" "os" "path/filepath" @@ -33,12 +35,22 @@ type sessionFile struct { // sessionListItem is a lightweight summary returned by GET /api/sessions. type sessionListItem struct { ID string `json:"id"` + Title string `json:"title"` Preview string `json:"preview"` MessageCount int `json:"message_count"` Created string `json:"created"` Updated string `json:"updated"` } +type sessionMetaFile struct { + Key string `json:"key"` + Summary string `json:"summary"` + Skip int `json:"skip"` + Count int `json:"count"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + // picoSessionPrefix is the key prefix used by the gateway's routing for Pico // channel sessions. The full key format is: // @@ -47,7 +59,12 @@ type sessionListItem struct { // The sanitized filename replaces ':' with '_', so on disk it becomes: // // agent_main_pico_direct_pico_.json -const picoSessionPrefix = "agent:main:pico:direct:pico:" +const ( + picoSessionPrefix = "agent:main:pico:direct:pico:" + sanitizedPicoSessionPrefix = "agent_main_pico_direct_pico_" + maxSessionJSONLLineSize = 10 * 1024 * 1024 // 10 MB + maxSessionTitleRunes = 60 +) // extractPicoSessionID extracts the session UUID from a full session key. // Returns the UUID and true if the key matches the Pico session pattern. @@ -58,6 +75,178 @@ func extractPicoSessionID(key string) (string, bool) { return "", false } +func extractPicoSessionIDFromSanitizedKey(key string) (string, bool) { + if strings.HasPrefix(key, sanitizedPicoSessionPrefix) { + return strings.TrimPrefix(key, sanitizedPicoSessionPrefix), true + } + return "", false +} + +func sanitizeSessionKey(key string) string { + return strings.ReplaceAll(key, ":", "_") +} + +func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error) { + path := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID)+".json") + data, err := os.ReadFile(path) + if err != nil { + return sessionFile{}, err + } + + var sess sessionFile + if err := json.Unmarshal(data, &sess); err != nil { + return sessionFile{}, err + } + return sess, nil +} + +func (h *Handler) readSessionMeta(path, sessionKey string) (sessionMetaFile, error) { + data, err := os.ReadFile(path) + if os.IsNotExist(err) { + return sessionMetaFile{Key: sessionKey}, nil + } + if err != nil { + return sessionMetaFile{}, err + } + + var meta sessionMetaFile + if err := json.Unmarshal(data, &meta); err != nil { + return sessionMetaFile{}, err + } + if meta.Key == "" { + meta.Key = sessionKey + } + return meta, nil +} + +func (h *Handler) readSessionMessages(path string, skip int) ([]providers.Message, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + msgs := make([]providers.Message, 0) + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), maxSessionJSONLLineSize) + + seen := 0 + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + seen++ + if seen <= skip { + continue + } + + var msg providers.Message + if err := json.Unmarshal(line, &msg); err != nil { + continue + } + msgs = append(msgs, msg) + } + if err := scanner.Err(); err != nil { + return nil, err + } + return msgs, nil +} + +func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) { + sessionKey := picoSessionPrefix + sessionID + base := filepath.Join(dir, sanitizeSessionKey(sessionKey)) + jsonlPath := base + ".jsonl" + metaPath := base + ".meta.json" + + meta, err := h.readSessionMeta(metaPath, sessionKey) + if err != nil { + return sessionFile{}, err + } + + messages, err := h.readSessionMessages(jsonlPath, meta.Skip) + if err != nil { + return sessionFile{}, err + } + + updated := meta.UpdatedAt + created := meta.CreatedAt + if created.IsZero() || updated.IsZero() { + if info, statErr := os.Stat(jsonlPath); statErr == nil { + if created.IsZero() { + created = info.ModTime() + } + if updated.IsZero() { + updated = info.ModTime() + } + } + } + + return sessionFile{ + Key: meta.Key, + Messages: messages, + Summary: meta.Summary, + Created: created, + Updated: updated, + }, nil +} + +func buildSessionListItem(sessionID string, sess sessionFile) sessionListItem { + preview := "" + for _, msg := range sess.Messages { + if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" { + preview = msg.Content + break + } + } + title := strings.TrimSpace(sess.Summary) + if title == "" { + title = preview + } + + title = truncateRunes(title, maxSessionTitleRunes) + preview = truncateRunes(preview, maxSessionTitleRunes) + + if preview == "" { + preview = "(empty)" + } + if title == "" { + title = preview + } + + validMessageCount := 0 + for _, msg := range sess.Messages { + if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" { + validMessageCount++ + } + } + + return sessionListItem{ + ID: sessionID, + Title: title, + Preview: preview, + MessageCount: validMessageCount, + Created: sess.Created.Format(time.RFC3339), + Updated: sess.Updated.Format(time.RFC3339), + } +} + +func isEmptySession(sess sessionFile) bool { + return len(sess.Messages) == 0 && strings.TrimSpace(sess.Summary) == "" +} + +func truncateRunes(s string, maxLen int) string { + if maxLen <= 0 { + return "" + } + runes := []rune(strings.TrimSpace(s)) + if len(runes) <= maxLen { + return string(runes) + } + return string(runes[:maxLen]) + "..." +} + // sessionsDir resolves the path to the gateway's session storage directory. // It reads the workspace from config, falling back to ~/.picoclaw/workspace. func (h *Handler) sessionsDir() (string, error) { @@ -104,58 +293,76 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) { } items := []sessionListItem{} + seen := make(map[string]struct{}) for _, entry := range entries { - if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + if entry.IsDir() { continue } - data, err := os.ReadFile(filepath.Join(dir, entry.Name())) - if err != nil { - continue - } + name := entry.Name() + var ( + sessionID string + sess sessionFile + loadErr error + ok bool + ) - var sess sessionFile - if err := json.Unmarshal(data, &sess); err != nil { - continue - } - - // Only include Pico channel sessions - sessionID, ok := extractPicoSessionID(sess.Key) - if !ok { - continue - } - - // Build a preview from the first user message - preview := "" - for _, msg := range sess.Messages { - if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" { - preview = msg.Content - break + switch { + case strings.HasSuffix(name, ".jsonl"): + sessionID, ok = extractPicoSessionIDFromSanitizedKey(strings.TrimSuffix(name, ".jsonl")) + if !ok { + continue } - } - if len([]rune(preview)) > 60 { - preview = string([]rune(preview)[:60]) + "..." - } - if preview == "" { - preview = "(empty)" - } - - // Only count non-empty user and assistant messages - validMessageCount := 0 - for _, msg := range sess.Messages { - if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" { - validMessageCount++ + sess, loadErr = h.readJSONLSession(dir, sessionID) + if loadErr == nil && isEmptySession(sess) { + continue } + case strings.HasSuffix(name, ".meta.json"): + continue + case filepath.Ext(name) == ".json": + base := strings.TrimSuffix(name, ".json") + if _, statErr := os.Stat(filepath.Join(dir, base+".jsonl")); statErr == nil { + if jsonlSessionID, found := extractPicoSessionIDFromSanitizedKey(base); found { + if jsonlSess, jsonlErr := h.readJSONLSession( + dir, + jsonlSessionID, + ); jsonlErr == nil && + !isEmptySession(jsonlSess) { + continue + } + } + } + data, err := os.ReadFile(filepath.Join(dir, name)) + if err != nil { + continue + } + if err := json.Unmarshal(data, &sess); err != nil { + continue + } + if isEmptySession(sess) { + continue + } + sessionID, ok = extractPicoSessionID(sess.Key) + if !ok { + continue + } + if _, exists := seen[sessionID]; exists { + continue + } + default: + continue } - items = append(items, sessionListItem{ - ID: sessionID, - Preview: preview, - MessageCount: validMessageCount, - Created: sess.Created.Format(time.RFC3339), - Updated: sess.Updated.Format(time.RFC3339), - }) + if loadErr != nil { + continue + } + if _, exists := seen[sessionID]; exists { + continue + } + + seen[sessionID] = struct{}{} + items = append(items, buildSessionListItem(sessionID, sess)) } // Sort by updated descending (most recent first) @@ -209,20 +416,25 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) { return } - // The sanitized filename replaces ':' with '_': - // agent:main:pico:direct:pico: -> agent_main_pico_direct_pico_.json - filename := strings.ReplaceAll(picoSessionPrefix+sessionID, ":", "_") + ".json" - - data, err := os.ReadFile(filepath.Join(dir, filename)) - if err != nil { - http.Error(w, "session not found", http.StatusNotFound) - return + sess, err := h.readJSONLSession(dir, sessionID) + if err == nil && isEmptySession(sess) { + err = os.ErrNotExist } - - var sess sessionFile - if err := json.Unmarshal(data, &sess); err != nil { - http.Error(w, "failed to parse session", http.StatusInternalServerError) - return + if err != nil { + if errors.Is(err, os.ErrNotExist) { + sess, err = h.readLegacySession(dir, sessionID) + if err == nil && isEmptySession(sess) { + err = os.ErrNotExist + } + } + if err != nil { + if errors.Is(err, os.ErrNotExist) { + http.Error(w, "session not found", http.StatusNotFound) + } else { + http.Error(w, "failed to parse session", http.StatusInternalServerError) + } + return + } } // Convert to a simpler format for the frontend @@ -268,17 +480,25 @@ func (h *Handler) handleDeleteSession(w http.ResponseWriter, r *http.Request) { return } - // The sanitized filename replaces ':' with '_': - // agent:main:pico:direct:pico: -> agent_main_pico_direct_pico_.json - filename := strings.ReplaceAll(picoSessionPrefix+sessionID, ":", "_") + ".json" - filePath := filepath.Join(dir, filename) + base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID)) + jsonlPath := base + ".jsonl" + metaPath := base + ".meta.json" + legacyPath := base + ".json" - if err := os.Remove(filePath); err != nil { - if os.IsNotExist(err) { - http.Error(w, "session not found", http.StatusNotFound) - } else { + removed := false + for _, path := range []string{jsonlPath, metaPath, legacyPath} { + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + continue + } http.Error(w, "failed to delete session", http.StatusInternalServerError) + return } + removed = true + } + + if !removed { + http.Error(w, "session not found", http.StatusNotFound) return } diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go new file mode 100644 index 000000000..21ef5b5b8 --- /dev/null +++ b/web/backend/api/session_test.go @@ -0,0 +1,322 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/memory" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/session" +) + +func sessionsTestDir(t *testing.T, configPath string) string { + t.Helper() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + dir := filepath.Join(cfg.Agents.Defaults.Workspace, "sessions") + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + return dir +} + +func TestHandleListSessions_JSONLStorage(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "history-jsonl" + if err := store.AddFullMessage(nil, sessionKey, providers.Message{ + Role: "user", + Content: "Explain why the history API is empty after migration.", + }); err != nil { + t.Fatalf("AddFullMessage(user) error = %v", err) + } + if err := store.AddFullMessage(nil, sessionKey, providers.Message{ + Role: "assistant", + Content: "Because the API still reads only legacy JSON session files.", + }); err != nil { + t.Fatalf("AddFullMessage(assistant) error = %v", err) + } + if err := store.AddFullMessage(nil, sessionKey, providers.Message{ + Role: "tool", + Content: "ignored", + }); err != nil { + t.Fatalf("AddFullMessage(tool) error = %v", err) + } + if err := store.SetSummary(nil, sessionKey, "JSONL-backed session"); err != nil { + t.Fatalf("SetSummary() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(items) != 1 { + t.Fatalf("len(items) = %d, want 1", len(items)) + } + if items[0].ID != "history-jsonl" { + t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "history-jsonl") + } + if items[0].MessageCount != 2 { + t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount) + } + if items[0].Title != "JSONL-backed session" { + t.Fatalf("items[0].Title = %q, want %q", items[0].Title, "JSONL-backed session") + } + if items[0].Preview != "Explain why the history API is empty after migration." { + t.Fatalf("items[0].Preview = %q", items[0].Preview) + } +} + +func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "summary-title" + if err := store.AddFullMessage(nil, sessionKey, providers.Message{ + Role: "user", + Content: "fallback preview", + }); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + if err := store.SetSummary( + nil, + sessionKey, + " This summary is intentionally longer than sixty characters so it must be truncated in the history menu. ", + ); err != nil { + t.Fatalf("SetSummary() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(items) != 1 { + t.Fatalf("len(items) = %d, want 1", len(items)) + } + expectedTitle := truncateRunes( + "This summary is intentionally longer than sixty characters so it must be truncated in the history menu.", + maxSessionTitleRunes, + ) + if items[0].Title != expectedTitle { + t.Fatalf("items[0].Title = %q", items[0].Title) + } + if items[0].Preview != "fallback preview" { + t.Fatalf("items[0].Preview = %q, want %q", items[0].Preview, "fallback preview") + } +} + +func TestHandleGetSession_JSONLStorage(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "detail-jsonl" + for _, msg := range []providers.Message{ + {Role: "user", Content: "first"}, + {Role: "assistant", Content: "second"}, + {Role: "tool", Content: "ignored"}, + } { + if err := store.AddFullMessage(nil, sessionKey, msg); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + } + if err := store.SetSummary(nil, sessionKey, "detail summary"); err != nil { + t.Fatalf("SetSummary() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-jsonl", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp struct { + ID string `json:"id"` + Summary string `json:"summary"` + Messages []struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"messages"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if resp.ID != "detail-jsonl" { + t.Fatalf("resp.ID = %q, want %q", resp.ID, "detail-jsonl") + } + if resp.Summary != "detail summary" { + t.Fatalf("resp.Summary = %q, want %q", resp.Summary, "detail summary") + } + if len(resp.Messages) != 2 { + t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages)) + } + if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "first" { + t.Fatalf("first message = %#v, want user/first", resp.Messages[0]) + } + if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "second" { + t.Fatalf("second message = %#v, want assistant/second", resp.Messages[1]) + } +} + +func TestHandleDeleteSession_JSONLStorage(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, err := memory.NewJSONLStore(dir) + if err != nil { + t.Fatalf("NewJSONLStore() error = %v", err) + } + + sessionKey := picoSessionPrefix + "delete-jsonl" + if err := store.AddFullMessage(nil, sessionKey, providers.Message{ + Role: "user", + Content: "delete me", + }); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + if err := store.SetSummary(nil, sessionKey, "delete summary"); err != nil { + t.Fatalf("SetSummary() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/api/sessions/delete-jsonl", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusNoContent, rec.Body.String()) + } + + base := filepath.Join(dir, sanitizeSessionKey(sessionKey)) + for _, path := range []string{base + ".jsonl", base + ".meta.json"} { + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("expected %s to be removed, stat err = %v", path, err) + } + } +} + +func TestHandleGetSession_LegacyJSONFallback(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + manager := session.NewSessionManager(dir) + sessionKey := picoSessionPrefix + "legacy-json" + manager.AddMessage(sessionKey, "user", "legacy user") + manager.AddMessage(sessionKey, "assistant", "legacy assistant") + if err := manager.Save(sessionKey); err != nil { + t.Fatalf("Save() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/sessions/legacy-json", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } +} + +func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+"empty-jsonl")) + if err := os.WriteFile(base+".jsonl", []byte{}, 0o644); err != nil { + t.Fatalf("WriteFile(jsonl) error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(listRec, listReq) + + if listRec.Code != http.StatusOK { + t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal(list) error = %v", err) + } + if len(items) != 0 { + t.Fatalf("len(items) = %d, want 0", len(items)) + } + + detailRec := httptest.NewRecorder() + detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/empty-jsonl", nil) + mux.ServeHTTP(detailRec, detailReq) + + if detailRec.Code != http.StatusNotFound { + t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusNotFound, detailRec.Body.String()) + } +} diff --git a/web/backend/api/skills.go b/web/backend/api/skills.go new file mode 100644 index 000000000..936074fee --- /dev/null +++ b/web/backend/api/skills.go @@ -0,0 +1,331 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/skills" +) + +type skillSupportResponse struct { + Skills []skills.SkillInfo `json:"skills"` +} + +type skillDetailResponse struct { + Name string `json:"name"` + Path string `json:"path"` + Source string `json:"source"` + Description string `json:"description"` + Content string `json:"content"` +} + +var ( + skillNameSanitizer = regexp.MustCompile(`[^a-z0-9-]+`) + importedSkillFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`) + skillFrontmatterStripper = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`) +) + +func (h *Handler) registerSkillRoutes(mux *http.ServeMux) { + mux.HandleFunc("GET /api/skills", h.handleListSkills) + mux.HandleFunc("GET /api/skills/{name}", h.handleGetSkill) + mux.HandleFunc("POST /api/skills/import", h.handleImportSkill) + mux.HandleFunc("DELETE /api/skills/{name}", h.handleDeleteSkill) +} + +func (h *Handler) handleListSkills(w http.ResponseWriter, r *http.Request) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + + loader := newSkillsLoader(cfg.WorkspacePath()) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(skillSupportResponse{ + Skills: loader.ListSkills(), + }) +} + +func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + + loader := newSkillsLoader(cfg.WorkspacePath()) + name := r.PathValue("name") + allSkills := loader.ListSkills() + + for _, skill := range allSkills { + if skill.Name != name { + continue + } + + content, err := loadSkillContent(skill.Path) + if err != nil { + http.Error(w, "Skill content not found", http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(skillDetailResponse{ + Name: skill.Name, + Path: skill.Path, + Source: skill.Source, + Description: skill.Description, + Content: content, + }) + return + } + + http.Error(w, "Skill not found", http.StatusNotFound) +} + +func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + + err = r.ParseMultipartForm(2 << 20) + if err != nil { + http.Error(w, fmt.Sprintf("Invalid multipart form: %v", err), http.StatusBadRequest) + return + } + + uploadedFile, fileHeader, err := r.FormFile("file") + if err != nil { + http.Error(w, "file is required", http.StatusBadRequest) + return + } + defer uploadedFile.Close() + + content, err := io.ReadAll(io.LimitReader(uploadedFile, (1<<20)+1)) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to read file: %v", err), http.StatusBadRequest) + return + } + if len(content) > 1<<20 { + http.Error(w, "file exceeds 1MB limit", http.StatusBadRequest) + return + } + + skillName, err := normalizeImportedSkillName(fileHeader.Filename, content) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + content = normalizeImportedSkillContent(content, skillName) + + workspace := cfg.WorkspacePath() + skillDir := filepath.Join(workspace, "skills", skillName) + skillFile := filepath.Join(skillDir, "SKILL.md") + if _, err := os.Stat(skillDir); err == nil { + http.Error(w, "skill already exists", http.StatusConflict) + return + } + + if err := os.MkdirAll(skillDir, 0o755); err != nil { + http.Error(w, fmt.Sprintf("Failed to create skill directory: %v", err), http.StatusInternalServerError) + return + } + if err := os.WriteFile(skillFile, content, 0o644); err != nil { + http.Error(w, fmt.Sprintf("Failed to save skill: %v", err), http.StatusInternalServerError) + return + } + + loader := newSkillsLoader(workspace) + for _, skill := range loader.ListSkills() { + if skill.Path == skillFile || (skill.Name == skillName && skill.Source == "workspace") { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(skill) + return + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "name": skillName, + "path": skillFile, + }) +} + +func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + + loader := newSkillsLoader(cfg.WorkspacePath()) + name := r.PathValue("name") + for _, skill := range loader.ListSkills() { + if skill.Name != name { + continue + } + if skill.Source != "workspace" { + http.Error(w, "only workspace skills can be deleted", http.StatusBadRequest) + return + } + if err := os.RemoveAll(filepath.Dir(skill.Path)); err != nil { + http.Error(w, fmt.Sprintf("Failed to delete skill: %v", err), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + return + } + + http.Error(w, "Skill not found", http.StatusNotFound) +} + +func newSkillsLoader(workspace string) *skills.SkillsLoader { + return skills.NewSkillsLoader( + workspace, + filepath.Join(globalConfigDir(), "skills"), + builtinSkillsDir(), + ) +} + +func normalizeImportedSkillName(filename string, content []byte) (string, error) { + rawContent := strings.ReplaceAll(string(content), "\r\n", "\n") + rawContent = strings.ReplaceAll(rawContent, "\r", "\n") + metadata, _ := extractImportedSkillMetadata(rawContent) + + raw := strings.TrimSpace(metadata["name"]) + if raw == "" { + raw = strings.TrimSpace(strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename))) + } + raw = strings.ToLower(raw) + raw = strings.ReplaceAll(raw, "_", "-") + raw = strings.ReplaceAll(raw, " ", "-") + raw = skillNameSanitizer.ReplaceAllString(raw, "-") + raw = strings.Trim(raw, "-") + raw = strings.Join(strings.FieldsFunc(raw, func(r rune) bool { return r == '-' }), "-") + + if raw == "" { + return "", fmt.Errorf("skill name is required in frontmatter or filename") + } + if len(raw) > 64 { + return "", fmt.Errorf("skill name exceeds 64 characters") + } + matched, err := regexp.MatchString(`^[a-z0-9]+(-[a-z0-9]+)*$`, raw) + if err != nil || !matched { + return "", fmt.Errorf("skill name must be alphanumeric with hyphens") + } + return raw, nil +} + +func normalizeImportedSkillContent(content []byte, skillName string) []byte { + raw := strings.ReplaceAll(string(content), "\r\n", "\n") + raw = strings.ReplaceAll(raw, "\r", "\n") + + metadata, body := extractImportedSkillMetadata(raw) + description := strings.TrimSpace(metadata["description"]) + if description == "" { + description = inferImportedSkillDescription(body) + } + if description == "" { + description = "Imported skill" + } + if len(description) > 1024 { + description = strings.TrimSpace(description[:1024]) + } + + body = strings.TrimLeft(body, "\n") + var builder strings.Builder + builder.WriteString("---\n") + builder.WriteString("name: ") + builder.WriteString(skillName) + builder.WriteString("\n") + builder.WriteString("description: ") + builder.WriteString(description) + builder.WriteString("\n") + builder.WriteString("---\n\n") + builder.WriteString(body) + if !strings.HasSuffix(builder.String(), "\n") { + builder.WriteString("\n") + } + return []byte(builder.String()) +} + +func extractImportedSkillMetadata(raw string) (map[string]string, string) { + matches := importedSkillFrontmatter.FindStringSubmatch(raw) + if len(matches) != 2 { + return map[string]string{}, raw + } + meta := parseImportedSkillYAML(matches[1]) + body := importedSkillFrontmatter.ReplaceAllString(raw, "") + return meta, body +} + +func parseImportedSkillYAML(frontmatter string) map[string]string { + result := make(map[string]string) + for _, line := range strings.Split(frontmatter, "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + key, value, ok := strings.Cut(line, ":") + if !ok { + continue + } + result[strings.TrimSpace(key)] = strings.Trim(strings.TrimSpace(value), `"'`) + } + return result +} + +func inferImportedSkillDescription(body string) string { + for _, line := range strings.Split(body, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + line = strings.TrimLeft(line, "#-*0123456789. ") + line = strings.TrimSpace(line) + if line != "" { + return line + } + } + return "" +} + +func loadSkillContent(path string) (string, error) { + content, err := os.ReadFile(path) + if err != nil { + return "", err + } + return skillFrontmatterStripper.ReplaceAllString(string(content), ""), nil +} + +func globalConfigDir() string { + if home := os.Getenv("PICOCLAW_HOME"); home != "" { + return home + } + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".picoclaw") +} + +func builtinSkillsDir() string { + if path := os.Getenv("PICOCLAW_BUILTIN_SKILLS"); path != "" { + return path + } + wd, err := os.Getwd() + if err != nil { + return "" + } + return filepath.Join(wd, "skills") +} diff --git a/web/backend/api/skills_test.go b/web/backend/api/skills_test.go new file mode 100644 index 000000000..3289d5b33 --- /dev/null +++ b/web/backend/api/skills_test.go @@ -0,0 +1,336 @@ +package api + +import ( + "bytes" + "encoding/json" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleListSkills(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + workspace := filepath.Join(t.TempDir(), "workspace") + cfg.Agents.Defaults.Workspace = workspace + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + if err := os.MkdirAll(filepath.Join(workspace, "skills", "workspace-skill"), 0o755); err != nil { + t.Fatalf("MkdirAll(workspace skill) error = %v", err) + } + if err := os.WriteFile( + filepath.Join(workspace, "skills", "workspace-skill", "SKILL.md"), + []byte("---\nname: workspace-skill\ndescription: Workspace skill\n---\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(workspace skill) error = %v", err) + } + + globalSkillDir := filepath.Join(globalConfigDir(), "skills", "global-skill") + if err := os.MkdirAll(globalSkillDir, 0o755); err != nil { + t.Fatalf("MkdirAll(global skill) error = %v", err) + } + if err := os.WriteFile( + filepath.Join(globalSkillDir, "SKILL.md"), + []byte("---\nname: global-skill\ndescription: Global skill\n---\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(global skill) error = %v", err) + } + + builtinRoot := filepath.Join(t.TempDir(), "builtin-skills") + oldBuiltin := os.Getenv("PICOCLAW_BUILTIN_SKILLS") + if err := os.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot); err != nil { + t.Fatalf("Setenv(PICOCLAW_BUILTIN_SKILLS) error = %v", err) + } + defer func() { + if oldBuiltin == "" { + _ = os.Unsetenv("PICOCLAW_BUILTIN_SKILLS") + } else { + _ = os.Setenv("PICOCLAW_BUILTIN_SKILLS", oldBuiltin) + } + }() + + builtinSkillDir := filepath.Join(builtinRoot, "builtin-skill") + if err := os.MkdirAll(builtinSkillDir, 0o755); err != nil { + t.Fatalf("MkdirAll(builtin skill) error = %v", err) + } + if err := os.WriteFile( + filepath.Join(builtinSkillDir, "SKILL.md"), + []byte("---\nname: builtin-skill\ndescription: Builtin skill\n---\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(builtin skill) error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/skills", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp skillSupportResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if len(resp.Skills) != 3 { + t.Fatalf("skills count = %d, want 3", len(resp.Skills)) + } + + gotSkills := make(map[string]string, len(resp.Skills)) + for _, skill := range resp.Skills { + gotSkills[skill.Name] = skill.Source + } + if gotSkills["workspace-skill"] != "workspace" { + t.Fatalf("workspace-skill source = %q, want workspace", gotSkills["workspace-skill"]) + } + if gotSkills["global-skill"] != "global" { + t.Fatalf("global-skill source = %q, want global", gotSkills["global-skill"]) + } + if gotSkills["builtin-skill"] != "builtin" { + t.Fatalf("builtin-skill source = %q, want builtin", gotSkills["builtin-skill"]) + } +} + +func TestHandleGetSkill(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + workspace := filepath.Join(t.TempDir(), "workspace") + cfg.Agents.Defaults.Workspace = workspace + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + skillDir := filepath.Join(workspace, "skills", "viewer-skill") + if err := os.MkdirAll(skillDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile( + filepath.Join(skillDir, "SKILL.md"), + []byte( + "---\nname: viewer-skill\ndescription: Viewable skill\n---\n# Viewer Skill\n\nThis is visible content.\n", + ), + 0o644, + ); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/skills/viewer-skill", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp skillDetailResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if resp.Name != "viewer-skill" || resp.Source != "workspace" || resp.Description != "Viewable skill" { + t.Fatalf("unexpected response: %#v", resp) + } + if resp.Content != "# Viewer Skill\n\nThis is visible content.\n" { + t.Fatalf("content = %q", resp.Content) + } +} + +func TestHandleGetSkillUsesResolvedPath(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + workspace := filepath.Join(t.TempDir(), "workspace") + cfg.Agents.Defaults.Workspace = workspace + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + skillDir := filepath.Join(workspace, "skills", "folder-name") + if err := os.MkdirAll(skillDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile( + filepath.Join(skillDir, "SKILL.md"), + []byte("---\nname: display-name\ndescription: Mismatched path skill\n---\n# Display Name\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/skills/display-name", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp skillDetailResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + if resp.Name != "display-name" { + t.Fatalf("resp.Name = %q, want display-name", resp.Name) + } + if resp.Content != "# Display Name\n" { + t.Fatalf("content = %q", resp.Content) + } +} + +func TestHandleImportSkill(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + workspace := filepath.Join(t.TempDir(), "workspace") + cfg.Agents.Defaults.Workspace = workspace + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + part, err := writer.CreateFormFile("file", "Plain Skill.md") + if err != nil { + t.Fatalf("CreateFormFile() error = %v", err) + } + _, err = io.WriteString(part, "# Plain Skill\n\nUse this skill to test imports.\n") + if err != nil { + t.Fatalf("WriteString() error = %v", err) + } + err = writer.Close() + if err != nil { + t.Fatalf("Close() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/skills/import", &body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + skillFile := filepath.Join(workspace, "skills", "plain-skill", "SKILL.md") + content, err := os.ReadFile(skillFile) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + expected := "---\nname: plain-skill\ndescription: Plain Skill\n---\n\n# Plain Skill\n\nUse this skill to test imports.\n" + if string(content) != expected { + t.Fatalf("saved skill content mismatch:\n%s", string(content)) + } + + rec2 := httptest.NewRecorder() + req2 := httptest.NewRequest(http.MethodGet, "/api/skills", nil) + mux.ServeHTTP(rec2, req2) + if rec2.Code != http.StatusOK { + t.Fatalf("list status = %d, want %d, body=%s", rec2.Code, http.StatusOK, rec2.Body.String()) + } + var listResp skillSupportResponse + if err := json.Unmarshal(rec2.Body.Bytes(), &listResp); err != nil { + t.Fatalf("Unmarshal list response error = %v", err) + } + found := false + for _, skill := range listResp.Skills { + if skill.Name == "plain-skill" && skill.Source == "workspace" && skill.Description == "Plain Skill" { + found = true + } + } + if !found { + t.Fatalf("plain-skill should be listed after import, got %#v", listResp.Skills) + } +} + +func TestHandleDeleteSkill(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + workspace := filepath.Join(t.TempDir(), "workspace") + cfg.Agents.Defaults.Workspace = workspace + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + skillDir := filepath.Join(workspace, "skills", "delete-me") + if err := os.MkdirAll(skillDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile( + filepath.Join(skillDir, "SKILL.md"), + []byte("---\nname: delete-me\ndescription: delete me\n---\n"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/api/skills/delete-me", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + if _, err := os.Stat(skillDir); !os.IsNotExist(err) { + t.Fatalf("skill directory should be removed, stat err=%v", err) + } +} diff --git a/web/backend/api/tools.go b/web/backend/api/tools.go new file mode 100644 index 000000000..373a3be12 --- /dev/null +++ b/web/backend/api/tools.go @@ -0,0 +1,323 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "runtime" + + "github.com/sipeed/picoclaw/pkg/config" +) + +type toolCatalogEntry struct { + Name string + Description string + Category string + ConfigKey string +} + +type toolSupportItem struct { + Name string `json:"name"` + Description string `json:"description"` + Category string `json:"category"` + ConfigKey string `json:"config_key"` + Status string `json:"status"` + ReasonCode string `json:"reason_code,omitempty"` +} + +type toolSupportResponse struct { + Tools []toolSupportItem `json:"tools"` +} + +type toolStateRequest struct { + Enabled bool `json:"enabled"` +} + +var toolCatalog = []toolCatalogEntry{ + { + Name: "read_file", + Description: "Read file content from the workspace or explicitly allowed paths.", + Category: "filesystem", + ConfigKey: "read_file", + }, + { + Name: "write_file", + Description: "Create or overwrite files within the writable workspace scope.", + Category: "filesystem", + ConfigKey: "write_file", + }, + { + Name: "list_dir", + Description: "Inspect directories and enumerate files available to the agent.", + Category: "filesystem", + ConfigKey: "list_dir", + }, + { + Name: "edit_file", + Description: "Apply targeted edits to existing files without rewriting everything.", + Category: "filesystem", + ConfigKey: "edit_file", + }, + { + Name: "append_file", + Description: "Append content to the end of an existing file.", + Category: "filesystem", + ConfigKey: "append_file", + }, + { + Name: "exec", + Description: "Run shell commands inside the configured workspace sandbox.", + Category: "filesystem", + ConfigKey: "exec", + }, + { + Name: "cron", + Description: "Schedule one-time or recurring reminders, jobs, and shell commands.", + Category: "automation", + ConfigKey: "cron", + }, + { + Name: "web_search", + Description: "Search the web using the configured providers.", + Category: "web", + ConfigKey: "web", + }, + { + Name: "web_fetch", + Description: "Fetch and summarize the contents of a webpage.", + Category: "web", + ConfigKey: "web_fetch", + }, + { + Name: "message", + Description: "Send a follow-up message back to the active user or chat.", + Category: "communication", + ConfigKey: "message", + }, + { + Name: "send_file", + Description: "Send an outbound file or media attachment to the active chat.", + Category: "communication", + ConfigKey: "send_file", + }, + { + Name: "find_skills", + Description: "Search external skill registries for installable skills.", + Category: "skills", + ConfigKey: "find_skills", + }, + { + Name: "install_skill", + Description: "Install a skill into the current workspace from a registry.", + Category: "skills", + ConfigKey: "install_skill", + }, + { + Name: "spawn", + Description: "Launch a background subagent for long-running or delegated work.", + Category: "agents", + ConfigKey: "spawn", + }, + { + Name: "i2c", + Description: "Interact with I2C hardware devices exposed on the host.", + Category: "hardware", + ConfigKey: "i2c", + }, + { + Name: "spi", + Description: "Interact with SPI hardware devices exposed on the host.", + Category: "hardware", + ConfigKey: "spi", + }, + { + Name: "tool_search_tool_regex", + Description: "Discover hidden MCP tools by regex search when tool discovery is enabled.", + Category: "discovery", + ConfigKey: "mcp.discovery.use_regex", + }, + { + Name: "tool_search_tool_bm25", + Description: "Discover hidden MCP tools by semantic ranking when tool discovery is enabled.", + Category: "discovery", + ConfigKey: "mcp.discovery.use_bm25", + }, +} + +func (h *Handler) registerToolRoutes(mux *http.ServeMux) { + mux.HandleFunc("GET /api/tools", h.handleListTools) + mux.HandleFunc("PUT /api/tools/{name}/state", h.handleUpdateToolState) +} + +func (h *Handler) handleListTools(w http.ResponseWriter, r *http.Request) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toolSupportResponse{ + Tools: buildToolSupport(cfg), + }) +} + +func (h *Handler) handleUpdateToolState(w http.ResponseWriter, r *http.Request) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) + return + } + + var req toolStateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + if err := applyToolState(cfg, r.PathValue("name"), req.Enabled); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := config.SaveConfig(h.configPath, cfg); err != nil { + http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) +} + +func buildToolSupport(cfg *config.Config) []toolSupportItem { + items := make([]toolSupportItem, 0, len(toolCatalog)) + for _, entry := range toolCatalog { + status := "disabled" + reasonCode := "" + + switch entry.Name { + case "find_skills", "install_skill": + if cfg.Tools.IsToolEnabled(entry.ConfigKey) { + if cfg.Tools.IsToolEnabled("skills") { + status = "enabled" + } else { + status = "blocked" + reasonCode = "requires_skills" + } + } + case "spawn": + if cfg.Tools.IsToolEnabled(entry.ConfigKey) { + if cfg.Tools.IsToolEnabled("subagent") { + status = "enabled" + } else { + status = "blocked" + reasonCode = "requires_subagent" + } + } + case "tool_search_tool_regex": + status, reasonCode = resolveDiscoveryToolSupport(cfg, cfg.Tools.MCP.Discovery.UseRegex) + case "tool_search_tool_bm25": + status, reasonCode = resolveDiscoveryToolSupport(cfg, cfg.Tools.MCP.Discovery.UseBM25) + case "i2c", "spi": + status, reasonCode = resolveHardwareToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey)) + default: + if cfg.Tools.IsToolEnabled(entry.ConfigKey) { + status = "enabled" + } + } + + items = append(items, toolSupportItem{ + Name: entry.Name, + Description: entry.Description, + Category: entry.Category, + ConfigKey: entry.ConfigKey, + Status: status, + ReasonCode: reasonCode, + }) + } + return items +} + +func resolveHardwareToolSupport(enabled bool) (string, string) { + if !enabled { + return "disabled", "" + } + if runtime.GOOS != "linux" { + return "blocked", "requires_linux" + } + return "enabled", "" +} + +func resolveDiscoveryToolSupport(cfg *config.Config, methodEnabled bool) (string, string) { + if !cfg.Tools.IsToolEnabled("mcp") { + return "disabled", "" + } + if !cfg.Tools.MCP.Discovery.Enabled { + return "blocked", "requires_mcp_discovery" + } + if !methodEnabled { + return "disabled", "" + } + return "enabled", "" +} + +func applyToolState(cfg *config.Config, toolName string, enabled bool) error { + switch toolName { + case "read_file": + cfg.Tools.ReadFile.Enabled = enabled + case "write_file": + cfg.Tools.WriteFile.Enabled = enabled + case "list_dir": + cfg.Tools.ListDir.Enabled = enabled + case "edit_file": + cfg.Tools.EditFile.Enabled = enabled + case "append_file": + cfg.Tools.AppendFile.Enabled = enabled + case "exec": + cfg.Tools.Exec.Enabled = enabled + case "cron": + cfg.Tools.Cron.Enabled = enabled + case "web_search": + cfg.Tools.Web.Enabled = enabled + case "web_fetch": + cfg.Tools.WebFetch.Enabled = enabled + case "message": + cfg.Tools.Message.Enabled = enabled + case "send_file": + cfg.Tools.SendFile.Enabled = enabled + case "find_skills": + cfg.Tools.FindSkills.Enabled = enabled + if enabled { + cfg.Tools.Skills.Enabled = true + } + case "install_skill": + cfg.Tools.InstallSkill.Enabled = enabled + if enabled { + cfg.Tools.Skills.Enabled = true + } + case "spawn": + cfg.Tools.Spawn.Enabled = enabled + if enabled { + cfg.Tools.Subagent.Enabled = true + } + case "i2c": + cfg.Tools.I2C.Enabled = enabled + case "spi": + cfg.Tools.SPI.Enabled = enabled + case "tool_search_tool_regex": + cfg.Tools.MCP.Discovery.UseRegex = enabled + if enabled { + cfg.Tools.MCP.Enabled = true + cfg.Tools.MCP.Discovery.Enabled = true + } + case "tool_search_tool_bm25": + cfg.Tools.MCP.Discovery.UseBM25 = enabled + if enabled { + cfg.Tools.MCP.Enabled = true + cfg.Tools.MCP.Discovery.Enabled = true + } + default: + return fmt.Errorf("tool %q cannot be updated", toolName) + } + return nil +} diff --git a/web/backend/api/tools_test.go b/web/backend/api/tools_test.go new file mode 100644 index 000000000..646cefbe2 --- /dev/null +++ b/web/backend/api/tools_test.go @@ -0,0 +1,198 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "runtime" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestHandleListTools(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.Tools.ReadFile.Enabled = true + cfg.Tools.WriteFile.Enabled = false + cfg.Tools.Cron.Enabled = true + cfg.Tools.FindSkills.Enabled = true + cfg.Tools.Skills.Enabled = true + cfg.Tools.Spawn.Enabled = true + cfg.Tools.Subagent.Enabled = false + cfg.Tools.MCP.Enabled = true + cfg.Tools.MCP.Discovery.Enabled = true + cfg.Tools.MCP.Discovery.UseRegex = true + cfg.Tools.MCP.Discovery.UseBM25 = false + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/tools", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var resp toolSupportResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + gotTools := make(map[string]toolSupportItem, len(resp.Tools)) + for _, tool := range resp.Tools { + gotTools[tool.Name] = tool + } + if gotTools["read_file"].Status != "enabled" { + t.Fatalf("read_file status = %q, want enabled", gotTools["read_file"].Status) + } + if gotTools["write_file"].Status != "disabled" { + t.Fatalf("write_file status = %q, want disabled", gotTools["write_file"].Status) + } + if gotTools["cron"].Status != "enabled" { + t.Fatalf("cron status = %q, want enabled", gotTools["cron"].Status) + } + if gotTools["spawn"].Status != "blocked" || gotTools["spawn"].ReasonCode != "requires_subagent" { + t.Fatalf("spawn = %#v, want blocked/requires_subagent", gotTools["spawn"]) + } + if gotTools["find_skills"].Status != "enabled" { + t.Fatalf("find_skills status = %q, want enabled", gotTools["find_skills"].Status) + } + if gotTools["tool_search_tool_regex"].Status != "enabled" { + t.Fatalf("tool_search_tool_regex status = %q, want enabled", gotTools["tool_search_tool_regex"].Status) + } + if gotTools["tool_search_tool_regex"].ConfigKey != "mcp.discovery.use_regex" { + t.Fatalf( + "tool_search_tool_regex config_key = %q, want mcp.discovery.use_regex", + gotTools["tool_search_tool_regex"].ConfigKey, + ) + } + if gotTools["tool_search_tool_bm25"].Status != "disabled" { + t.Fatalf("tool_search_tool_bm25 status = %q, want disabled", gotTools["tool_search_tool_bm25"].Status) + } + if gotTools["tool_search_tool_bm25"].ConfigKey != "mcp.discovery.use_bm25" { + t.Fatalf( + "tool_search_tool_bm25 config_key = %q, want mcp.discovery.use_bm25", + gotTools["tool_search_tool_bm25"].ConfigKey, + ) + } + if runtime.GOOS == "linux" { + if gotTools["i2c"].Status != "disabled" { + t.Fatalf("i2c status = %q, want disabled on linux when config is off", gotTools["i2c"].Status) + } + } else { + cfg.Tools.I2C.Enabled = true + cfg.Tools.SPI.Enabled = true + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/tools", nil) + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + gotTools = make(map[string]toolSupportItem, len(resp.Tools)) + for _, tool := range resp.Tools { + gotTools[tool.Name] = tool + } + + if gotTools["i2c"].Status != "blocked" || gotTools["i2c"].ReasonCode != "requires_linux" { + t.Fatalf("i2c = %#v, want blocked/requires_linux", gotTools["i2c"]) + } + if gotTools["spi"].Status != "blocked" || gotTools["spi"].ReasonCode != "requires_linux" { + t.Fatalf("spi = %#v, want blocked/requires_linux", gotTools["spi"]) + } + } +} + +func TestHandleUpdateToolState(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + cfg, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + cfg.Tools.Spawn.Enabled = false + cfg.Tools.Subagent.Enabled = false + cfg.Tools.Cron.Enabled = false + cfg.Tools.MCP.Enabled = false + cfg.Tools.MCP.Discovery.Enabled = false + cfg.Tools.MCP.Discovery.UseRegex = false + err = config.SaveConfig(configPath, cfg) + if err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodPut, + "/api/tools/spawn/state", + bytes.NewBufferString(`{"enabled":true}`), + ) + req.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("spawn status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + + rec2 := httptest.NewRecorder() + req2 := httptest.NewRequest( + http.MethodPut, + "/api/tools/tool_search_tool_regex/state", + bytes.NewBufferString(`{"enabled":true}`), + ) + req2.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec2, req2) + if rec2.Code != http.StatusOK { + t.Fatalf("regex status = %d, want %d, body=%s", rec2.Code, http.StatusOK, rec2.Body.String()) + } + + rec3 := httptest.NewRecorder() + req3 := httptest.NewRequest( + http.MethodPut, + "/api/tools/cron/state", + bytes.NewBufferString(`{"enabled":true}`), + ) + req3.Header.Set("Content-Type", "application/json") + mux.ServeHTTP(rec3, req3) + if rec3.Code != http.StatusOK { + t.Fatalf("cron status = %d, want %d, body=%s", rec3.Code, http.StatusOK, rec3.Body.String()) + } + + updated, err := config.LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig(updated) error = %v", err) + } + if !updated.Tools.Spawn.Enabled || !updated.Tools.Subagent.Enabled { + t.Fatalf("spawn/subagent should both be enabled: %#v", updated.Tools) + } + if !updated.Tools.MCP.Enabled || !updated.Tools.MCP.Discovery.Enabled || !updated.Tools.MCP.Discovery.UseRegex { + t.Fatalf("mcp regex discovery should be enabled: %#v", updated.Tools.MCP) + } + if !updated.Tools.Cron.Enabled { + t.Fatalf("cron should be enabled: %#v", updated.Tools.Cron) + } +} diff --git a/web/backend/dist/.gitkeep b/web/backend/dist/.gitkeep index e69de29bb..4b533f03a 100644 --- a/web/backend/dist/.gitkeep +++ b/web/backend/dist/.gitkeep @@ -0,0 +1 @@ +# Keep the embedded web backend dist directory in version control. diff --git a/web/backend/main.go b/web/backend/main.go index b8c4dc2bb..650540ea8 100644 --- a/web/backend/main.go +++ b/web/backend/main.go @@ -25,6 +25,7 @@ import ( "github.com/sipeed/picoclaw/web/backend/api" "github.com/sipeed/picoclaw/web/backend/launcherconfig" "github.com/sipeed/picoclaw/web/backend/middleware" + "github.com/sipeed/picoclaw/web/backend/utils" ) func main() { @@ -51,7 +52,7 @@ func main() { flag.Parse() // Resolve config path - configPath := getDefaultConfigPath() + configPath := utils.GetDefaultConfigPath() if flag.NArg() > 0 { configPath = flag.Arg(0) } @@ -60,6 +61,10 @@ func main() { if err != nil { log.Fatalf("Failed to resolve config path: %v", err) } + err = utils.EnsureOnboarded(absPath) + if err != nil { + log.Printf("Warning: Failed to initialize PicoClaw config automatically: %v", err) + } var explicitPort bool var explicitPublic bool @@ -109,7 +114,7 @@ func main() { // API Routes (e.g. /api/status) apiHandler := api.NewHandler(absPath) - apiHandler.SetServerOptions(portNum, effectivePublic, launcherCfg.AllowedCIDRs) + apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs) apiHandler.RegisterRoutes(mux) // Frontend Embedded Assets @@ -128,13 +133,13 @@ func main() { ) // Print startup banner - fmt.Print(banner) + fmt.Print(utils.Banner) fmt.Println() fmt.Println(" Open the following URL in your browser:") fmt.Println() fmt.Printf(" >> http://localhost:%s <<\n", effectivePort) if effectivePublic { - if ip := getLocalIP(); ip != "" { + if ip := utils.GetLocalIP(); ip != "" { fmt.Printf(" >> http://%s:%s <<\n", ip, effectivePort) } } @@ -145,7 +150,7 @@ func main() { go func() { time.Sleep(500 * time.Millisecond) url := "http://localhost:" + effectivePort - if err := openBrowser(url); err != nil { + if err := utils.OpenBrowser(url); err != nil { log.Printf("Warning: Failed to auto-open browser: %v", err) } }() diff --git a/web/backend/utils.go b/web/backend/utils/banner.go similarity index 54% rename from web/backend/utils.go rename to web/backend/utils/banner.go index 6fa734aeb..a64ea6390 100644 --- a/web/backend/utils.go +++ b/web/backend/utils/banner.go @@ -1,19 +1,10 @@ -package main - -import ( - "fmt" - "net" - "os" - "os/exec" - "path/filepath" - "runtime" -) +package utils const ( colorBlue = "\x1b[38;2;62;93;185m" colorRed = "\x1b[38;2;213;70;70m" colorReset = "\x1b[0m" - banner = "\r\n" + + Banner = "\r\n" + colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" + colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" + colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" + @@ -22,40 +13,3 @@ const ( colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n" + colorReset ) - -// getDefaultConfigPath returns the default path to the picoclaw config file. -func getDefaultConfigPath() string { - home, err := os.UserHomeDir() - if err != nil { - return "config.json" - } - return filepath.Join(home, ".picoclaw", "config.json") -} - -// getLocalIP returns the local IP address of the machine. -func getLocalIP() string { - addrs, err := net.InterfaceAddrs() - if err != nil { - return "" - } - for _, a := range addrs { - if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { - return ipnet.IP.String() - } - } - return "" -} - -// openBrowser automatically opens the given URL in the default browser. -func openBrowser(url string) error { - switch runtime.GOOS { - case "linux": - return exec.Command("xdg-open", url).Start() - case "windows": - return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - return exec.Command("open", url).Start() - default: - return fmt.Errorf("unsupported platform") - } -} diff --git a/web/backend/utils/onboard.go b/web/backend/utils/onboard.go new file mode 100644 index 000000000..fbe34f220 --- /dev/null +++ b/web/backend/utils/onboard.go @@ -0,0 +1,42 @@ +package utils + +import ( + "fmt" + "os" + "os/exec" + "strings" +) + +var execCommand = exec.Command + +func EnsureOnboarded(configPath string) error { + _, err := os.Stat(configPath) + if err == nil { + return nil + } + if !os.IsNotExist(err) { + return fmt.Errorf("stat config: %w", err) + } + + cmd := execCommand(FindPicoclawBinary(), "onboard") + cmd.Env = append(os.Environ(), "PICOCLAW_CONFIG="+configPath) + cmd.Stdin = strings.NewReader("n\n") + + output, err := cmd.CombinedOutput() + if err != nil { + trimmed := strings.TrimSpace(string(output)) + if trimmed == "" { + return fmt.Errorf("run onboard: %w", err) + } + return fmt.Errorf("run onboard: %w: %s", err, trimmed) + } + + if _, err := os.Stat(configPath); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("onboard completed but did not create config %s", configPath) + } + return fmt.Errorf("verify config after onboard: %w", err) + } + + return nil +} diff --git a/web/backend/utils/onboard_test.go b/web/backend/utils/onboard_test.go new file mode 100644 index 000000000..06f967e76 --- /dev/null +++ b/web/backend/utils/onboard_test.go @@ -0,0 +1,101 @@ +package utils + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func TestEnsureOnboardedSkipsWhenConfigExists(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + if err := os.WriteFile(configPath, []byte(`{}`), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + origExecCommand := execCommand + defer func() { execCommand = origExecCommand }() + + called := false + execCommand = func(name string, args ...string) *exec.Cmd { + called = true + return exec.Command("sh", "-c", "exit 1") + } + + if err := EnsureOnboarded(configPath); err != nil { + t.Fatalf("EnsureOnboarded() error = %v", err) + } + if called { + t.Fatal("expected onboard command not to run when config already exists") + } +} + +func TestEnsureOnboardedRunsOnboardWhenConfigMissing(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + t.Setenv("EXPECTED_CONFIG_PATH", configPath) + + origExecCommand := execCommand + defer func() { execCommand = origExecCommand }() + + var gotName string + var gotArgs []string + execCommand = func(name string, args ...string) *exec.Cmd { + gotName = name + gotArgs = append([]string(nil), args...) + return exec.Command( + "sh", + "-c", + `test "$PICOCLAW_CONFIG" = "$EXPECTED_CONFIG_PATH" && +mkdir -p "$(dirname "$PICOCLAW_CONFIG")" && +printf '{}' > "$PICOCLAW_CONFIG"`, + ) + } + + if err := EnsureOnboarded(configPath); err != nil { + t.Fatalf("EnsureOnboarded() error = %v", err) + } + if gotName == "" { + t.Fatal("expected onboard command to run") + } + if len(gotArgs) != 1 || gotArgs[0] != "onboard" { + t.Fatalf("command args = %#v, want []string{\"onboard\"}", gotArgs) + } + if _, err := os.Stat(configPath); err != nil { + t.Fatalf("expected config to be created: %v", err) + } +} + +func TestEnsureOnboardedFailsWhenOnboardDoesNotCreateConfig(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + + origExecCommand := execCommand + defer func() { execCommand = origExecCommand }() + + execCommand = func(name string, args ...string) *exec.Cmd { + return exec.Command("sh", "-c", "exit 0") + } + + if err := EnsureOnboarded(configPath); err == nil { + t.Fatal("EnsureOnboarded() error = nil, want failure when onboard does not create config") + } +} + +func TestEnsureOnboardedIncludesOnboardOutputOnFailure(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + + origExecCommand := execCommand + defer func() { execCommand = origExecCommand }() + + execCommand = func(name string, args ...string) *exec.Cmd { + return exec.Command("sh", "-c", "echo onboarding failed >&2; exit 2") + } + + err := EnsureOnboarded(configPath) + if err == nil { + t.Fatal("EnsureOnboarded() error = nil, want failure") + } + if !strings.Contains(err.Error(), "onboarding failed") { + t.Fatalf("error = %q, want onboard output included", err) + } +} diff --git a/web/backend/utils/runtime.go b/web/backend/utils/runtime.go new file mode 100644 index 000000000..4e6c32c56 --- /dev/null +++ b/web/backend/utils/runtime.go @@ -0,0 +1,80 @@ +package utils + +import ( + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "runtime" +) + +// GetDefaultConfigPath returns the default path to the picoclaw config file. +func GetDefaultConfigPath() string { + if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" { + return configPath + } + if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" { + return filepath.Join(picoclawHome, "config.json") + } + home, err := os.UserHomeDir() + if err != nil { + return "config.json" + } + return filepath.Join(home, ".picoclaw", "config.json") +} + +// FindPicoclawBinary locates the picoclaw executable. +// Search order: +// 1. PICOCLAW_BINARY environment variable (explicit override) +// 2. Same directory as the current executable +// 3. Falls back to "picoclaw" and relies on $PATH +func FindPicoclawBinary() string { + binaryName := "picoclaw" + if runtime.GOOS == "windows" { + binaryName = "picoclaw.exe" + } + + if p := os.Getenv("PICOCLAW_BINARY"); p != "" { + if info, _ := os.Stat(p); info != nil && !info.IsDir() { + return p + } + } + + if exe, err := os.Executable(); err == nil { + candidate := filepath.Join(filepath.Dir(exe), binaryName) + if info, err := os.Stat(candidate); err == nil && !info.IsDir() { + return candidate + } + } + + return "picoclaw" +} + +// GetLocalIP returns the local IP address of the machine. +func GetLocalIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + for _, a := range addrs { + if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && ipnet.IP.To4() != nil { + return ipnet.IP.String() + } + } + return "" +} + +// OpenBrowser automatically opens the given URL in the default browser. +func OpenBrowser(url string) error { + switch runtime.GOOS { + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + return exec.Command("open", url).Start() + default: + return fmt.Errorf("unsupported platform") + } +} diff --git a/web/frontend/package.json b/web/frontend/package.json index ee46cdcda..687fd5771 100644 --- a/web/frontend/package.json +++ b/web/frontend/package.json @@ -32,7 +32,7 @@ "react-markdown": "^10.1.0", "react-textarea-autosize": "^8.5.9", "remark-gfm": "^4.0.1", - "shadcn": "^3.8.5", + "shadcn": "^4.0.5", "sonner": "^2.0.7", "tailwind-merge": "^3.5.0", "tailwindcss": "^4.2.1", diff --git a/web/frontend/pnpm-lock.yaml b/web/frontend/pnpm-lock.yaml index 8e89cbbe5..9de3354a1 100644 --- a/web/frontend/pnpm-lock.yaml +++ b/web/frontend/pnpm-lock.yaml @@ -66,8 +66,8 @@ importers: specifier: ^4.0.1 version: 4.0.1 shadcn: - specifier: ^3.8.5 - version: 3.8.5(@types/node@24.11.0)(typescript@5.9.3) + specifier: ^4.0.5 + version: 4.0.5(@types/node@24.11.0)(typescript@5.9.3) sonner: specifier: ^2.0.7 version: 2.0.7(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -512,8 +512,8 @@ packages: '@fontsource-variable/inter@5.2.8': resolution: {integrity: sha512-kOfP2D+ykbcX/P3IFnokOhVRNoTozo5/JxhAIVYLpea/UBmCQ/YWPBfWIDuBImXX/15KH+eKh4xpEUyS2sQQGQ==} - '@hono/node-server@1.19.9': - resolution: {integrity: sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==} + '@hono/node-server@1.19.11': + resolution: {integrity: sha512-dr8/3zEaB+p0D2n/IUrlPF1HZm586qgJNXK1a9fhg/PzdtkK7Ksd5l312tJX2yBuALqDYBlG20QEbayqPyxn+g==} engines: {node: '>=18.14.1'} peerDependencies: hono: ^4 @@ -1359,79 +1359,66 @@ packages: resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} cpu: [arm] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.59.0': resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} cpu: [arm] os: [linux] - libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.59.0': resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} cpu: [arm64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.59.0': resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} cpu: [arm64] os: [linux] - libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.59.0': resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} cpu: [loong64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-loong64-musl@4.59.0': resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} cpu: [loong64] os: [linux] - libc: [musl] '@rollup/rollup-linux-ppc64-gnu@4.59.0': resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} cpu: [ppc64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-ppc64-musl@4.59.0': resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} cpu: [ppc64] os: [linux] - libc: [musl] '@rollup/rollup-linux-riscv64-gnu@4.59.0': resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} cpu: [riscv64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.59.0': resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} cpu: [riscv64] os: [linux] - libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.59.0': resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} cpu: [s390x] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.59.0': resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} cpu: [x64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-musl@4.59.0': resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} cpu: [x64] os: [linux] - libc: [musl] '@rollup/rollup-openbsd-x64@4.59.0': resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} @@ -1516,28 +1503,24 @@ packages: engines: {node: '>= 20'} cpu: [arm64] os: [linux] - libc: [glibc] '@tailwindcss/oxide-linux-arm64-musl@4.2.1': resolution: {integrity: sha512-WZA0CHRL/SP1TRbA5mp9htsppSEkWuQ4KsSUumYQnyl8ZdT39ntwqmz4IUHGN6p4XdSlYfJwM4rRzZLShHsGAQ==} engines: {node: '>= 20'} cpu: [arm64] os: [linux] - libc: [musl] '@tailwindcss/oxide-linux-x64-gnu@4.2.1': resolution: {integrity: sha512-qMFzxI2YlBOLW5PhblzuSWlWfwLHaneBE0xHzLrBgNtqN6mWfs+qYbhryGSXQjFYB1Dzf5w+LN5qbUTPhW7Y5g==} engines: {node: '>= 20'} cpu: [x64] os: [linux] - libc: [glibc] '@tailwindcss/oxide-linux-x64-musl@4.2.1': resolution: {integrity: sha512-5r1X2FKnCMUPlXTWRYpHdPYUY6a1Ar/t7P24OuiEdEOmms5lyqjDRvVY1yy9Rmioh+AunQ0rWiOTPE8F9A3v5g==} engines: {node: '>= 20'} cpu: [x64] os: [linux] - libc: [musl] '@tailwindcss/oxide-wasm32-wasi@4.2.1': resolution: {integrity: sha512-MGFB5cVPvshR85MTJkEvqDUnuNoysrsRxd6vnk1Lf2tbiqNlXpHYZqkqOQalydienEWOHHFyyuTSYRsLfxFJ2Q==} @@ -2296,8 +2279,8 @@ packages: resolution: {integrity: sha512-9Be3ZoN4LmYR90tUoVu2te2BsbzHfhJyfEiAVfz7N5/zv+jduIfLrV2xdQXOHbaD6KgpGdO9PRPM1Y4Q9QkPkA==} engines: {node: ^18.19.0 || >=20.5.0} - express-rate-limit@8.2.1: - resolution: {integrity: sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==} + express-rate-limit@8.3.1: + resolution: {integrity: sha512-D1dKN+cmyPWuvB+G2SREQDzPY1agpBIcTa9sJxOPMCNeH3gwzhqJRDWCXW3gg0y//+LQ/8j52JbMROWyrKdMdw==} engines: {node: '>= 16'} peerDependencies: express: '>= 4.11' @@ -2496,8 +2479,8 @@ packages: hermes-parser@0.25.1: resolution: {integrity: sha512-6pEjquH3rqaI6cYAXYPcz9MS4rY6R4ngRgrgfDshRptUZIc3lw0MCIJIGDj9++mfySOuPTHB4nrSW99BCvOPIA==} - hono@4.12.3: - resolution: {integrity: sha512-SFsVSjp8sj5UumXOOFlkZOG6XS9SJDKw0TbwFeV+AJ8xlST8kxK5Z/5EYa111UY8732lK2S/xB653ceuaoGwpg==} + hono@4.12.7: + resolution: {integrity: sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==} engines: {node: '>=16.9.0'} html-parse-stringify@3.0.1: @@ -2559,8 +2542,8 @@ packages: inline-style-parser@0.2.7: resolution: {integrity: sha512-Nb2ctOyNR8DqQoR0OwRG95uNWIC0C1lCgf5Naz5H6Ji72KZ8OcFZLz2P5sNgwlyoJ8Yif11oMuYs5pBQa86csA==} - ip-address@10.0.1: - resolution: {integrity: sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==} + ip-address@10.1.0: + resolution: {integrity: sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==} engines: {node: '>= 12'} ipaddr.js@1.9.1: @@ -2785,28 +2768,24 @@ packages: engines: {node: '>= 12.0.0'} cpu: [arm64] os: [linux] - libc: [glibc] lightningcss-linux-arm64-musl@1.31.1: resolution: {integrity: sha512-mVZ7Pg2zIbe3XlNbZJdjs86YViQFoJSpc41CbVmKBPiGmC4YrfeOyz65ms2qpAobVd7WQsbW4PdsSJEMymyIMg==} engines: {node: '>= 12.0.0'} cpu: [arm64] os: [linux] - libc: [musl] lightningcss-linux-x64-gnu@1.31.1: resolution: {integrity: sha512-xGlFWRMl+0KvUhgySdIaReQdB4FNudfUTARn7q0hh/V67PVGCs3ADFjw+6++kG1RNd0zdGRlEKa+T13/tQjPMA==} engines: {node: '>= 12.0.0'} cpu: [x64] os: [linux] - libc: [glibc] lightningcss-linux-x64-musl@1.31.1: resolution: {integrity: sha512-eowF8PrKHw9LpoZii5tdZwnBcYDxRw2rRCyvAXLi34iyeYfqCQNA9rmUM0ce62NlPhCvof1+9ivRaTY6pSKDaA==} engines: {node: '>= 12.0.0'} cpu: [x64] os: [linux] - libc: [musl] lightningcss-win32-arm64-msvc@1.31.1: resolution: {integrity: sha512-aJReEbSEQzx1uBlQizAOBSjcmr9dCdL3XuC/6HLXAxmtErsj2ICo5yYggg1qOODQMtnjNQv2UHb9NpOuFtYe4w==} @@ -3501,8 +3480,8 @@ packages: setprototypeof@1.2.0: resolution: {integrity: sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==} - shadcn@3.8.5: - resolution: {integrity: sha512-jPRx44e+eyeV7xwY3BLJXcfrks00+M0h5BGB9l6DdcBW4BpAj4x3lVmVy0TXPEs2iHEisxejr62sZAAw6B1EVA==} + shadcn@4.0.5: + resolution: {integrity: sha512-z0SOHEU1+ADam1UJHrgxJhUsOb0/jBoYc+u9mhWs071KrnORq48X7uCwG3mD2ysQEBtOfeK/MxMGsmzL5Jt+Jg==} hasBin: true shebang-command@2.0.0: @@ -4332,9 +4311,9 @@ snapshots: '@fontsource-variable/inter@5.2.8': {} - '@hono/node-server@1.19.9(hono@4.12.3)': + '@hono/node-server@1.19.11(hono@4.12.7)': dependencies: - hono: 4.12.3 + hono: 4.12.7 '@humanfs/core@0.19.1': {} @@ -4396,7 +4375,7 @@ snapshots: '@modelcontextprotocol/sdk@1.27.1(zod@3.25.76)': dependencies: - '@hono/node-server': 1.19.9(hono@4.12.3) + '@hono/node-server': 1.19.11(hono@4.12.7) ajv: 8.18.0 ajv-formats: 3.0.1(ajv@8.18.0) content-type: 1.0.5 @@ -4405,8 +4384,8 @@ snapshots: eventsource: 3.0.7 eventsource-parser: 3.0.6 express: 5.2.1 - express-rate-limit: 8.2.1(express@5.2.1) - hono: 4.12.3 + express-rate-limit: 8.3.1(express@5.2.1) + hono: 4.12.7 jose: 6.1.3 json-schema-typed: 8.0.2 pkce-challenge: 5.0.1 @@ -6146,10 +6125,10 @@ snapshots: strip-final-newline: 4.0.0 yoctocolors: 2.1.2 - express-rate-limit@8.2.1(express@5.2.1): + express-rate-limit@8.3.1(express@5.2.1): dependencies: express: 5.2.1 - ip-address: 10.0.1 + ip-address: 10.1.0 express@5.2.1: dependencies: @@ -6374,7 +6353,7 @@ snapshots: dependencies: hermes-estree: 0.25.1 - hono@4.12.3: {} + hono@4.12.7: {} html-parse-stringify@3.0.1: dependencies: @@ -6430,7 +6409,7 @@ snapshots: inline-style-parser@0.2.7: {} - ip-address@10.0.1: {} + ip-address@10.1.0: {} ipaddr.js@1.9.1: {} @@ -7534,7 +7513,7 @@ snapshots: setprototypeof@1.2.0: {} - shadcn@3.8.5(@types/node@24.11.0)(typescript@5.9.3): + shadcn@4.0.5(@types/node@24.11.0)(typescript@5.9.3): dependencies: '@antfu/ni': 25.0.0 '@babel/core': 7.29.0 diff --git a/web/frontend/src/api/gateway.ts b/web/frontend/src/api/gateway.ts index 5a58d48f0..020e92e3a 100644 --- a/web/frontend/src/api/gateway.ts +++ b/web/frontend/src/api/gateway.ts @@ -14,6 +14,8 @@ interface GatewayStatusResponse { interface GatewayActionResponse { status: string pid?: number + log_total?: number + log_run_id?: number } const BASE_URL = "" @@ -59,4 +61,10 @@ export async function restartGateway(): Promise { }) } +export async function clearGatewayLogs(): Promise { + return request("/api/gateway/logs/clear", { + method: "POST", + }) +} + export type { GatewayStatusResponse, GatewayActionResponse } diff --git a/web/frontend/src/api/sessions.ts b/web/frontend/src/api/sessions.ts index 56ef148db..10b0d28fd 100644 --- a/web/frontend/src/api/sessions.ts +++ b/web/frontend/src/api/sessions.ts @@ -2,6 +2,7 @@ export interface SessionSummary { id: string + title: string preview: string message_count: number created: string diff --git a/web/frontend/src/api/skills.ts b/web/frontend/src/api/skills.ts new file mode 100644 index 000000000..307cbd788 --- /dev/null +++ b/web/frontend/src/api/skills.ts @@ -0,0 +1,79 @@ +export interface SkillSupportItem { + name: string + path: string + source: "workspace" | "global" | "builtin" | string + description: string +} + +export interface SkillDetailResponse extends SkillSupportItem { + content: string +} + +interface SkillsResponse { + skills: SkillSupportItem[] +} + +interface SkillActionResponse { + status?: string + name?: string + path?: string + source?: string + description?: string +} + +async function request(path: string, options?: RequestInit): Promise { + const res = await fetch(path, options) + if (!res.ok) { + throw new Error(await extractErrorMessage(res)) + } + return res.json() as Promise +} + +export async function getSkills(): Promise { + return request("/api/skills") +} + +export async function getSkill(name: string): Promise { + return request(`/api/skills/${encodeURIComponent(name)}`) +} + +export async function importSkill(file: File): Promise { + const formData = new FormData() + formData.set("file", file) + + const res = await fetch("/api/skills/import", { + method: "POST", + body: formData, + }) + if (!res.ok) { + throw new Error(await extractErrorMessage(res)) + } + return res.json() as Promise +} + +export async function deleteSkill(name: string): Promise { + return request( + `/api/skills/${encodeURIComponent(name)}`, + { + method: "DELETE", + }, + ) +} + +async function extractErrorMessage(res: Response): Promise { + try { + const body = (await res.json()) as { + error?: string + errors?: string[] + } + if (Array.isArray(body.errors) && body.errors.length > 0) { + return body.errors.join("; ") + } + if (typeof body.error === "string" && body.error.trim() !== "") { + return body.error + } + } catch { + // ignore invalid body + } + return `API error: ${res.status} ${res.statusText}` +} diff --git a/web/frontend/src/api/tools.ts b/web/frontend/src/api/tools.ts new file mode 100644 index 000000000..9f09efbfd --- /dev/null +++ b/web/frontend/src/api/tools.ts @@ -0,0 +1,56 @@ +export interface ToolSupportItem { + name: string + description: string + category: string + config_key: string + status: "enabled" | "disabled" | "blocked" + reason_code?: string +} + +interface ToolsResponse { + tools: ToolSupportItem[] +} + +interface ToolActionResponse { + status: string +} + +async function request(path: string, options?: RequestInit): Promise { + const res = await fetch(path, options) + if (!res.ok) { + let message = `API error: ${res.status} ${res.statusText}` + try { + const body = (await res.json()) as { + error?: string + errors?: string[] + } + if (Array.isArray(body.errors) && body.errors.length > 0) { + message = body.errors.join("; ") + } else if (typeof body.error === "string" && body.error.trim() !== "") { + message = body.error + } + } catch { + // ignore invalid body + } + throw new Error(message) + } + return res.json() as Promise +} + +export async function getTools(): Promise { + return request("/api/tools") +} + +export async function setToolEnabled( + name: string, + enabled: boolean, +): Promise { + return request( + `/api/tools/${encodeURIComponent(name)}/state`, + { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ enabled }), + }, + ) +} diff --git a/web/frontend/src/components/app-sidebar.tsx b/web/frontend/src/components/app-sidebar.tsx index dc24f8781..702212857 100644 --- a/web/frontend/src/components/app-sidebar.tsx +++ b/web/frontend/src/components/app-sidebar.tsx @@ -7,6 +7,8 @@ import { IconListDetails, IconMessageCircle, IconSettings, + IconSparkles, + IconTools, } from "@tabler/icons-react" import { Link, useRouterState } from "@tanstack/react-router" import * as React from "react" @@ -53,6 +55,10 @@ const baseNavGroups: Omit[] = [ label: "navigation.model_group", defaultOpen: true, }, + { + label: "navigation.agent_group", + defaultOpen: true, + }, { label: "navigation.services", defaultOpen: true, @@ -113,6 +119,23 @@ export function AppSidebar({ ...props }: React.ComponentProps) { }, { ...baseNavGroups[2], + items: [ + { + title: "navigation.skills", + url: "/agent/skills", + icon: IconSparkles, + translateTitle: true, + }, + { + title: "navigation.tools", + url: "/agent/tools", + icon: IconTools, + translateTitle: true, + }, + ], + }, + { + ...baseNavGroups[3], items: [ { title: "navigation.config", diff --git a/web/frontend/src/components/chat/chat-page.tsx b/web/frontend/src/components/chat/chat-page.tsx index 0fd23a6a5..a3ab843b4 100644 --- a/web/frontend/src/components/chat/chat-page.tsx +++ b/web/frontend/src/components/chat/chat-page.tsx @@ -43,11 +43,18 @@ export function ChatPage() { handleSetDefault, } = useChatModels({ isConnected }) - const { sessions, hasMore, observerRef, loadSessions, handleDeleteSession } = - useSessionHistory({ - activeSessionId, - onDeletedActiveSession: newChat, - }) + const { + sessions, + hasMore, + loadError, + loadErrorMessage, + observerRef, + loadSessions, + handleDeleteSession, + } = useSessionHistory({ + activeSessionId, + onDeletedActiveSession: newChat, + }) const handleScroll = (e: React.UIEvent) => { const { scrollTop, scrollHeight, clientHeight } = e.currentTarget @@ -96,6 +103,8 @@ export function ChatPage() { sessions={sessions} activeSessionId={activeSessionId} hasMore={hasMore} + loadError={loadError} + loadErrorMessage={loadErrorMessage} observerRef={observerRef} onOpenChange={(open) => { if (open) { diff --git a/web/frontend/src/components/chat/session-history-menu.tsx b/web/frontend/src/components/chat/session-history-menu.tsx index f2e93295c..3f293e353 100644 --- a/web/frontend/src/components/chat/session-history-menu.tsx +++ b/web/frontend/src/components/chat/session-history-menu.tsx @@ -17,6 +17,8 @@ interface SessionHistoryMenuProps { sessions: SessionSummary[] activeSessionId: string hasMore: boolean + loadError: boolean + loadErrorMessage: string observerRef: RefObject onOpenChange: (open: boolean) => void onSwitchSession: (sessionId: string) => void @@ -27,6 +29,8 @@ export function SessionHistoryMenu({ sessions, activeSessionId, hasMore, + loadError, + loadErrorMessage, observerRef, onOpenChange, onSwitchSession, @@ -44,7 +48,14 @@ export function SessionHistoryMenu({ - {sessions.length === 0 ? ( + {loadError && ( + + + {loadErrorMessage} + + + )} + {sessions.length === 0 && !loadError ? ( {t("chat.noHistory")} @@ -60,7 +71,7 @@ export function SessionHistoryMenu({ onClick={() => onSwitchSession(session.id)} > - {session.preview} + {session.title || session.preview} {t("chat.messagesCount", { diff --git a/web/frontend/src/components/config/config-page.tsx b/web/frontend/src/components/config/config-page.tsx index c2d502079..d7e1aa1b5 100644 --- a/web/frontend/src/components/config/config-page.tsx +++ b/web/frontend/src/components/config/config-page.tsx @@ -189,6 +189,11 @@ export function ConfigPage() { session: { dm_scope: dmScope, }, + tools: { + exec: { + allow_remote: form.allowRemote, + }, + }, heartbeat: { enabled: form.heartbeatEnabled, interval: heartbeatInterval, diff --git a/web/frontend/src/components/config/config-sections.tsx b/web/frontend/src/components/config/config-sections.tsx index 340ece333..90813be2a 100644 --- a/web/frontend/src/components/config/config-sections.tsx +++ b/web/frontend/src/components/config/config-sections.tsx @@ -63,6 +63,13 @@ export function AgentDefaultsSection({ } /> + onFieldChange("allowRemote", checked)} + /> + export interface CoreConfigForm { workspace: string restrictToWorkspace: boolean + allowRemote: boolean maxTokens: string maxToolIterations: string summarizeMessageThreshold: string @@ -54,6 +55,7 @@ export const DM_SCOPE_OPTIONS = [ export const EMPTY_FORM: CoreConfigForm = { workspace: "", restrictToWorkspace: true, + allowRemote: true, maxTokens: "32768", maxToolIterations: "50", summarizeMessageThreshold: "20", @@ -103,6 +105,8 @@ export function buildFormFromConfig(config: unknown): CoreConfigForm { const session = asRecord(root.session) const heartbeat = asRecord(root.heartbeat) const devices = asRecord(root.devices) + const tools = asRecord(root.tools) + const exec = asRecord(tools.exec) return { workspace: asString(defaults.workspace) || EMPTY_FORM.workspace, @@ -110,6 +114,10 @@ export function buildFormFromConfig(config: unknown): CoreConfigForm { defaults.restrict_to_workspace === undefined ? EMPTY_FORM.restrictToWorkspace : asBool(defaults.restrict_to_workspace), + allowRemote: + exec.allow_remote === undefined + ? EMPTY_FORM.allowRemote + : asBool(exec.allow_remote), maxTokens: asNumberString(defaults.max_tokens, EMPTY_FORM.maxTokens), maxToolIterations: asNumberString( defaults.max_tool_iterations, diff --git a/web/frontend/src/components/skills/skills-page.tsx b/web/frontend/src/components/skills/skills-page.tsx new file mode 100644 index 000000000..3b5c5acb4 --- /dev/null +++ b/web/frontend/src/components/skills/skills-page.tsx @@ -0,0 +1,314 @@ +import { + IconFileInfo, + IconLoader2, + IconPlus, + IconTrash, +} from "@tabler/icons-react" +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" +import { type ChangeEvent, useRef, useState } from "react" +import { useTranslation } from "react-i18next" +import ReactMarkdown from "react-markdown" +import remarkGfm from "remark-gfm" +import { toast } from "sonner" + +import { + type SkillSupportItem, + deleteSkill, + getSkill, + getSkills, + importSkill, +} from "@/api/skills" +import { PageHeader } from "@/components/page-header" +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog" +import { Button } from "@/components/ui/button" +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card" +import { + Sheet, + SheetContent, + SheetDescription, + SheetHeader, + SheetTitle, +} from "@/components/ui/sheet" + +export function SkillsPage() { + const { t } = useTranslation() + const queryClient = useQueryClient() + const importInputRef = useRef(null) + const [selectedSkill, setSelectedSkill] = useState( + null, + ) + const [skillPendingDelete, setSkillPendingDelete] = + useState(null) + + const { data, isLoading, error } = useQuery({ + queryKey: ["skills"], + queryFn: getSkills, + }) + const { + data: selectedSkillDetail, + isLoading: isSkillDetailLoading, + error: skillDetailError, + } = useQuery({ + queryKey: ["skills", selectedSkill?.name], + queryFn: () => getSkill(selectedSkill!.name), + enabled: selectedSkill !== null, + }) + + const importMutation = useMutation({ + mutationFn: async (file: File) => importSkill(file), + onSuccess: () => { + toast.success(t("pages.agent.skills.import_success")) + void queryClient.invalidateQueries({ queryKey: ["skills"] }) + }, + onError: (err) => { + toast.error( + err instanceof Error + ? err.message + : t("pages.agent.skills.import_error"), + ) + }, + }) + + const deleteMutation = useMutation({ + mutationFn: async (name: string) => deleteSkill(name), + onSuccess: (_, deletedName) => { + toast.success(t("pages.agent.skills.delete_success")) + setSkillPendingDelete(null) + if ( + selectedSkill?.name === deletedName && + selectedSkill.source === "workspace" + ) { + setSelectedSkill(null) + } + void queryClient.invalidateQueries({ queryKey: ["skills"] }) + }, + onError: (err) => { + toast.error( + err instanceof Error + ? err.message + : t("pages.agent.skills.delete_error"), + ) + }, + }) + + const handleImportClick = () => { + importInputRef.current?.click() + } + + const handleImportFileChange = (event: ChangeEvent) => { + const file = event.target.files?.[0] + if (!file) return + importMutation.mutate(file) + event.target.value = "" + } + + return ( +
+ + + + + } + /> + +
+
+ {isLoading ? ( +
+ {t("labels.loading")} +
+ ) : error ? ( +
+ {t("pages.agent.load_error")} +
+ ) : ( +
+

+ {t("pages.agent.skills.description")} +

+ + {data?.skills.length ? ( +
+ {data.skills.map((skill) => ( + + +
+
+ + {skill.name} + + + {skill.description || + t("pages.agent.skills.no_description")} + +
+
+ + {skill.source === "workspace" ? ( + + ) : null} +
+
+
+ +
+ {t("pages.agent.skills.path")} +
+
+ {skill.path} +
+
+
+ ))} +
+ ) : ( + + + {t("pages.agent.skills.empty")} + + + )} +
+ )} +
+
+ + { + if (!open) setSelectedSkill(null) + }} + > + + + + {selectedSkill?.name || t("pages.agent.skills.viewer_title")} + + + {selectedSkill?.description || + t("pages.agent.skills.viewer_description")} + + + +
+ {isSkillDetailLoading ? ( +
+ {t("pages.agent.skills.loading_detail")} +
+ ) : skillDetailError ? ( +
+ {t("pages.agent.skills.load_detail_error")} +
+ ) : selectedSkillDetail ? ( +
+
+ + {selectedSkillDetail.content} + +
+
+ ) : null} +
+
+
+ + { + if (!open) setSkillPendingDelete(null) + }} + > + + + + {t("pages.agent.skills.delete_title")} + + + {t("pages.agent.skills.delete_description", { + name: skillPendingDelete?.name, + })} + + + + + {t("common.cancel")} + + { + if (skillPendingDelete) + deleteMutation.mutate(skillPendingDelete.name) + }} + > + {deleteMutation.isPending ? ( + + ) : ( + + )} + {t("pages.agent.skills.delete_confirm")} + + + + +
+ ) +} diff --git a/web/frontend/src/components/tools/tools-page.tsx b/web/frontend/src/components/tools/tools-page.tsx new file mode 100644 index 000000000..05aa42122 --- /dev/null +++ b/web/frontend/src/components/tools/tools-page.tsx @@ -0,0 +1,190 @@ +import { IconLoader2 } from "@tabler/icons-react" +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query" +import { useTranslation } from "react-i18next" +import { toast } from "sonner" + +import { type ToolSupportItem, getTools, setToolEnabled } from "@/api/tools" +import { PageHeader } from "@/components/page-header" +import { Button } from "@/components/ui/button" +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card" +import { cn } from "@/lib/utils" + +export function ToolsPage() { + const { t } = useTranslation() + const queryClient = useQueryClient() + const { data, isLoading, error } = useQuery({ + queryKey: ["tools"], + queryFn: getTools, + }) + + const toggleMutation = useMutation({ + mutationFn: async ({ name, enabled }: { name: string; enabled: boolean }) => + setToolEnabled(name, enabled), + onSuccess: (_, variables) => { + toast.success( + variables.enabled + ? t("pages.agent.tools.enable_success") + : t("pages.agent.tools.disable_success"), + ) + void queryClient.invalidateQueries({ queryKey: ["tools"] }) + }, + onError: (err) => { + toast.error( + err instanceof Error + ? err.message + : t("pages.agent.tools.toggle_error"), + ) + }, + }) + + const groupedTools = (() => { + if (!data) return [] as Array<[string, ToolSupportItem[]]> + const buckets = new Map() + for (const item of data.tools) { + const list = buckets.get(item.category) ?? [] + list.push(item) + buckets.set(item.category, list) + } + return Array.from(buckets.entries()) + })() + + return ( +
+ + +
+
+ {isLoading ? ( +
+ {t("labels.loading")} +
+ ) : error ? ( +
+ {t("pages.agent.load_error")} +
+ ) : ( +
+

+ {t("pages.agent.tools.description")} +

+ + {data?.tools.length ? ( + groupedTools.map(([category, items]) => ( +
+
+ {t(`pages.agent.tools.categories.${category}`)} +
+
+ {items.map((tool) => { + const reasonText = tool.reason_code + ? t(`pages.agent.tools.reasons.${tool.reason_code}`) + : "" + const isPending = + toggleMutation.isPending && + toggleMutation.variables?.name === tool.name + const nextEnabled = tool.status !== "enabled" + + return ( + + +
+
+ + {tool.name} + + + {tool.description} + +
+
+ + +
+
+
+ +
+ {t("pages.agent.tools.config_key", { + key: tool.config_key, + })} +
+ {reasonText ? ( +
+ {reasonText} +
+ ) : null} +
+
+ ) + })} +
+
+ )) + ) : ( + + + {t("pages.agent.tools.empty")} + + + )} +
+ )} +
+
+
+ ) +} + +function ToolStatusBadge({ status }: { status: ToolSupportItem["status"] }) { + const { t } = useTranslation() + + return ( + + {t(`pages.agent.tools.status.${status}`)} + + ) +} diff --git a/web/frontend/src/hooks/use-pico-chat.ts b/web/frontend/src/hooks/use-pico-chat.ts index 7735ad928..4ce615dcf 100644 --- a/web/frontend/src/hooks/use-pico-chat.ts +++ b/web/frontend/src/hooks/use-pico-chat.ts @@ -1,6 +1,8 @@ import dayjs from "dayjs" import { useAtomValue } from "jotai" import { useCallback, useEffect, useRef, useState } from "react" +import { useTranslation } from "react-i18next" +import { toast } from "sonner" import { getPicoToken } from "@/api/pico" import { getSessionHistory } from "@/api/sessions" @@ -100,6 +102,7 @@ export function formatMessageTime(dateRaw: number | string | Date): string { } export function usePicoChat() { + const { t } = useTranslation() const { status: gatewayState } = useAtomValue(gatewayAtom) const [messages, setMessages] = useState([]) const [connectionState, setConnectionState] = @@ -317,43 +320,38 @@ export function usePicoChat() { // Switch to a historical session const switchSession = useCallback( async (sessionId: string) => { - // Disconnect current WebSocket - disconnect() - - // Set new session ID - setActiveSessionId(sessionId) - setIsTyping(false) - - // Load history from backend - try { - const detail = await getSessionHistory(sessionId) - // Set all history messages timestamp from the session updated time as fallback, - // since currently the backend doesn't return per-message timestamp in the history API. - // We'll use the session's updated time for now. - const fallbackTime = detail.updated - - setMessages( - detail.messages.map((m, i) => ({ - id: `hist-${i}-${Date.now()}`, - role: m.role as "user" | "assistant", - content: m.content, - timestamp: fallbackTime, - })), - ) - } catch (err) { - console.error("Failed to load session history:", err) - setMessages([]) + if (sessionId === activeSessionIdRef.current) { + return + } + + try { + const detail = await getSessionHistory(sessionId) + const fallbackTime = detail.updated + const historyMessages = detail.messages.map((m, i) => ({ + id: `hist-${i}-${Date.now()}`, + role: m.role as "user" | "assistant", + content: m.content, + timestamp: fallbackTime, + })) + + // Only switch the active websocket session after history has loaded successfully. + disconnect() + setActiveSessionId(sessionId) + setIsTyping(false) + setMessages(historyMessages) + } catch (err) { + console.error("Failed to load session history:", err) + toast.error(t("chat.historyOpenFailed")) + return } - // Reconnect with new session ID (will use the updated ref) - // Small delay to ensure state has settled setTimeout(() => { if (gatewayState === "running") { connect() } }, 100) }, - [disconnect, connect, gatewayState], + [connect, disconnect, gatewayState, t], ) // Start a new empty chat diff --git a/web/frontend/src/hooks/use-session-history.ts b/web/frontend/src/hooks/use-session-history.ts index 1a6d5c956..790339dba 100644 --- a/web/frontend/src/hooks/use-session-history.ts +++ b/web/frontend/src/hooks/use-session-history.ts @@ -1,4 +1,5 @@ import { useCallback, useEffect, useRef, useState } from "react" +import { useTranslation } from "react-i18next" import { type SessionSummary, deleteSession, getSessions } from "@/api/sessions" @@ -13,22 +14,26 @@ export function useSessionHistory({ activeSessionId, onDeletedActiveSession, }: UseSessionHistoryOptions) { + const { t } = useTranslation() const observerRef = useRef(null) const [sessions, setSessions] = useState([]) const [offset, setOffset] = useState(0) const [hasMore, setHasMore] = useState(true) const [isLoadingMore, setIsLoadingMore] = useState(false) + const [loadError, setLoadError] = useState(false) const loadSessions = useCallback( async (reset = true) => { try { const currentOffset = reset ? 0 : offset if (reset) { + setLoadError(false) setHasMore(true) setOffset(0) } const data = await getSessions(currentOffset, LIMIT) + setLoadError(false) if (data.length < LIMIT) { setHasMore(false) @@ -45,8 +50,12 @@ export function useSessionHistory({ } setOffset(currentOffset + data.length) - } catch { - // silently fail + } catch (err) { + console.error("Failed to fetch session history:", err) + setLoadError(true) + if (!reset) { + setHasMore(false) + } } finally { setIsLoadingMore(false) } @@ -55,11 +64,16 @@ export function useSessionHistory({ ) useEffect(() => { - if (!observerRef.current || !hasMore || isLoadingMore) return + if (!observerRef.current || !hasMore || isLoadingMore || loadError) return const observer = new IntersectionObserver( (entries) => { - if (entries[0].isIntersecting && hasMore && !isLoadingMore) { + if ( + entries[0].isIntersecting && + hasMore && + !isLoadingMore && + !loadError + ) { setIsLoadingMore(true) void loadSessions(false) } @@ -69,7 +83,7 @@ export function useSessionHistory({ observer.observe(observerRef.current) return () => observer.disconnect() - }, [hasMore, isLoadingMore, loadSessions]) + }, [hasMore, isLoadingMore, loadError, loadSessions]) const handleDeleteSession = useCallback( async (id: string) => { @@ -89,6 +103,8 @@ export function useSessionHistory({ return { sessions, hasMore, + loadError, + loadErrorMessage: t("chat.historyLoadFailed"), observerRef, loadSessions, handleDeleteSession, diff --git a/web/frontend/src/hooks/use-sidebar-channels.ts b/web/frontend/src/hooks/use-sidebar-channels.ts index 0848af468..5579a955b 100644 --- a/web/frontend/src/hooks/use-sidebar-channels.ts +++ b/web/frontend/src/hooks/use-sidebar-channels.ts @@ -27,7 +27,7 @@ import { import { getChannelDisplayName } from "@/components/channels/channel-display-name" import { gatewayAtom } from "@/store/gateway" -const DEFAULT_VISIBLE_CHANNELS = 5 +const DEFAULT_VISIBLE_CHANNELS = 4 const CHANNEL_IMPORTANCE_ORDER = [ "discord", "feishu", diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index f1ed0ac16..b88b5c924 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -4,6 +4,9 @@ "model_group": "Models", "models": "Models", "credentials": "Credentials", + "agent_group": "Agent", + "skills": "Skills", + "tools": "Tools", "services": "Services", "channels_group": "Channels", "show_more_channels": "More", @@ -25,6 +28,8 @@ }, "history": "History", "noHistory": "No chat history yet", + "historyLoadFailed": "Failed to load chat history", + "historyOpenFailed": "Failed to open this chat history", "loadingMore": "Loading more...", "deleteSession": "Delete session", "messagesCount": "{{count}} messages", @@ -324,12 +329,108 @@ } }, "pages": { + "agent": { + "load_error": "Failed to load agent support information.", + "stats": { + "workspace": "Workspace", + "workspace_hint": "The default agent workspace used for runtime files and workspace skills.", + "skills": "Available Skills", + "skills_hint": "Skills discovered from workspace, global, and builtin roots.", + "tools": "Enabled Tools", + "tools_hint": "{{blocked}} blocked by missing dependencies." + }, + "skills": { + "title": "Skills", + "description": "Skills are loaded from the workspace, global PicoClaw home, and builtin directories.", + "hero_title": "Skill Library", + "hero_description": "Browse every capability package the agent can load, then drill straight into the effective SKILL.md without leaving the page.", + "stats": { + "total": "Total Skills", + "workspace": "Workspace", + "shared": "Shared" + }, + "empty": "No skills are currently available.", + "import": "Import Skill", + "import_title": "Import Skill", + "import_description": "Create a workspace skill by uploading a markdown file as the new SKILL.md.", + "import_name": "Skill Name", + "import_name_placeholder": "e.g. my-workflow", + "import_file": "Markdown File", + "import_file_hint": "Upload a .md file. The backend stores it as workspace/skills//SKILL.md.", + "import_confirm": "Import Skill", + "import_success": "Skill imported.", + "import_error": "Failed to import skill.", + "view": "View", + "delete": "Delete", + "delete_title": "Delete Skill?", + "delete_description": "\"{{name}}\" will be removed from workspace skills.", + "delete_confirm": "Delete", + "delete_success": "Skill deleted.", + "delete_error": "Failed to delete skill.", + "viewer_title": "Skill Content", + "viewer_description": "Read the current effective SKILL.md content here.", + "loading_detail": "Loading skill content...", + "load_detail_error": "Failed to load skill content.", + "source": "Source", + "path": "Skill Path", + "no_description": "No description provided.", + "sources": { + "workspace": "Workspace", + "global": "Global", + "builtin": "Builtin" + }, + "errors": { + "file_required": "Please choose a markdown file to import." + } + }, + "tools": { + "title": "Tools", + "description": "This view reflects whether each agent tool is enabled, disabled, or blocked by a missing prerequisite.", + "hero_title": "Tool Surface", + "hero_description": "Inspect what the agent can actually call right now, which capabilities are blocked, and where each tool is controlled in config.", + "stats": { + "enabled": "Enabled", + "blocked": "Blocked", + "categories": "Categories" + }, + "empty": "No tools are available.", + "enable": "Enable", + "disable": "Disable", + "enable_success": "Tool enabled.", + "disable_success": "Tool disabled.", + "toggle_error": "Failed to update tool state.", + "config_key": "Controlled by tools.{{key}}", + "status": { + "enabled": "Enabled", + "disabled": "Disabled", + "blocked": "Blocked" + }, + "categories": { + "automation": "Automation", + "filesystem": "Filesystem", + "web": "Web", + "communication": "Communication", + "skills": "Skills", + "agents": "Agents", + "hardware": "Hardware", + "discovery": "Discovery" + }, + "reasons": { + "requires_linux": "This tool only works on Linux hosts with the required device files exposed.", + "requires_skills": "Enable `tools.skills` before this skill-registry tool can be used.", + "requires_subagent": "Enable `tools.subagent` before the spawn tool can delegate work.", + "requires_mcp_discovery": "Enable `tools.mcp.discovery` before MCP discovery tools become available." + } + } + }, "config": { "load_error": "Failed to load configuration. Please refresh and try again.", "workspace": "Workspace Directory", "workspace_hint": "Base directory for agent file operations.", "restrict_workspace": "Restrict to Workspace", "restrict_workspace_hint": "Only allow file operations inside workspace.", + "allow_remote": "Allow Remote Shell Execution", + "allow_remote_hint": "When enabled, shell commands can also run for remote sessions or non-local contexts. When disabled, shell execution stays limited to local safe contexts.", "max_tokens": "Max Tokens", "max_tokens_hint": "Upper token limit per model response.", "max_tool_iterations": "Max Tool Iterations", @@ -387,7 +488,9 @@ "unsaved_changes": "You have unsaved changes." }, "logs": { - "description": "System logs and monitoring." + "description": "System logs and monitoring.", + "clear": "Clear logs", + "empty": "Waiting for logs..." } } } diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index b66f0f03d..12833cbf5 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -4,6 +4,9 @@ "model_group": "模型", "models": "模型", "credentials": "凭据", + "agent_group": "智能体", + "skills": "技能", + "tools": "工具", "services": "服务", "channels_group": "频道", "show_more_channels": "更多", @@ -25,6 +28,8 @@ }, "history": "历史记录", "noHistory": "暂无对话历史", + "historyLoadFailed": "加载历史记录失败", + "historyOpenFailed": "打开该历史会话失败", "loadingMore": "加载更多...", "deleteSession": "删除会话", "messagesCount": "{{count}} 条消息", @@ -324,12 +329,108 @@ } }, "pages": { + "agent": { + "load_error": "加载 Agent 支持信息失败。", + "stats": { + "workspace": "工作目录", + "workspace_hint": "默认 Agent 运行时使用的工作目录,也用于加载工作区技能。", + "skills": "可用技能数", + "skills_hint": "从工作区、全局目录和内置目录发现的技能。", + "tools": "已启用工具", + "tools_hint": "其中 {{blocked}} 个因依赖未满足而不可用。" + }, + "skills": { + "title": "技能", + "description": "技能会从工作区、PicoClaw 全局目录和内置目录中加载。", + "hero_title": "技能库", + "hero_description": "在这里查看 Agent 当前可加载的能力包,并且不离开页面就能直接阅读生效后的 SKILL.md。", + "stats": { + "total": "技能总数", + "workspace": "工作区技能", + "shared": "共享技能" + }, + "empty": "当前没有可用技能。", + "import": "导入技能", + "import_title": "导入技能", + "import_description": "通过上传 Markdown 文件创建工作区技能,文件会保存为新的 SKILL.md。", + "import_name": "技能名称", + "import_name_placeholder": "例如 my-workflow", + "import_file": "Markdown 文件", + "import_file_hint": "上传一个 .md 文件。后端会保存到 workspace/skills//SKILL.md。", + "import_confirm": "导入技能", + "import_success": "技能导入成功。", + "import_error": "导入技能失败。", + "view": "查看", + "delete": "删除", + "delete_title": "删除技能?", + "delete_description": "将从工作区技能中移除「{{name}}」。", + "delete_confirm": "删除", + "delete_success": "技能已删除。", + "delete_error": "删除技能失败。", + "viewer_title": "技能内容", + "viewer_description": "这里展示当前生效的 SKILL.md 内容。", + "loading_detail": "正在加载技能内容...", + "load_detail_error": "加载技能内容失败。", + "source": "来源", + "path": "技能路径", + "no_description": "未提供描述。", + "sources": { + "workspace": "工作区", + "global": "全局", + "builtin": "内置" + }, + "errors": { + "file_required": "请先选择要导入的 Markdown 文件。" + } + }, + "tools": { + "title": "工具", + "description": "这里展示每个 Agent 工具当前是已启用、已禁用,还是被依赖条件阻塞。", + "hero_title": "工具面板", + "hero_description": "集中查看 Agent 现在真正可调用的工具、被阻塞的能力,以及它们分别受哪项配置控制。", + "stats": { + "enabled": "已启用", + "blocked": "被阻塞", + "categories": "分类数" + }, + "empty": "当前没有可用工具。", + "enable": "启用", + "disable": "禁用", + "enable_success": "工具已启用。", + "disable_success": "工具已禁用。", + "toggle_error": "更新工具状态失败。", + "config_key": "由 tools.{{key}} 控制", + "status": { + "enabled": "已启用", + "disabled": "已禁用", + "blocked": "被阻塞" + }, + "categories": { + "automation": "自动化", + "filesystem": "文件系统", + "web": "网页", + "communication": "通信", + "skills": "技能", + "agents": "Agent", + "hardware": "硬件", + "discovery": "发现" + }, + "reasons": { + "requires_linux": "该工具仅在 Linux 主机上可用,并且需要暴露对应的设备文件。", + "requires_skills": "需要先启用 `tools.skills`,该技能注册表工具才能使用。", + "requires_subagent": "需要先启用 `tools.subagent`,`spawn` 才能委派任务。", + "requires_mcp_discovery": "需要先启用 `tools.mcp.discovery`,MCP 发现工具才会可用。" + } + } + }, "config": { "load_error": "加载配置失败,请刷新后重试。", "workspace": "工作目录", "workspace_hint": "智能体执行文件读写操作时使用的基础目录。", "restrict_workspace": "限制工作目录访问", "restrict_workspace_hint": "仅允许在工作目录内执行文件操作。", + "allow_remote": "允许远程执行 Shell 命令", + "allow_remote_hint": "开启后,来自远程会话或非本地上下文的请求也可以执行 shell 命令;关闭后,仅允许本地安全上下文执行。", "max_tokens": "最大 Token 数", "max_tokens_hint": "单次模型响应允许的最大 Token 数。", "max_tool_iterations": "最大工具迭代次数", @@ -387,7 +488,9 @@ "unsaved_changes": "您有未保存的更改。" }, "logs": { - "description": "系统日志和监控。" + "description": "系统日志和监控。", + "clear": "清空日志", + "empty": "等待日志中..." } } } diff --git a/web/frontend/src/routeTree.gen.ts b/web/frontend/src/routeTree.gen.ts index 336504075..60f19ab53 100644 --- a/web/frontend/src/routeTree.gen.ts +++ b/web/frontend/src/routeTree.gen.ts @@ -13,10 +13,13 @@ import { Route as ModelsRouteImport } from './routes/models' import { Route as LogsRouteImport } from './routes/logs' import { Route as CredentialsRouteImport } from './routes/credentials' import { Route as ConfigRouteImport } from './routes/config' +import { Route as AgentRouteImport } from './routes/agent' import { Route as ChannelsRouteRouteImport } from './routes/channels/route' import { Route as IndexRouteImport } from './routes/index' import { Route as ConfigRawRouteImport } from './routes/config.raw' import { Route as ChannelsNameRouteImport } from './routes/channels/$name' +import { Route as AgentToolsRouteImport } from './routes/agent/tools' +import { Route as AgentSkillsRouteImport } from './routes/agent/skills' const ModelsRoute = ModelsRouteImport.update({ id: '/models', @@ -38,6 +41,11 @@ const ConfigRoute = ConfigRouteImport.update({ path: '/config', getParentRoute: () => rootRouteImport, } as any) +const AgentRoute = AgentRouteImport.update({ + id: '/agent', + path: '/agent', + getParentRoute: () => rootRouteImport, +} as any) const ChannelsRouteRoute = ChannelsRouteRouteImport.update({ id: '/channels', path: '/channels', @@ -58,24 +66,40 @@ const ChannelsNameRoute = ChannelsNameRouteImport.update({ path: '/$name', getParentRoute: () => ChannelsRouteRoute, } as any) +const AgentToolsRoute = AgentToolsRouteImport.update({ + id: '/tools', + path: '/tools', + getParentRoute: () => AgentRoute, +} as any) +const AgentSkillsRoute = AgentSkillsRouteImport.update({ + id: '/skills', + path: '/skills', + getParentRoute: () => AgentRoute, +} as any) export interface FileRoutesByFullPath { '/': typeof IndexRoute '/channels': typeof ChannelsRouteRouteWithChildren + '/agent': typeof AgentRouteWithChildren '/config': typeof ConfigRouteWithChildren '/credentials': typeof CredentialsRoute '/logs': typeof LogsRoute '/models': typeof ModelsRoute + '/agent/skills': typeof AgentSkillsRoute + '/agent/tools': typeof AgentToolsRoute '/channels/$name': typeof ChannelsNameRoute '/config/raw': typeof ConfigRawRoute } export interface FileRoutesByTo { '/': typeof IndexRoute '/channels': typeof ChannelsRouteRouteWithChildren + '/agent': typeof AgentRouteWithChildren '/config': typeof ConfigRouteWithChildren '/credentials': typeof CredentialsRoute '/logs': typeof LogsRoute '/models': typeof ModelsRoute + '/agent/skills': typeof AgentSkillsRoute + '/agent/tools': typeof AgentToolsRoute '/channels/$name': typeof ChannelsNameRoute '/config/raw': typeof ConfigRawRoute } @@ -83,10 +107,13 @@ export interface FileRoutesById { __root__: typeof rootRouteImport '/': typeof IndexRoute '/channels': typeof ChannelsRouteRouteWithChildren + '/agent': typeof AgentRouteWithChildren '/config': typeof ConfigRouteWithChildren '/credentials': typeof CredentialsRoute '/logs': typeof LogsRoute '/models': typeof ModelsRoute + '/agent/skills': typeof AgentSkillsRoute + '/agent/tools': typeof AgentToolsRoute '/channels/$name': typeof ChannelsNameRoute '/config/raw': typeof ConfigRawRoute } @@ -95,30 +122,39 @@ export interface FileRouteTypes { fullPaths: | '/' | '/channels' + | '/agent' | '/config' | '/credentials' | '/logs' | '/models' + | '/agent/skills' + | '/agent/tools' | '/channels/$name' | '/config/raw' fileRoutesByTo: FileRoutesByTo to: | '/' | '/channels' + | '/agent' | '/config' | '/credentials' | '/logs' | '/models' + | '/agent/skills' + | '/agent/tools' | '/channels/$name' | '/config/raw' id: | '__root__' | '/' | '/channels' + | '/agent' | '/config' | '/credentials' | '/logs' | '/models' + | '/agent/skills' + | '/agent/tools' | '/channels/$name' | '/config/raw' fileRoutesById: FileRoutesById @@ -126,6 +162,7 @@ export interface FileRouteTypes { export interface RootRouteChildren { IndexRoute: typeof IndexRoute ChannelsRouteRoute: typeof ChannelsRouteRouteWithChildren + AgentRoute: typeof AgentRouteWithChildren ConfigRoute: typeof ConfigRouteWithChildren CredentialsRoute: typeof CredentialsRoute LogsRoute: typeof LogsRoute @@ -162,6 +199,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof ConfigRouteImport parentRoute: typeof rootRouteImport } + '/agent': { + id: '/agent' + path: '/agent' + fullPath: '/agent' + preLoaderRoute: typeof AgentRouteImport + parentRoute: typeof rootRouteImport + } '/channels': { id: '/channels' path: '/channels' @@ -190,6 +234,20 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof ChannelsNameRouteImport parentRoute: typeof ChannelsRouteRoute } + '/agent/tools': { + id: '/agent/tools' + path: '/tools' + fullPath: '/agent/tools' + preLoaderRoute: typeof AgentToolsRouteImport + parentRoute: typeof AgentRoute + } + '/agent/skills': { + id: '/agent/skills' + path: '/skills' + fullPath: '/agent/skills' + preLoaderRoute: typeof AgentSkillsRouteImport + parentRoute: typeof AgentRoute + } } } @@ -205,6 +263,18 @@ const ChannelsRouteRouteWithChildren = ChannelsRouteRoute._addFileChildren( ChannelsRouteRouteChildren, ) +interface AgentRouteChildren { + AgentSkillsRoute: typeof AgentSkillsRoute + AgentToolsRoute: typeof AgentToolsRoute +} + +const AgentRouteChildren: AgentRouteChildren = { + AgentSkillsRoute: AgentSkillsRoute, + AgentToolsRoute: AgentToolsRoute, +} + +const AgentRouteWithChildren = AgentRoute._addFileChildren(AgentRouteChildren) + interface ConfigRouteChildren { ConfigRawRoute: typeof ConfigRawRoute } @@ -219,6 +289,7 @@ const ConfigRouteWithChildren = const rootRouteChildren: RootRouteChildren = { IndexRoute: IndexRoute, ChannelsRouteRoute: ChannelsRouteRouteWithChildren, + AgentRoute: AgentRouteWithChildren, ConfigRoute: ConfigRouteWithChildren, CredentialsRoute: CredentialsRoute, LogsRoute: LogsRoute, diff --git a/web/frontend/src/routes/agent.tsx b/web/frontend/src/routes/agent.tsx new file mode 100644 index 000000000..78104de5b --- /dev/null +++ b/web/frontend/src/routes/agent.tsx @@ -0,0 +1,22 @@ +import { + Navigate, + Outlet, + createFileRoute, + useRouterState, +} from "@tanstack/react-router" + +export const Route = createFileRoute("/agent")({ + component: AgentLayout, +}) + +function AgentLayout() { + const pathname = useRouterState({ + select: (state) => state.location.pathname, + }) + + if (pathname === "/agent") { + return + } + + return +} diff --git a/web/frontend/src/routes/agent/skills.tsx b/web/frontend/src/routes/agent/skills.tsx new file mode 100644 index 000000000..bbe396bdb --- /dev/null +++ b/web/frontend/src/routes/agent/skills.tsx @@ -0,0 +1,11 @@ +import { createFileRoute } from "@tanstack/react-router" + +import { SkillsPage } from "@/components/skills/skills-page" + +export const Route = createFileRoute("/agent/skills")({ + component: AgentSkillsRoute, +}) + +function AgentSkillsRoute() { + return +} diff --git a/web/frontend/src/routes/agent/tools.tsx b/web/frontend/src/routes/agent/tools.tsx new file mode 100644 index 000000000..ac8738a8f --- /dev/null +++ b/web/frontend/src/routes/agent/tools.tsx @@ -0,0 +1,11 @@ +import { createFileRoute } from "@tanstack/react-router" + +import { ToolsPage } from "@/components/tools/tools-page" + +export const Route = createFileRoute("/agent/tools")({ + component: AgentToolsRoute, +}) + +function AgentToolsRoute() { + return +} diff --git a/web/frontend/src/routes/logs.tsx b/web/frontend/src/routes/logs.tsx index 39688bd84..ef39e0bdf 100644 --- a/web/frontend/src/routes/logs.tsx +++ b/web/frontend/src/routes/logs.tsx @@ -1,10 +1,12 @@ +import { IconTrash } from "@tabler/icons-react" import { createFileRoute } from "@tanstack/react-router" import { useAtomValue } from "jotai" import { useEffect, useRef, useState } from "react" import { useTranslation } from "react-i18next" -import { getGatewayStatus } from "@/api/gateway" +import { clearGatewayLogs, getGatewayStatus } from "@/api/gateway" import { PageHeader } from "@/components/page-header" +import { Button } from "@/components/ui/button" import { ScrollArea } from "@/components/ui/scroll-area" import { gatewayAtom } from "@/store/gateway" @@ -15,12 +17,31 @@ export const Route = createFileRoute("/logs")({ function LogsPage() { const { t } = useTranslation() const [logs, setLogs] = useState([]) + const [clearing, setClearing] = useState(false) const logOffsetRef = useRef(0) const logRunIdRef = useRef(-1) + const syncTokenRef = useRef(0) const scrollRef = useRef(null) const gateway = useAtomValue(gatewayAtom) + const handleClearLogs = async () => { + setClearing(true) + try { + const data = await clearGatewayLogs() + syncTokenRef.current += 1 + setLogs([]) + logOffsetRef.current = data.log_total ?? 0 + if (data.log_run_id !== undefined) { + logRunIdRef.current = data.log_run_id + } + } catch { + // Ignore clear failures silently to avoid noisy transient errors. + } finally { + setClearing(false) + } + } + useEffect(() => { let mounted = true let timeout: ReturnType @@ -40,17 +61,17 @@ function LogsPage() { } try { + const requestToken = syncTokenRef.current + const requestOffset = logOffsetRef.current + const requestRunId = logRunIdRef.current const data = await getGatewayStatus({ - log_offset: logOffsetRef.current, - log_run_id: logRunIdRef.current, + log_offset: requestOffset, + log_run_id: requestRunId, }) - if (!mounted) return + if (!mounted || requestToken !== syncTokenRef.current) return - if ( - data.log_run_id !== undefined && - data.log_run_id !== logRunIdRef.current - ) { + if (data.log_run_id !== undefined && data.log_run_id !== requestRunId) { logRunIdRef.current = data.log_run_id logOffsetRef.current = 0 if (data.logs) { @@ -90,13 +111,25 @@ function LogsPage() {
-
-

- {t("navigation.logs")} -

-

- {t("pages.logs.description")} -

+
+
+

+ {t("navigation.logs")} +

+

+ {t("pages.logs.description")} +

+
+ +
@@ -104,7 +137,7 @@ function LogsPage() {
{logs.length === 0 ? (
- Waiting for logs... + {t("pages.logs.empty")}
) : ( logs.map((log, i) => (