diff --git a/Makefile b/Makefile index c7375a544..afc76a6ad 100644 --- a/Makefile +++ b/Makefile @@ -168,11 +168,11 @@ clean: @echo "Clean complete" ## vet: Run go vet for static analysis -vet: +vet: generate @$(GO) vet ./... ## test: Test Go code -test: +test: generate @$(GO) test ./... ## fmt: Format Go code @@ -204,6 +204,44 @@ check: deps fmt vet test run: build @$(BUILD_DIR)/$(BINARY_NAME) $(ARGS) +## docker-build: Build Docker image (minimal Alpine-based) +docker-build: + @echo "Building minimal Docker image (Alpine-based)..." + docker compose -f docker/docker-compose.yml build picoclaw-agent picoclaw-gateway + +## docker-build-full: Build Docker image with full MCP support (Node.js 24) +docker-build-full: + @echo "Building full-featured Docker image (Node.js 24)..." + docker compose -f docker/docker-compose.full.yml build picoclaw-agent picoclaw-gateway + +## docker-test: Test MCP tools in Docker container +docker-test: + @echo "Testing MCP tools in Docker..." + @chmod +x scripts/test-docker-mcp.sh + @./scripts/test-docker-mcp.sh + +## docker-run: Run picoclaw gateway in Docker (Alpine-based) +docker-run: + docker compose -f docker/docker-compose.yml --profile gateway up + +## docker-run-full: Run picoclaw gateway in Docker (full-featured) +docker-run-full: + docker compose -f docker/docker-compose.full.yml --profile gateway up + +## docker-run-agent: Run picoclaw agent in Docker (interactive, Alpine-based) +docker-run-agent: + docker compose -f docker/docker-compose.yml run --rm picoclaw-agent + +## docker-run-agent-full: Run picoclaw agent in Docker (interactive, full-featured) +docker-run-agent-full: + docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent + +## docker-clean: Clean Docker images and volumes +docker-clean: + docker compose -f docker/docker-compose.yml down -v + docker compose -f docker/docker-compose.full.yml down -v + docker rmi picoclaw:latest picoclaw:full 2>/dev/null || true + ## help: Show this help message help: @echo "picoclaw Makefile" @@ -219,6 +257,8 @@ help: @echo " make install # Install to ~/.local/bin" @echo " make uninstall # Remove from /usr/local/bin" @echo " make install-skills # Install skills to workspace" + @echo " make docker-build # Build minimal Docker image" + @echo " make docker-test # Test MCP tools in Docker" @echo "" @echo "Environment Variables:" @echo " INSTALL_PREFIX # Installation prefix (default: ~/.local)" diff --git a/README.md b/README.md index a06f2ea61..c5b38e222 100644 --- a/README.md +++ b/README.md @@ -721,6 +721,20 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa โ””โ”€โ”€ USER.md # User preferences ``` +### Skill Sources + +By default, skills are loaded from: + +1. `~/.picoclaw/workspace/skills` (workspace) +2. `~/.picoclaw/skills` (global) +3. `/skills` (builtin) + +For advanced/test setups, you can override the builtin skills root with: + +```bash +export PICOCLAW_BUILTIN_SKILLS=/path/to/skills +``` + ### ๐Ÿ”’ Security Sandbox PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. @@ -925,7 +939,7 @@ This design also enables **multi-agent support** with flexible provider selectio #### ๐Ÿ“‹ All Supported Vendors | Vendor | `model` Prefix | Default API Base | Protocol | API Key | -| ------------------- | ----------------- | --------------------------------------------------- | --------- | ---------------------------------------------------------------- | +| ------------------- | ----------------- |-----------------------------------------------------| --------- | ---------------------------------------------------------------- | | **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) | | **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) | | **ๆ™บ่ฐฑ AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) | @@ -937,6 +951,7 @@ This design also enables **multi-agent support** with flexible provider selectio | **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) | | **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) | | **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) | +| **LiteLLM Proxy** | `litellm/` | `http://localhost:4000/v1 | OpenAI | Your LiteLLM proxy key | | **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local | | **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) | | **็ซๅฑฑๅผ•ๆ“Ž** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) | @@ -1038,6 +1053,19 @@ This design also enables **multi-agent support** with flexible provider selectio } ``` +**LiteLLM Proxy** + +```json +{ + "model_name": "lite-gpt4", + "model": "litellm/lite-gpt4", + "api_base": "http://localhost:4000/v1", + "api_key": "sk-..." +} +``` + +PicoClaw strips only the outer `litellm/` prefix before sending the request, so proxy aliases like `litellm/lite-gpt4` send `lite-gpt4`, while `litellm/openai/gpt-4o` sends `openai/gpt-4o`. + #### Load Balancing Configure multiple endpoints for the same model nameโ€”PicoClaw will automatically round-robin between them: diff --git a/README.zh.md b/README.zh.md index d3a49ee8d..db96ba555 100644 --- a/README.zh.md +++ b/README.zh.md @@ -362,6 +362,20 @@ PicoClaw ๅฐ†ๆ•ฐๆฎๅญ˜ๅ‚จๅœจๆ‚จ้…็ฝฎ็š„ๅทฅไฝœๅŒบไธญ๏ผˆ้ป˜่ฎค๏ผš`~/.picoclaw/work ``` +### ๆŠ€่ƒฝๆฅๆบ (Skill Sources) + +้ป˜่ฎคๆƒ…ๅ†ตไธ‹๏ผŒๆŠ€่ƒฝไผšๆŒ‰ไปฅไธ‹้กบๅบๅŠ ่ฝฝ๏ผš + +1. `~/.picoclaw/workspace/skills`๏ผˆๅทฅไฝœๅŒบ๏ผ‰ +2. `~/.picoclaw/skills`๏ผˆๅ…จๅฑ€๏ผ‰ +3. `/skills`๏ผˆๅ†…็ฝฎ๏ผ‰ + +ๅœจ้ซ˜็บง/ๆต‹่ฏ•ๅœบๆ™ฏไธ‹๏ผŒๅฏ้€š่ฟ‡ไปฅไธ‹็Žฏๅขƒๅ˜้‡่ฆ†็›–ๅ†…็ฝฎๆŠ€่ƒฝ็›ฎๅฝ•๏ผš + +```bash +export PICOCLAW_BUILTIN_SKILLS=/path/to/skills +``` + ### ๅฟƒ่ทณ / ๅ‘จๆœŸๆ€งไปปๅŠก (Heartbeat) PicoClaw ๅฏไปฅ่‡ชๅŠจๆ‰ง่กŒๅ‘จๆœŸๆ€งไปปๅŠกใ€‚ๅœจๅทฅไฝœๅŒบๅˆ›ๅปบ `HEARTBEAT.md` ๆ–‡ไปถ๏ผš diff --git a/cmd/picoclaw-launcher-tui/internal/ui/style.go b/cmd/picoclaw-launcher-tui/internal/ui/style.go index ff4f8b1a8..68cdd60b9 100644 --- a/cmd/picoclaw-launcher-tui/internal/ui/style.go +++ b/cmd/picoclaw-launcher-tui/internal/ui/style.go @@ -5,6 +5,19 @@ import ( "github.com/rivo/tview" ) +const ( + colorBlue = "[#3e5db9]" + colorRed = "[#d54646]" + banner = "\r\n[::b]" + + colorBlue + "โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— " + colorRed + " โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•—\n" + + colorBlue + "โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•”โ•โ•โ•โ•โ•โ–ˆโ–ˆโ•”โ•โ•โ•โ–ˆโ–ˆโ•—" + colorRed + "โ–ˆโ–ˆโ•”โ•โ•โ•โ•โ•โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘\n" + + colorBlue + "โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•”โ•โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘" + colorRed + "โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ•— โ–ˆโ–ˆโ•‘\n" + + colorBlue + "โ–ˆโ–ˆโ•”โ•โ•โ•โ• โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘" + colorRed + "โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘\n" + + colorBlue + "โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•”โ•" + colorRed + "โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ•šโ–ˆโ–ˆโ–ˆโ•”โ–ˆโ–ˆโ–ˆโ•”โ•\n" + + colorBlue + "โ•šโ•โ• โ•šโ•โ• โ•šโ•โ•โ•โ•โ•โ• โ•šโ•โ•โ•โ•โ•โ• " + colorRed + " โ•šโ•โ•โ•โ•โ•โ•โ•šโ•โ•โ•โ•โ•โ•โ•โ•šโ•โ• โ•šโ•โ• โ•šโ•โ•โ•โ•šโ•โ•โ•\n " + + "[:]" +) + func applyStyles() { tview.Styles.PrimitiveBackgroundColor = tcell.NewRGBColor(12, 13, 22) tview.Styles.ContrastBackgroundColor = tcell.NewRGBColor(34, 19, 53) @@ -24,14 +37,7 @@ func bannerView() *tview.TextView { text.SetDynamicColors(true) text.SetTextAlign(tview.AlignCenter) text.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor) - text.SetText( - "[::b][#84aaff]โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•—\n" + - "[#84aaff]โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•”โ•โ•โ•โ•โ•โ–ˆโ–ˆโ•”โ•โ•โ•โ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•”โ•โ•โ•โ•โ•โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘\n" + - "[#84aaff]โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•”โ•โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ•— โ–ˆโ–ˆโ•‘\n" + - "[#84aaff]โ–ˆโ–ˆโ•”โ•โ•โ•โ• โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘\n" + - "[#84aaff]โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•”โ•โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ•šโ–ˆโ–ˆโ–ˆโ•”โ–ˆโ–ˆโ–ˆโ•”โ•\n" + - "[#84aaff]โ•šโ•โ• โ•šโ•โ• โ•šโ•โ•โ•โ•โ•โ• โ•šโ•โ•โ•โ•โ•โ• โ•šโ•โ•โ•โ•โ•โ•โ•šโ•โ•โ•โ•โ•โ•โ•โ•šโ•โ• โ•šโ•โ• โ•šโ•โ•โ•โ•šโ•โ•โ•", - ) + text.SetText(banner) text.SetBorder(false) return text } diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 6db69c990..d9263462e 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -48,7 +48,21 @@ func NewPicoclawCommand() *cobra.Command { return cmd } +const ( + colorBlue = "\033[1;38;2;62;93;185m" + colorRed = "\033[1;38;2;213;70;70m" + banner = "\r\n" + + colorBlue + "โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— " + colorRed + " โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•— โ–ˆโ–ˆโ•—\n" + + colorBlue + "โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•”โ•โ•โ•โ•โ•โ–ˆโ–ˆโ•”โ•โ•โ•โ–ˆโ–ˆโ•—" + colorRed + "โ–ˆโ–ˆโ•”โ•โ•โ•โ•โ•โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘\n" + + colorBlue + "โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•”โ•โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘" + colorRed + "โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ•— โ–ˆโ–ˆโ•‘\n" + + colorBlue + "โ–ˆโ–ˆโ•”โ•โ•โ•โ• โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘" + colorRed + "โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•”โ•โ•โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ•‘โ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘\n" + + colorBlue + "โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•”โ•" + colorRed + "โ•šโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ•—โ–ˆโ–ˆโ•‘ โ–ˆโ–ˆโ•‘โ•šโ–ˆโ–ˆโ–ˆโ•”โ–ˆโ–ˆโ–ˆโ•”โ•\n" + + colorBlue + "โ•šโ•โ• โ•šโ•โ• โ•šโ•โ•โ•โ•โ•โ• โ•šโ•โ•โ•โ•โ•โ• " + colorRed + " โ•šโ•โ•โ•โ•โ•โ•โ•šโ•โ•โ•โ•โ•โ•โ•โ•šโ•โ• โ•šโ•โ• โ•šโ•โ•โ•โ•šโ•โ•โ•\n " + + "\033[0m\r\n" +) + func main() { + fmt.Printf("%s", banner) cmd := NewPicoclawCommand() if err := cmd.Execute(); err != nil { os.Exit(1) diff --git a/config/config.example.json b/config/config.example.json index e292731b9..3c84cfa9f 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -49,6 +49,7 @@ "telegram": { "enabled": false, "token": "YOUR_TELEGRAM_BOT_TOKEN", + "base_url": "", "proxy": "", "allow_from": [ "YOUR_USER_ID" @@ -243,6 +244,71 @@ "cron": { "exec_timeout_minutes": 5 }, + "mcp": { + "enabled": false, + "servers": { + "context7": { + "enabled": false, + "type": "http", + "url": "https://mcp.context7.com/mcp", + "headers": { + "CONTEXT7_API_KEY": "ctx7sk-xx" + } + }, + "filesystem": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + "/tmp" + ] + }, + "github": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN" + } + }, + "brave-search": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-brave-search" + ], + "env": { + "BRAVE_API_KEY": "YOUR_BRAVE_API_KEY" + } + }, + "postgres": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-postgres", + "postgresql://user:password@localhost/dbname" + ] + }, + "slack": { + "enabled": false, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-slack" + ], + "env": { + "SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN", + "SLACK_TEAM_ID": "YOUR_SLACK_TEAM_ID" + } + } + } + }, "exec": { "enable_deny_patterns": false, "custom_deny_patterns": [] @@ -271,4 +337,4 @@ "host": "127.0.0.1", "port": 18790 } -} +} \ No newline at end of file diff --git a/docker/Dockerfile.full b/docker/Dockerfile.full new file mode 100644 index 000000000..30e1680d5 --- /dev/null +++ b/docker/Dockerfile.full @@ -0,0 +1,44 @@ +# ============================================================ +# Stage 1: Build the picoclaw binary +# ============================================================ +FROM golang:1.26.0-alpine AS builder + +RUN apk add --no-cache git make + +WORKDIR /src + +# Cache dependencies +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source and build +COPY . . +RUN make build + +# ============================================================ +# Stage 2: Node.js-based runtime with full MCP support +# ============================================================ +FROM node:24-alpine3.23 + +# Install runtime dependencies +RUN apk add --no-cache \ + ca-certificates \ + curl \ + git \ + python3 \ + py3-pip + +# Install uv and symlink to system path +RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ + ln -s /root/.local/bin/uv /usr/local/bin/uv && \ + ln -s /root/.local/bin/uvx /usr/local/bin/uvx && \ + uv --version + +# Copy binary +COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw + +# Create picoclaw home directory +RUN /usr/local/bin/picoclaw onboard + +ENTRYPOINT ["picoclaw"] +CMD ["gateway"] diff --git a/docker/docker-compose.full.yml b/docker/docker-compose.full.yml new file mode 100644 index 000000000..6f34448c4 --- /dev/null +++ b/docker/docker-compose.full.yml @@ -0,0 +1,44 @@ +services: + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # PicoClaw Agent (one-shot query) - Full MCP Support + # docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent -m "Hello" + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + picoclaw-agent: + build: + context: .. + dockerfile: docker/Dockerfile.full + container_name: picoclaw-agent-full + profiles: + - agent + volumes: + - ../config/config.json:/root/.picoclaw/config.json:ro + - picoclaw-workspace:/root/.picoclaw/workspace + - picoclaw-npm-cache:/root/.npm # npm cache for faster MCP server installs + entrypoint: ["picoclaw", "agent"] + stdin_open: true + tty: true + + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # PicoClaw Gateway (Long-running Bot) - Full MCP Support + # docker compose -f docker/docker-compose.full.yml --profile gateway up + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + picoclaw-gateway: + build: + context: .. + dockerfile: docker/Dockerfile.full + container_name: picoclaw-gateway-full + restart: unless-stopped + profiles: + - gateway + volumes: + # Configuration file + - ../config/config.json:/root/.picoclaw/config.json:ro + # Persistent workspace (sessions, memory, logs) + - picoclaw-workspace:/root/.picoclaw/workspace + # NPM cache for faster MCP server installs + - picoclaw-npm-cache:/root/.npm + command: ["gateway"] + +volumes: + picoclaw-workspace: + picoclaw-npm-cache: # Cache npm packages to speed up MCP server installations diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md index 8aba1aa91..6204fb0c8 100644 --- a/docs/tools_configuration.md +++ b/docs/tools_configuration.md @@ -8,6 +8,7 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`. { "tools": { "web": { ... }, + "mcp": { ... }, "exec": { ... }, "cron": { ... }, "skills": { ... } @@ -21,35 +22,35 @@ Web tools are used for web search and fetching. ### Brave -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enabled` | bool | false | Enable Brave search | -| `api_key` | string | - | Brave Search API key | -| `max_results` | int | 5 | Maximum number of results | +| Config | Type | Default | Description | +| ------------- | ------ | ------- | ------------------------- | +| `enabled` | bool | false | Enable Brave search | +| `api_key` | string | - | Brave Search API key | +| `max_results` | int | 5 | Maximum number of results | ### DuckDuckGo -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enabled` | bool | true | Enable DuckDuckGo search | -| `max_results` | int | 5 | Maximum number of results | +| Config | Type | Default | Description | +| ------------- | ---- | ------- | ------------------------- | +| `enabled` | bool | true | Enable DuckDuckGo search | +| `max_results` | int | 5 | Maximum number of results | ### Perplexity -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enabled` | bool | false | Enable Perplexity search | -| `api_key` | string | - | Perplexity API key | -| `max_results` | int | 5 | Maximum number of results | +| Config | Type | Default | Description | +| ------------- | ------ | ------- | ------------------------- | +| `enabled` | bool | false | Enable Perplexity search | +| `api_key` | string | - | Perplexity API key | +| `max_results` | int | 5 | Maximum number of results | ## Exec Tool The exec tool is used to execute shell commands. -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking | -| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) | +| Config | Type | Default | Description | +| ---------------------- | ----- | ------- | ------------------------------------------ | +| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking | +| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) | ### Functionality @@ -80,10 +81,7 @@ By default, PicoClaw blocks the following dangerous commands: "tools": { "exec": { "enable_deny_patterns": true, - "custom_deny_patterns": [ - "\\brm\\s+-r\\b", - "\\bkillall\\s+python" - ] + "custom_deny_patterns": ["\\brm\\s+-r\\b", "\\bkillall\\s+python"] } } } @@ -93,9 +91,84 @@ By default, PicoClaw blocks the following dangerous commands: The cron tool is used for scheduling periodic tasks. -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit | +| Config | Type | Default | Description | +| ---------------------- | ---- | ------- | ---------------------------------------------- | +| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit | + +## MCP Tool + +The MCP tool enables integration with external Model Context Protocol servers. + +### Global Config + +| Config | Type | Default | Description | +| --------- | ------ | ------- | ----------------------------------- | +| `enabled` | bool | false | Enable MCP integration globally | +| `servers` | object | `{}` | Map of server name to server config | + +### Per-Server Config + +| Config | Type | Required | Description | +| ---------- | ------ | -------- | ------------------------------------------ | +| `enabled` | bool | yes | Enable this MCP server | +| `type` | string | no | Transport type: `stdio`, `sse`, `http` | +| `command` | string | stdio | Executable command for stdio transport | +| `args` | array | no | Command arguments for stdio transport | +| `env` | object | no | Environment variables for stdio process | +| `env_file` | string | no | Path to environment file for stdio process | +| `url` | string | sse/http | Endpoint URL for `sse`/`http` transport | +| `headers` | object | no | HTTP headers for `sse`/`http` transport | + +### Transport Behavior + +- If `type` is omitted, transport is auto-detected: + - `url` is set โ†’ `sse` + - `command` is set โ†’ `stdio` +- `http` and `sse` both use `url` + optional `headers`. +- `env` and `env_file` are only applied to `stdio` servers. + +### Configuration Examples + +#### 1) Stdio MCP server + +```json +{ + "tools": { + "mcp": { + "enabled": true, + "servers": { + "filesystem": { + "enabled": true, + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } + } + } +} +``` + +#### 2) Remote SSE/HTTP MCP server + +```json +{ + "tools": { + "mcp": { + "enabled": true, + "servers": { + "remote-mcp": { + "enabled": true, + "type": "sse", + "url": "https://example.com/mcp", + "headers": { + "Authorization": "Bearer YOUR_TOKEN" + } + } + } + } + } +} +``` ## Skills Tool @@ -103,13 +176,13 @@ The skills tool configures skill discovery and installation via registries like ### Registries -| Config | Type | Default | Description | -|--------|------|---------|-------------| -| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | -| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | -| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | -| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | -| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | +| Config | Type | Default | Description | +| ---------------------------------- | ------ | -------------------- | ----------------------- | +| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry | +| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL | +| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path | +| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path | +| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path | ### Configuration Example @@ -136,8 +209,10 @@ The skills tool configures skill discovery and installation via registries like All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS_
_`: For example: + - `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true` - `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false` - `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10` +- `PICOCLAW_TOOLS_MCP_ENABLED=true` -Note: Array-type environment variables are not currently supported and must be set via the config file. +Note: Nested map-style config (for example `tools.mcp.servers..*`) is configured in `config.json` rather than environment variables. diff --git a/go.mod b/go.mod index 7892cade6..c1172937c 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,16 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/gdamore/tcell/v2 v2.13.8 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mdp/qrterminal/v3 v3.2.1 + github.com/modelcontextprotocol/go-sdk v1.3.0 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 + github.com/rivo/tview v0.42.0 github.com/slack-go/slack v0.17.3 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 @@ -35,6 +38,7 @@ require ( github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect github.com/gdamore/encoding v1.0.1 // indirect github.com/gdamore/tcell/v2 v2.13.8 // indirect + github.com/h2non/filetype v1.1.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -43,7 +47,6 @@ require ( github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/tview v0.42.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/zerolog v1.34.0 // indirect github.com/spf13/pflag v1.0.10 // indirect @@ -81,6 +84,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fastjson v1.6.7 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/arch v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect diff --git a/go.sum b/go.sum index d1ee1d629..060594d06 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncV github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -96,6 +98,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= +github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= +github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -130,6 +134,8 @@ github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4= github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU= +github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs= +github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE= github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= @@ -212,6 +218,8 @@ github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTd github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 6fccbaf53..3aa903b3f 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -34,6 +34,11 @@ type ContextBuilder struct { // created (didn't exist at cache time, now exist) or deleted (existed at // cache time, now gone) โ€” both of which should trigger a cache rebuild. existedAtCache map[string]bool + + // skillFilesAtCache snapshots the skill tree file set and mtimes at cache + // build time. This catches nested file creations/deletions/mtime changes + // that may not update the top-level skill root directory mtime. + skillFilesAtCache map[string]time.Time } func getGlobalConfigDir() string { @@ -47,8 +52,11 @@ func getGlobalConfigDir() string { func NewContextBuilder(workspace string) *ContextBuilder { // builtin skills: skills directory in current project // Use the skills/ directory under the current working directory - wd, _ := os.Getwd() - builtinSkillsDir := filepath.Join(wd, "skills") + builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS")) + if builtinSkillsDir == "" { + wd, _ := os.Getwd() + builtinSkillsDir = filepath.Join(wd, "skills") + } globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills") return &ContextBuilder{ @@ -148,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string { cb.cachedSystemPrompt = prompt cb.cachedAt = baseline.maxMtime cb.existedAtCache = baseline.existed + cb.skillFilesAtCache = baseline.skillFiles logger.DebugCF("agent", "System prompt cached", map[string]any{ @@ -167,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() { cb.cachedSystemPrompt = "" cb.cachedAt = time.Time{} cb.existedAtCache = nil + cb.skillFilesAtCache = nil logger.DebugCF("agent", "System prompt cache invalidated", nil) } -// sourcePaths returns the workspace source file paths tracked for cache -// invalidation (bootstrap files + memory). The skills directory is handled -// separately in sourceFilesChangedLocked because it requires both directory- -// level and recursive file-level mtime checks. +// sourcePaths returns non-skill workspace source files tracked for cache +// invalidation (bootstrap files + memory). Skill roots are handled separately +// because they require both directory-level and recursive file-level checks. func (cb *ContextBuilder) sourcePaths() []string { return []string{ filepath.Join(cb.workspace, "AGENTS.md"), @@ -185,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string { } } +// skillRoots returns all skill root directories that can affect +// BuildSkillsSummary output (workspace/global/builtin). +func (cb *ContextBuilder) skillRoots() []string { + if cb.skillsLoader == nil { + return []string{filepath.Join(cb.workspace, "skills")} + } + + roots := cb.skillsLoader.SkillRoots() + if len(roots) == 0 { + return []string{filepath.Join(cb.workspace, "skills")} + } + return roots +} + // cacheBaseline holds the file existence snapshot and the latest observed // mtime across all tracked paths. Used as the cache reference point. type cacheBaseline struct { - existed map[string]bool - maxMtime time.Time + existed map[string]bool + skillFiles map[string]time.Time + maxMtime time.Time } // buildCacheBaseline records which tracked paths currently exist and computes // the latest mtime across all tracked files + skills directory contents. // Called under write lock when the cache is built. func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { - skillsDir := filepath.Join(cb.workspace, "skills") + skillRoots := cb.skillRoots() - // All paths whose existence we track: source files + skills dir. - allPaths := append(cb.sourcePaths(), skillsDir) + // All paths whose existence we track: source files + all skill roots. + allPaths := append(cb.sourcePaths(), skillRoots...) existed := make(map[string]bool, len(allPaths)) + skillFiles := make(map[string]time.Time) var maxMtime time.Time for _, p := range allPaths { @@ -212,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { } } - // Walk skills files to capture their mtimes too. - // Use os.Stat (not d.Info) to match the stat method used in - // fileChangedSince / skillFilesModifiedSince for consistency. - _ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { - if walkErr == nil && !d.IsDir() { - if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) { - maxMtime = info.ModTime() + // Walk all skill roots recursively to snapshot skill files and mtimes. + // Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks. + for _, root := range skillRoots { + _ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr == nil && !d.IsDir() { + if info, err := os.Stat(path); err == nil { + skillFiles[path] = info.ModTime() + if info.ModTime().After(maxMtime) { + maxMtime = info.ModTime() + } + } } - } - return nil - }) + return nil + }) + } // If no tracked files exist yet (empty workspace), maxMtime is zero. // Use a very old non-zero time so that: @@ -234,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline { maxMtime = time.Unix(1, 0) } - return cacheBaseline{existed: existed, maxMtime: maxMtime} + return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime} } // sourceFilesChangedLocked checks whether any workspace source file has been @@ -254,21 +283,17 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool { return true } - // --- Skills directory (handled separately from sourcePaths) --- + // --- Skill roots (workspace/global/builtin) --- // - // 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files. - skillsDir := filepath.Join(cb.workspace, "skills") - if cb.fileChangedSince(skillsDir) { - return true + // For each root: + // 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince. + // 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot. + for _, root := range cb.skillRoots() { + if cb.fileChangedSince(root) { + return true + } } - - // 2. Structural changes (add/remove entries inside the dir) are reflected - // in the directory's own mtime, which fileChangedSince already checks. - // - // 3. Content-only edits to files inside skills/ do NOT update the parent - // directory mtime on most filesystems, so we recursively walk to check - // individual file mtimes at any nesting depth. - if skillFilesModifiedSince(skillsDir, cb.cachedAt) { + if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) { return true } @@ -309,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool { // if the callback returned nil when its err parameter is non-nil. var errWalkStop = errors.New("walk stop") -// skillFilesModifiedSince recursively walks the skills directory and checks -// whether any file was modified after t. This catches content-only edits at -// any nesting depth (e.g. skills/name/docs/extra.md) that don't update -// parent directory mtimes. -func skillFilesModifiedSince(skillsDir string, t time.Time) bool { - changed := false - err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error { - if walkErr == nil && !d.IsDir() { - if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) { - changed = true - return errWalkStop // stop walking - } - } - return nil - }) - // errWalkStop is expected (early exit on first changed file). - // os.IsNotExist means the skills dir doesn't exist yet โ€” not an error. - // Any other error is unexpected and worth logging. - if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) { - logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()}) +// skillFilesChangedSince compares the current recursive skill file tree +// against the cache-time snapshot. Any create/delete/mtime drift invalidates +// the cache. +func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool { + // Defensive: if the snapshot was never initialized, force rebuild. + if filesAtCache == nil { + return true } - return changed + + // Check cached files still exist and keep the same mtime. + for path, cachedMtime := range filesAtCache { + info, err := os.Stat(path) + if err != nil { + // A previously tracked file disappeared (or became inaccessible): + // either way, cached skill summary may now be stale. + return true + } + if !info.ModTime().Equal(cachedMtime) { + return true + } + } + + // Check no new files appeared under any skill root. + changed := false + for _, root := range skillRoots { + if strings.TrimSpace(root) == "" { + continue + } + + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + // Treat unexpected walk errors as changed to avoid stale cache. + if !os.IsNotExist(walkErr) { + changed = true + return errWalkStop + } + return nil + } + if d.IsDir() { + return nil + } + if _, ok := filesAtCache[path]; !ok { + changed = true + return errWalkStop + } + return nil + }) + + if changed { + return true + } + if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) { + logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()}) + return true + } + } + + return false } func (cb *ContextBuilder) LoadBootstrapFiles() string { @@ -466,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages( // Add current user message if strings.TrimSpace(currentMessage) != "" { - messages = append(messages, providers.Message{ + msg := providers.Message{ Role: "user", Content: currentMessage, - }) + } + if len(media) > 0 { + msg.Media = media + } + messages = append(messages, msg) } return messages diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go index 0905e8a46..707510820 100644 --- a/pkg/agent/context_cache_test.go +++ b/pkg/agent/context_cache_test.go @@ -383,6 +383,162 @@ Updated content.` } } +// TestGlobalSkillFileContentChange verifies that modifying a global skill +// (~/.picoclaw/skills) invalidates the cached system prompt. +func TestGlobalSkillFileContentChange(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md") + if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil { + t.Fatal(err) + } + v1 := `--- +name: global-skill +description: global-v1 +--- +# Global Skill v1` + if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil { + t.Fatal(err) + } + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "global-v1") { + t.Fatal("expected initial prompt to contain global skill description") + } + + v2 := `--- +name: global-skill +description: global-v2 +--- +# Global Skill v2` + if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil { + t.Fatal(err) + } + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(globalSkillPath, future, future); err != nil { + t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect global skill file content change") + } + + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, "global-v2") { + t.Error("rebuilt prompt should contain updated global skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when global skill file content changes") + } +} + +// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill +// invalidates the cached system prompt. +func TestBuiltinSkillFileContentChange(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + tmpDir := setupWorkspace(t, nil) + defer os.RemoveAll(tmpDir) + + builtinRoot := t.TempDir() + t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot) + + builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md") + if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil { + t.Fatal(err) + } + v1 := `--- +name: builtin-skill +description: builtin-v1 +--- +# Builtin Skill v1` + if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil { + t.Fatal(err) + } + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "builtin-v1") { + t.Fatal("expected initial prompt to contain builtin skill description") + } + + v2 := `--- +name: builtin-skill +description: builtin-v2 +--- +# Builtin Skill v2` + if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil { + t.Fatal(err) + } + future := time.Now().Add(2 * time.Second) + if err := os.Chtimes(builtinSkillPath, future, future); err != nil { + t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change") + } + + sp2 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp2, "builtin-v2") { + t.Error("rebuilt prompt should contain updated builtin skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when builtin skill file content changes") + } +} + +// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill +// file invalidates the cached system prompt. +func TestSkillFileDeletionInvalidatesCache(t *testing.T) { + tmpDir := setupWorkspace(t, map[string]string{ + "skills/delete-me/SKILL.md": `--- +name: delete-me +description: delete-me-v1 +--- +# Delete Me`, + }) + defer os.RemoveAll(tmpDir) + + cb := NewContextBuilder(tmpDir) + sp1 := cb.BuildSystemPromptWithCache() + if !strings.Contains(sp1, "delete-me-v1") { + t.Fatal("expected initial prompt to contain skill description") + } + + skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md") + if err := os.Remove(skillPath); err != nil { + t.Fatal(err) + } + + cb.systemPromptMutex.RLock() + changed := cb.sourceFilesChangedLocked() + cb.systemPromptMutex.RUnlock() + if !changed { + t.Fatal("sourceFilesChangedLocked() should detect deleted skill file") + } + + sp2 := cb.BuildSystemPromptWithCache() + if strings.Contains(sp2, "delete-me-v1") { + t.Error("rebuilt prompt should not contain deleted skill description") + } + if sp1 == sp2 { + t.Error("cache should be invalidated when skill file is deleted") + } +} + // TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines // can safely call BuildSystemPromptWithCache concurrently without producing // empty results, panics, or data races. diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 00b0f096a..b803187b1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -23,6 +23,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/mcp" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" @@ -46,19 +47,24 @@ type AgentLoop struct { // processOptions configures how a message is processed type processOptions struct { - SessionKey string // Session identifier for history/context - Channel string // Target channel for tool execution - ChatID string // Target chat ID for tool execution - UserMessage string // User message content (may include prefix) - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - NoHistory bool // If true, don't load session history (for heartbeat) + SessionKey string // Session identifier for history/context + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + UserMessage string // User message content (may include prefix) + Media []string // media:// refs from inbound message + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) } const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json." -func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { +func NewAgentLoop( + cfg *config.Config, + msgBus *bus.MessageBus, + provider providers.LLMProvider, +) *AgentLoop { registry := NewAgentRegistry(cfg, provider) // Register shared tools to all agents @@ -170,6 +176,72 @@ func registerSharedTools( func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) + // Initialize MCP servers for all agents + if al.cfg.Tools.MCP.Enabled { + mcpManager := mcp.NewManager() + // Ensure MCP connections are cleaned up on exit, regardless of initialization success + // This fixes resource leak when LoadFromMCPConfig partially succeeds then fails + defer func() { + if err := mcpManager.Close(); err != nil { + logger.ErrorCF("agent", "Failed to close MCP manager", + map[string]any{ + "error": err.Error(), + }) + } + }() + + defaultAgent := al.registry.GetDefaultAgent() + var workspacePath string + if defaultAgent != nil && defaultAgent.Workspace != "" { + workspacePath = defaultAgent.Workspace + } else { + workspacePath = al.cfg.WorkspacePath() + } + + if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil { + logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available", + map[string]any{ + "error": err.Error(), + }) + } else { + // Register MCP tools for all agents + servers := mcpManager.GetServers() + uniqueTools := 0 + totalRegistrations := 0 + agentIDs := al.registry.ListAgentIDs() + agentCount := len(agentIDs) + + for serverName, conn := range servers { + uniqueTools += len(conn.Tools) + for _, tool := range conn.Tools { + for _, agentID := range agentIDs { + agent, ok := al.registry.GetAgent(agentID) + if !ok { + continue + } + mcpTool := tools.NewMCPTool(mcpManager, serverName, tool) + agent.Tools.Register(mcpTool) + totalRegistrations++ + logger.DebugCF("agent", "Registered MCP tool", + map[string]any{ + "agent_id": agentID, + "server": serverName, + "tool": tool.Name, + "name": mcpTool.Name(), + }) + } + } + } + logger.InfoCF("agent", "MCP tools registered successfully", + map[string]any{ + "server_count": len(servers), + "unique_tools": uniqueTools, + "total_registrations": totalRegistrations, + "agent_count": agentCount, + }) + } + } + for al.running.Load() { select { case <-ctx.Done(): @@ -310,7 +382,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error { return al.state.SetLastChatID(chatID) } -func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { +func (al *AgentLoop) ProcessDirect( + ctx context.Context, + content, sessionKey string, +) (string, error) { return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") } @@ -331,7 +406,10 @@ func (al *AgentLoop) ProcessDirectWithChannel( // ProcessHeartbeat processes a heartbeat request without session history. // Each heartbeat is independent and doesn't accumulate context. -func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { +func (al *AgentLoop) ProcessHeartbeat( + ctx context.Context, + content, channel, chatID string, +) (string, error) { agent := al.registry.GetDefaultAgent() if agent == nil { return "", fmt.Errorf("no default agent for heartbeat") @@ -356,13 +434,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } else { logContent = utils.Truncate(msg.Content, 80) } - logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), + logger.InfoCF( + "agent", + fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), map[string]any{ "channel": msg.Channel, "chat_id": msg.ChatID, "sender_id": msg.SenderID, "session_key": msg.SessionKey, - }) + }, + ) // Route system messages to processSystemMessage if msg.Channel == "system" { @@ -417,15 +498,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, + Media: msg.Media, DefaultResponse: defaultResponse, EnableSummary: true, SendResponse: false, }) } -func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { +func (al *AgentLoop) processSystemMessage( + ctx context.Context, + msg bus.InboundMessage, +) (string, error) { if msg.Channel != "system" { - return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel) + return "", fmt.Errorf( + "processSystemMessage called with non-system message channel: %s", + msg.Channel, + ) } logger.InfoCF("agent", "Processing system message", @@ -483,14 +571,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe } // runAgentLoop is the core message processing logic. -func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) { +func (al *AgentLoop) runAgentLoop( + ctx context.Context, + agent *AgentInstance, + opts processOptions, +) (string, error) { // 0. Record last channel for heartbeat notifications (skip internal channels) if opts.Channel != "" && opts.ChatID != "" { // Don't record internal channels (cli, system, subagent) if !constants.IsInternalChannel(opts.Channel) { channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) if err := al.RecordLastChannel(channelKey); err != nil { - logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()}) + logger.WarnCF( + "agent", + "Failed to record last channel", + map[string]any{"error": err.Error()}, + ) } } } @@ -509,11 +605,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt history, summary, opts.UserMessage, - nil, + opts.Media, opts.Channel, opts.ChatID, ) + // Resolve media:// refs to base64 data URLs (streaming) + maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize() + messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) + // 3. Save user message to session agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) @@ -572,7 +672,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string return "" } -func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) { +func (al *AgentLoop) handleReasoning( + ctx context.Context, + reasoningContent, channelName, channelID string, +) { if reasoningContent == "" || channelName == "" || channelID == "" { return } @@ -665,22 +768,33 @@ func (al *AgentLoop) runLLMIteration( callLLM := func() (*providers.LLMResponse, error) { if len(agent.Candidates) > 1 && al.fallback != nil { - fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, + fbResult, fbErr := al.fallback.Execute( + ctx, + agent.Candidates, func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { - return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{ - "max_tokens": agent.MaxTokens, - "temperature": agent.Temperature, - "prompt_cache_key": agent.ID, - }) + return agent.Provider.Chat( + ctx, + messages, + providerToolDefs, + model, + map[string]any{ + "max_tokens": agent.MaxTokens, + "temperature": agent.Temperature, + "prompt_cache_key": agent.ID, + }, + ) }, ) if fbErr != nil { return nil, fbErr } if fbResult.Provider != "" && len(fbResult.Attempts) > 0 { - logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", - fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), - map[string]any{"agent_id": agent.ID, "iteration": iteration}) + logger.InfoCF( + "agent", + fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", + fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), + map[string]any{"agent_id": agent.ID, "iteration": iteration}, + ) } return fbResult.Response, nil } @@ -731,10 +845,14 @@ func (al *AgentLoop) runLLMIteration( } if isContextError && retry < maxRetries { - logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{ - "error": err.Error(), - "retry": retry, - }) + logger.WarnCF( + "agent", + "Context window error detected, attempting compression", + map[string]any{ + "error": err.Error(), + "retry": retry, + }, + ) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ @@ -766,7 +884,12 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } - go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel)) + go al.handleReasoning( + ctx, + response.Reasoning, + opts.Channel, + al.targetReasoningChannelID(opts.Channel), + ) logger.DebugCF("agent", "LLM response", map[string]any{ @@ -1068,7 +1191,11 @@ func formatMessagesForLog(messages []providers.Message) string { for _, tc := range msg.ToolCalls { fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name) if tc.Function != nil { - fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200)) + fmt.Fprintf( + &sb, + " Arguments: %s\n", + utils.Truncate(tc.Function.Arguments, 200), + ) } } } @@ -1097,7 +1224,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string { fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name) fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description) if len(tool.Function.Parameters) > 0 { - fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200)) + fmt.Fprintf( + &sb, + " Parameters: %s\n", + utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200), + ) } } sb.WriteString("]") @@ -1194,7 +1325,9 @@ func (al *AgentLoop) summarizeBatch( existingSummary string, ) (string, error) { var sb strings.Builder - sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n") + sb.WriteString( + "Provide a concise summary of this conversation segment, preserving core context and key points.\n", + ) if existingSummary != "" { sb.WriteString("Existing context: ") sb.WriteString(existingSummary) diff --git a/pkg/agent/loop_media.go b/pkg/agent/loop_media.go new file mode 100644 index 000000000..82547a008 --- /dev/null +++ b/pkg/agent/loop_media.go @@ -0,0 +1,122 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package agent + +import ( + "bytes" + "encoding/base64" + "io" + "os" + "strings" + + "github.com/h2non/filetype" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs. +// Uses streaming base64 encoding (file handle โ†’ encoder โ†’ buffer) to avoid holding +// both raw bytes and encoded string in memory simultaneously. +// Returns a new slice; original messages are not mutated. +func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message { + if store == nil { + return messages + } + + result := make([]providers.Message, len(messages)) + copy(result, messages) + + for i, m := range result { + if len(m.Media) == 0 { + continue + } + + resolved := make([]string, 0, len(m.Media)) + for _, ref := range m.Media { + if !strings.HasPrefix(ref, "media://") { + resolved = append(resolved, ref) + continue + } + + localPath, meta, err := store.ResolveWithMeta(ref) + if err != nil { + logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{ + "ref": ref, + "error": err.Error(), + }) + continue + } + + info, err := os.Stat(localPath) + if err != nil { + logger.WarnCF("agent", "Failed to stat media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + if info.Size() > int64(maxSize) { + logger.WarnCF("agent", "Media file too large, skipping", map[string]any{ + "path": localPath, + "size": info.Size(), + "max_size": maxSize, + }) + continue + } + + // Determine MIME type: prefer metadata, fallback to magic-bytes detection + mime := meta.ContentType + if mime == "" { + kind, ftErr := filetype.MatchFile(localPath) + if ftErr != nil || kind == filetype.Unknown { + logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{ + "path": localPath, + }) + continue + } + mime = kind.MIME.Value + } + + // Streaming base64: open file โ†’ base64 encoder โ†’ buffer + // Peak memory: ~1.33x file size (buffer only, no raw bytes copy) + f, err := os.Open(localPath) + if err != nil { + logger.WarnCF("agent", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + + prefix := "data:" + mime + ";base64," + encodedLen := base64.StdEncoding.EncodedLen(int(info.Size())) + var buf bytes.Buffer + buf.Grow(len(prefix) + encodedLen) + buf.WriteString(prefix) + + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + if _, err := io.Copy(encoder, f); err != nil { + f.Close() + logger.WarnCF("agent", "Failed to encode media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + encoder.Close() + f.Close() + + resolved = append(resolved, buf.String()) + } + + result[i].Media = resolved + } + + return result +} diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 3565314fe..023286f02 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -6,12 +6,14 @@ import ( "os" "path/filepath" "slices" + "strings" "testing" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -808,3 +810,142 @@ func TestHandleReasoning(t *testing.T) { } }) } + +func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + // Create a minimal valid PNG (8-byte header is enough for filetype detection) + pngPath := filepath.Join(dir, "test.png") + // PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR + pngHeader := []byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // "IHDR" + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB + 0x00, 0x00, 0x00, // no interlace + 0x90, 0x77, 0x53, 0xDE, // CRC + } + if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil { + t.Fatal(err) + } + ref, err := store.Store(pngPath, media.MediaMeta{}, "test") + if err != nil { + t.Fatal(err) + } + + messages := []providers.Message{ + {Role: "user", Content: "describe this", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 { + t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media)) + } + if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") { + t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40]) + } +} + +func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + bigPath := filepath.Join(dir, "big.png") + // Write PNG header + padding to exceed limit + data := make([]byte, 1024+1) // 1KB + 1 byte + copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}) + if err := os.WriteFile(bigPath, data, 0o644); err != nil { + t.Fatal(err) + } + ref, _ := store.Store(bigPath, media.MediaMeta{}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + // Use a tiny limit (1KB) so the file is oversized + result := resolveMediaRefs(messages, store, 1024) + + if len(result[0].Media) != 0 { + t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media)) + } +} + +func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + txtPath := filepath.Join(dir, "readme.txt") + if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil { + t.Fatal(err) + } + ref, _ := store.Store(txtPath, media.MediaMeta{}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 0 { + t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media)) + } +} + +func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) { + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}}, + } + result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" { + t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media) + } +} + +func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + pngPath := filepath.Join(dir, "test.png") + pngHeader := []byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, + 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, + 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, + } + os.WriteFile(pngPath, pngHeader, 0o644) + ref, _ := store.Store(pngPath, media.MediaMeta{}, "test") + + original := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + originalRef := original[0].Media[0] + + resolveMediaRefs(original, store, config.DefaultMaxMediaSize) + + if original[0].Media[0] != originalRef { + t.Fatal("resolveMediaRefs mutated original message slice") + } +} + +func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) { + store := media.NewFileMediaStore() + dir := t.TempDir() + + // File with JPEG content but stored with explicit content type + jpegPath := filepath.Join(dir, "photo") + jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes + os.WriteFile(jpegPath, jpegHeader, 0o644) + ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test") + + messages := []providers.Message{ + {Role: "user", Content: "hi", Media: []string{ref}}, + } + result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize) + + if len(result[0].Media) != 1 { + t.Fatalf("expected 1 media, got %d", len(result[0].Media)) + } + if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") { + t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30]) + } +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go index e8a057741..fbe085b73 100644 --- a/pkg/channels/feishu/common.go +++ b/pkg/channels/feishu/common.go @@ -1,5 +1,16 @@ package feishu +import ( + "encoding/json" + "regexp" + "strings" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions. +var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`) + // stringValue safely dereferences a *string pointer. func stringValue(v *string) string { if v == nil { @@ -7,3 +18,69 @@ func stringValue(v *string) string { } return *v } + +// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content. +// JSON 2.0 cards support full CommonMark standard markdown syntax. +func buildMarkdownCard(content string) (string, error) { + card := map[string]any{ + "schema": "2.0", + "body": map[string]any{ + "elements": []map[string]any{ + { + "tag": "markdown", + "content": content, + }, + }, + }, + } + data, err := json.Marshal(card) + if err != nil { + return "", err + } + return string(data), nil +} + +// extractJSONStringField unmarshals content as JSON and returns the value of the given string field. +// Returns "" if the content is invalid JSON or the field is missing/empty. +func extractJSONStringField(content, field string) string { + var m map[string]json.RawMessage + if err := json.Unmarshal([]byte(content), &m); err != nil { + return "" + } + raw, ok := m[field] + if !ok { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "" + } + return s +} + +// extractImageKey extracts the image_key from a Feishu image message content JSON. +// Format: {"image_key": "img_xxx"} +func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") } + +// extractFileKey extracts the file_key from a Feishu file/audio message content JSON. +// Format: {"file_key": "file_xxx", "file_name": "...", ...} +func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") } + +// extractFileName extracts the file_name from a Feishu file message content JSON. +func extractFileName(content string) string { return extractJSONStringField(content, "file_name") } + +// stripMentionPlaceholders removes @_user_N placeholders from the text content. +// These are inserted by Feishu when users @mention someone in a message. +func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string { + if len(mentions) == 0 { + return content + } + for _, m := range mentions { + if m.Key != nil && *m.Key != "" { + content = strings.ReplaceAll(content, *m.Key, "") + } + } + // Also clean up any remaining @_user_N patterns + content = mentionPlaceholderRegex.ReplaceAllString(content, "") + return strings.TrimSpace(content) +} diff --git a/pkg/channels/feishu/common_test.go b/pkg/channels/feishu/common_test.go new file mode 100644 index 000000000..fefc9f7c1 --- /dev/null +++ b/pkg/channels/feishu/common_test.go @@ -0,0 +1,292 @@ +package feishu + +import ( + "encoding/json" + "testing" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func TestExtractJSONStringField(t *testing.T) { + tests := []struct { + name string + content string + field string + want string + }{ + { + name: "valid field", + content: `{"image_key": "img_v2_xxx"}`, + field: "image_key", + want: "img_v2_xxx", + }, + { + name: "missing field", + content: `{"image_key": "img_v2_xxx"}`, + field: "file_key", + want: "", + }, + { + name: "invalid JSON", + content: `not json at all`, + field: "image_key", + want: "", + }, + { + name: "empty content", + content: "", + field: "image_key", + want: "", + }, + { + name: "non-string field value", + content: `{"count": 42}`, + field: "count", + want: "", + }, + { + name: "empty string value", + content: `{"image_key": ""}`, + field: "image_key", + want: "", + }, + { + name: "multiple fields", + content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`, + field: "file_name", + want: "test.pdf", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractJSONStringField(tt.content, tt.field) + if got != tt.want { + t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want) + } + }) + } +} + +func TestExtractImageKey(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"image_key": "img_v2_abc123"}`, + want: "img_v2_abc123", + }, + { + name: "missing key", + content: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `{broken`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractImageKey(tt.content) + if got != tt.want { + t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestExtractFileKey(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`, + want: "file_v2_abc123", + }, + { + name: "missing key", + content: `{"image_key": "img_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `not json`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFileKey(tt.content) + if got != tt.want { + t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestExtractFileName(t *testing.T) { + tests := []struct { + name string + content string + want string + }{ + { + name: "normal", + content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`, + want: "report.pdf", + }, + { + name: "missing name", + content: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "malformed JSON", + content: `{bad`, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFileName(tt.content) + if got != tt.want { + t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} + +func TestBuildMarkdownCard(t *testing.T) { + tests := []struct { + name string + content string + }{ + { + name: "normal content", + content: "Hello **world**", + }, + { + name: "empty content", + content: "", + }, + { + name: "special characters", + content: `Code: "foo" & 'baz'`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := buildMarkdownCard(tt.content) + if err != nil { + t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err) + } + + // Verify valid JSON + var parsed map[string]any + if err := json.Unmarshal([]byte(result), &parsed); err != nil { + t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err) + } + + // Verify schema + if parsed["schema"] != "2.0" { + t.Errorf("schema = %v, want %q", parsed["schema"], "2.0") + } + + // Verify body.elements[0].content == input + body, ok := parsed["body"].(map[string]any) + if !ok { + t.Fatal("missing body in card JSON") + } + elements, ok := body["elements"].([]any) + if !ok || len(elements) == 0 { + t.Fatal("missing or empty elements in card JSON") + } + elem, ok := elements[0].(map[string]any) + if !ok { + t.Fatal("first element is not an object") + } + if elem["tag"] != "markdown" { + t.Errorf("tag = %v, want %q", elem["tag"], "markdown") + } + if elem["content"] != tt.content { + t.Errorf("content = %v, want %q", elem["content"], tt.content) + } + }) + } +} + +func TestStripMentionPlaceholders(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + content string + mentions []*larkim.MentionEvent + want string + }{ + { + name: "no mentions", + content: "Hello world", + mentions: nil, + want: "Hello world", + }, + { + name: "single mention", + content: "@_user_1 hello", + mentions: []*larkim.MentionEvent{ + {Key: strPtr("@_user_1")}, + }, + want: "hello", + }, + { + name: "multiple mentions", + content: "@_user_1 @_user_2 hey", + mentions: []*larkim.MentionEvent{ + {Key: strPtr("@_user_1")}, + {Key: strPtr("@_user_2")}, + }, + want: "hey", + }, + { + name: "empty content", + content: "", + mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}}, + want: "", + }, + { + name: "empty mentions slice", + content: "@_user_1 test", + mentions: []*larkim.MentionEvent{}, + want: "@_user_1 test", + }, + { + name: "mention with nil key", + content: "@_user_1 test", + mentions: []*larkim.MentionEvent{ + {Key: nil}, + }, + want: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripMentionPlaceholders(tt.content, tt.mentions) + if got != tt.want { + t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/feishu/feishu_32.go b/pkg/channels/feishu/feishu_32.go index d0ec758c6..f5e3aa224 100644 --- a/pkg/channels/feishu/feishu_32.go +++ b/pkg/channels/feishu/feishu_32.go @@ -16,6 +16,8 @@ type FeishuChannel struct { *channels.BaseChannel } +var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures") + // NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { return nil, errors.New( @@ -25,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan // Start is a stub method to satisfy the Channel interface func (c *FeishuChannel) Start(ctx context.Context) error { - return nil + return errUnsupported } // Stop is a stub method to satisfy the Channel interface func (c *FeishuChannel) Stop(ctx context.Context) error { - return nil + return errUnsupported } // Send is a stub method to satisfy the Channel interface func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - return errors.New("feishu channel is not supported on 32-bit architectures") + return errUnsupported +} + +// EditMessage is a stub method to satisfy MessageEditor +func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return errUnsupported +} + +// SendPlaceholder is a stub method to satisfy PlaceholderCapable +func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + return "", errUnsupported +} + +// ReactToMessage is a stub method to satisfy ReactionCapable +func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + return func() {}, errUnsupported +} + +// SendMedia is a stub method to satisfy MediaSender +func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + return errUnsupported } diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index 1db1bf669..00f73064d 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -6,10 +6,15 @@ import ( "context" "encoding/json" "fmt" + "io" + "net/http" + "os" + "path/filepath" "sync" - "time" + "sync/atomic" lark "github.com/larksuite/oapi-sdk-go/v3" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" larkws "github.com/larksuite/oapi-sdk-go/v3/ws" @@ -19,6 +24,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -28,6 +34,8 @@ type FeishuChannel struct { client *lark.Client wsClient *larkws.Client + botOpenID atomic.Value // stores string; populated lazily for @mention detection + mu sync.Mutex cancel context.CancelFunc } @@ -38,11 +46,13 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) - return &FeishuChannel{ + ch := &FeishuChannel{ BaseChannel: base, config: cfg, client: lark.NewClient(cfg.AppID, cfg.AppSecret), - }, nil + } + ch.SetOwner(ch) + return ch, nil } func (c *FeishuChannel) Start(ctx context.Context) error { @@ -50,6 +60,13 @@ func (c *FeishuChannel) Start(ctx context.Context) error { return fmt.Errorf("feishu app_id or app_secret is empty") } + // Fetch bot open_id via API for reliable @mention detection. + if err := c.fetchBotOpenID(ctx); err != nil { + logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{ + "error": err.Error(), + }) + } + dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey). OnP2MessageReceiveV1(c.handleMessageReceive) @@ -93,46 +110,213 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { return nil } +// Send sends a message using Interactive Card format for markdown rendering. func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning } if msg.ChatID == "" { - return fmt.Errorf("chat ID is empty") + return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) } - payload, err := json.Marshal(map[string]string{"text": msg.Content}) + // Build interactive card with markdown content + cardContent, err := buildMarkdownCard(msg.Content) if err != nil { - return fmt.Errorf("failed to marshal feishu content: %w", err) + return fmt.Errorf("feishu send: card build failed: %w", err) + } + return c.sendCard(ctx, msg.ChatID, cardContent) +} + +// EditMessage implements channels.MessageEditor. +// Uses Message.Patch to update an interactive card message. +func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error { + cardContent, err := buildMarkdownCard(content) + if err != nil { + return fmt.Errorf("feishu edit: card build failed: %w", err) + } + + req := larkim.NewPatchMessageReqBuilder(). + MessageId(messageID). + Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()). + Build() + + resp, err := c.client.Im.V1.Message.Patch(ctx, req) + if err != nil { + return fmt.Errorf("feishu edit: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + return nil +} + +// SendPlaceholder implements channels.PlaceholderCapable. +// Sends an interactive card with placeholder text and returns its message ID. +func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) { + if !c.config.Placeholder.Enabled { + logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{ + "chat_id": chatID, + }) + return "", nil + } + + text := c.config.Placeholder.Text + if text == "" { + text = "Thinking..." + } + + cardContent, err := buildMarkdownCard(text) + if err != nil { + return "", fmt.Errorf("feishu placeholder: card build failed: %w", err) } req := larkim.NewCreateMessageReqBuilder(). ReceiveIdType(larkim.ReceiveIdTypeChatId). Body(larkim.NewCreateMessageReqBodyBuilder(). - ReceiveId(msg.ChatID). - MsgType(larkim.MsgTypeText). - Content(string(payload)). - Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())). + ReceiveId(chatID). + MsgType(larkim.MsgTypeInteractive). + Content(cardContent). Build()). Build() resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("feishu send: %w", channels.ErrTemporary) + return "", fmt.Errorf("feishu placeholder send: %w", err) } - if !resp.Success() { - return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg) } - logger.DebugCF("feishu", "Feishu message sent", map[string]any{ - "chat_id": msg.ChatID, - }) + if resp.Data != nil && resp.Data.MessageId != nil { + return *resp.Data.MessageId, nil + } + return "", nil +} + +// ReactToMessage implements channels.ReactionCapable. +// Adds an "Pin" reaction and returns an undo function to remove it. +func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) { + req := larkim.NewCreateMessageReactionReqBuilder(). + MessageId(messageID). + Body(larkim.NewCreateMessageReactionReqBodyBuilder(). + ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()). + Build()). + Build() + + resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req) + if err != nil { + logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{ + "message_id": messageID, + "error": err.Error(), + }) + return func() {}, fmt.Errorf("feishu react: %w", err) + } + if !resp.Success() { + logger.ErrorCF("feishu", "Reaction API error", map[string]any{ + "message_id": messageID, + "code": resp.Code, + "msg": resp.Msg, + }) + return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + + var reactionID string + if resp.Data != nil && resp.Data.ReactionId != nil { + reactionID = *resp.Data.ReactionId + } + if reactionID == "" { + return func() {}, nil + } + + var undone atomic.Bool + undo := func() { + if !undone.CompareAndSwap(false, true) { + return + } + delReq := larkim.NewDeleteMessageReactionReqBuilder(). + MessageId(messageID). + ReactionId(reactionID). + Build() + _, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq) + } + return undo, nil +} + +// SendMedia implements channels.MediaSender. +// Uploads images/files via Feishu API then sends as messages. +func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + if msg.ChatID == "" { + return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil { + return err + } + } return nil } +// sendMediaPart resolves and sends a single media part. +func (c *FeishuChannel) sendMediaPart( + ctx context.Context, + chatID string, + part bus.MediaPart, + store media.MediaStore, +) error { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + return nil // skip this part + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("feishu", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + return nil // skip this part + } + defer file.Close() + + switch part.Type { + case "image": + err = c.sendImage(ctx, chatID, file) + default: + filename := part.Filename + if filename == "" { + filename = "file" + } + err = c.sendFile(ctx, chatID, file, filename, part.Type) + } + + if err != nil { + logger.ErrorCF("feishu", "Failed to send media", map[string]any{ + "type": part.Type, + "error": err.Error(), + }) + return fmt.Errorf("feishu send media: %w", channels.ErrTemporary) + } + return nil +} + +// --- Inbound message handling --- + func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { if event == nil || event.Event == nil || event.Event.Message == nil { return nil @@ -151,34 +335,68 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. senderID = "unknown" } - content := extractFeishuMessageContent(message) + messageType := stringValue(message.MessageType) + messageID := stringValue(message.MessageId) + rawContent := stringValue(message.Content) + + // Check allowlist early to avoid downloading media for rejected senders. + // BaseChannel.HandleMessage will check again, but this avoids wasted network I/O. + senderInfo := bus.SenderInfo{ + Platform: "feishu", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("feishu", senderID), + } + if !c.IsAllowedSender(senderInfo) { + return nil + } + + // Extract content based on message type + content := extractContent(messageType, rawContent) + + // Handle media messages (download and store) + var mediaRefs []string + if store := c.GetMediaStore(); store != nil && messageID != "" { + mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store) + } + + // Append media tags to content (like Telegram does) + content = appendMediaTags(content, messageType, mediaRefs) + if content == "" { content = "[empty message]" } metadata := map[string]string{} - messageID := "" - if mid := stringValue(message.MessageId); mid != "" { - messageID = mid + if messageID != "" { + metadata["message_id"] = messageID } - if messageType := stringValue(message.MessageType); messageType != "" { + if messageType != "" { metadata["message_type"] = messageType } - if chatType := stringValue(message.ChatType); chatType != "" { + chatType := stringValue(message.ChatType) + if chatType != "" { metadata["chat_type"] = chatType } if sender != nil && sender.TenantKey != nil { metadata["tenant_key"] = *sender.TenantKey } - chatType := stringValue(message.ChatType) var peer bus.Peer if chatType == "p2p" { peer = bus.Peer{Kind: "direct", ID: senderID} } else { peer = bus.Peer{Kind: "group", ID: chatID} + + // Check if bot was mentioned + isMentioned := c.isBotMentioned(message) + + // Strip mention placeholders from content before group trigger check + if len(message.Mentions) > 0 { + content = stripMentionPlaceholders(content, message.Mentions) + } + // In group chats, apply unified group trigger filtering - respond, cleaned := c.ShouldRespondInGroup(false, content) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) if !respond { return nil } @@ -186,22 +404,398 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. } logger.InfoCF("feishu", "Feishu message received", map[string]any{ - "sender_id": senderID, - "chat_id": chatID, - "preview": utils.Truncate(content, 80), + "sender_id": senderID, + "chat_id": chatID, + "message_id": messageID, + "preview": utils.Truncate(content, 80), }) - senderInfo := bus.SenderInfo{ - Platform: "feishu", - PlatformID: senderID, - CanonicalID: identity.BuildCanonicalID("feishu", senderID), + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo) + return nil +} + +// --- Internal helpers --- + +// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id. +func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error { + resp, err := c.client.Do(ctx, &larkcore.ApiReq{ + HttpMethod: http.MethodGet, + ApiPath: "/open-apis/bot/v3/info", + SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant}, + }) + if err != nil { + return fmt.Errorf("bot info request: %w", err) } - if !c.IsAllowedSender(senderInfo) { - return nil + var result struct { + Code int `json:"code"` + Bot struct { + OpenID string `json:"open_id"` + } `json:"bot"` + } + if err := json.Unmarshal(resp.RawBody, &result); err != nil { + return fmt.Errorf("bot info parse: %w", err) + } + if result.Code != 0 { + return fmt.Errorf("bot info api error (code=%d)", result.Code) + } + if result.Bot.OpenID == "" { + return fmt.Errorf("bot info: empty open_id") } - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo) + c.botOpenID.Store(result.Bot.OpenID) + logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{ + "open_id": result.Bot.OpenID, + }) + return nil +} + +// isBotMentioned checks if the bot was @mentioned in the message. +func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool { + if message.Mentions == nil { + return false + } + + knownID, _ := c.botOpenID.Load().(string) + if knownID == "" { + logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil) + return false + } + + for _, m := range message.Mentions { + if m.Id == nil { + continue + } + if m.Id.OpenId != nil && *m.Id.OpenId == knownID { + return true + } + } + return false +} + +// extractContent extracts text content from different message types. +func extractContent(messageType, rawContent string) string { + if rawContent == "" { + return "" + } + + switch messageType { + case larkim.MsgTypeText: + var textPayload struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil { + return textPayload.Text + } + return rawContent + + case larkim.MsgTypePost: + // Pass raw JSON to LLM โ€” structured rich text is more informative than flattened plain text + return rawContent + + case larkim.MsgTypeImage: + // Image messages don't have text content + return "" + + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + // File/audio/video messages may have a filename + name := extractFileName(rawContent) + if name != "" { + return name + } + return "" + + default: + return rawContent + } +} + +// downloadInboundMedia downloads media from inbound messages and stores in MediaStore. +func (c *FeishuChannel) downloadInboundMedia( + ctx context.Context, + chatID, messageID, messageType, rawContent string, + store media.MediaStore, +) []string { + var refs []string + scope := channels.BuildMediaScope("feishu", chatID, messageID) + + switch messageType { + case larkim.MsgTypeImage: + imageKey := extractImageKey(rawContent) + if imageKey == "" { + return nil + } + ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope) + if ref != "" { + refs = append(refs, ref) + } + + case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia: + fileKey := extractFileKey(rawContent) + if fileKey == "" { + return nil + } + // Derive a fallback extension from the message type. + var ext string + switch messageType { + case larkim.MsgTypeAudio: + ext = ".ogg" + case larkim.MsgTypeMedia: + ext = ".mp4" + default: + ext = "" // generic file โ€” rely on resp.FileName + } + ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope) + if ref != "" { + refs = append(refs, ref) + } + } + + return refs +} + +// downloadResource downloads a message resource (image/file) from Feishu, +// writes it to the project media directory, and stores the reference in MediaStore. +// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension. +func (c *FeishuChannel) downloadResource( + ctx context.Context, + messageID, fileKey, resourceType, fallbackExt string, + store media.MediaStore, + scope string, +) string { + req := larkim.NewGetMessageResourceReqBuilder(). + MessageId(messageID). + FileKey(fileKey). + Type(resourceType). + Build() + + resp, err := c.client.Im.V1.MessageResource.Get(ctx, req) + if err != nil { + logger.ErrorCF("feishu", "Failed to download resource", map[string]any{ + "message_id": messageID, + "file_key": fileKey, + "error": err.Error(), + }) + return "" + } + if !resp.Success() { + logger.ErrorCF("feishu", "Resource download api error", map[string]any{ + "code": resp.Code, + "msg": resp.Msg, + }) + return "" + } + + if resp.File == nil { + return "" + } + // Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body). + if closer, ok := resp.File.(io.Closer); ok { + defer closer.Close() + } + + filename := resp.FileName + if filename == "" { + filename = fileKey + } + // If filename still has no extension, append the fallback (like Telegram's ext parameter). + if filepath.Ext(filename) == "" && fallbackExt != "" { + filename += fallbackExt + } + + // Write to the shared picoclaw_media directory using a unique name to avoid collisions. + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil { + logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{ + "error": mkdirErr.Error(), + }) + return "" + } + ext := filepath.Ext(filename) + localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext)) + + out, err := os.Create(localPath) + if err != nil { + logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{ + "error": err.Error(), + }) + return "" + } + + if _, copyErr := io.Copy(out, resp.File); copyErr != nil { + out.Close() + os.Remove(localPath) + logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{ + "error": copyErr.Error(), + }) + return "" + } + out.Close() + + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "feishu", + }, scope) + if err != nil { + logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{ + "file_key": fileKey, + "error": err.Error(), + }) + os.Remove(localPath) + return "" + } + + return ref +} + +// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]"). +func appendMediaTags(content, messageType string, mediaRefs []string) string { + if len(mediaRefs) == 0 { + return content + } + + var tag string + switch messageType { + case larkim.MsgTypeImage: + tag = "[image: photo]" + case larkim.MsgTypeAudio: + tag = "[audio]" + case larkim.MsgTypeMedia: + tag = "[video]" + case larkim.MsgTypeFile: + tag = "[file]" + default: + tag = "[attachment]" + } + + if content == "" { + return tag + } + return content + " " + tag +} + +// sendCard sends an interactive card message to a chat. +func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error { + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeInteractive). + Content(cardContent). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu send card: %w", channels.ErrTemporary) + } + + if !resp.Success() { + return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) + } + + logger.DebugCF("feishu", "Feishu card message sent", map[string]any{ + "chat_id": chatID, + }) + + return nil +} + +// sendImage uploads an image and sends it as a message. +func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error { + // Upload image to get image_key + uploadReq := larkim.NewCreateImageReqBuilder(). + Body(larkim.NewCreateImageReqBodyBuilder(). + ImageType("message"). + Image(file). + Build()). + Build() + + uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("feishu image upload: %w", err) + } + if !uploadResp.Success() { + return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg) + } + if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil { + return fmt.Errorf("feishu image upload: no image_key returned") + } + + imageKey := *uploadResp.Data.ImageKey + + // Send image message + content, _ := json.Marshal(map[string]string{"image_key": imageKey}) + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeImage). + Content(string(content)). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu image send: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg) + } + return nil +} + +// sendFile uploads a file and sends it as a message. +func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error { + // Map part type to Feishu file type + feishuFileType := "stream" + switch fileType { + case "audio": + feishuFileType = "opus" + case "video": + feishuFileType = "mp4" + } + + // Upload file to get file_key + uploadReq := larkim.NewCreateFileReqBuilder(). + Body(larkim.NewCreateFileReqBodyBuilder(). + FileType(feishuFileType). + FileName(filename). + File(file). + Build()). + Build() + + uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq) + if err != nil { + return fmt.Errorf("feishu file upload: %w", err) + } + if !uploadResp.Success() { + return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg) + } + if uploadResp.Data == nil || uploadResp.Data.FileKey == nil { + return fmt.Errorf("feishu file upload: no file_key returned") + } + + fileKey := *uploadResp.Data.FileKey + + // Send file message + content, _ := json.Marshal(map[string]string{"file_key": fileKey}) + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(chatID). + MsgType(larkim.MsgTypeFile). + Content(string(content)). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("feishu file send: %w", err) + } + if !resp.Success() { + return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg) + } return nil } @@ -222,20 +816,3 @@ func extractFeishuSenderID(sender *larkim.EventSender) string { return "" } - -func extractFeishuMessageContent(message *larkim.EventMessage) string { - if message == nil || message.Content == nil || *message.Content == "" { - return "" - } - - if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText { - var textPayload struct { - Text string `json:"text"` - } - if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil { - return textPayload.Text - } - } - - return *message.Content -} diff --git a/pkg/channels/feishu/feishu_64_test.go b/pkg/channels/feishu/feishu_64_test.go new file mode 100644 index 000000000..dc3eab2e7 --- /dev/null +++ b/pkg/channels/feishu/feishu_64_test.go @@ -0,0 +1,256 @@ +//go:build amd64 || arm64 || riscv64 || mips64 || ppc64 + +package feishu + +import ( + "testing" + + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" +) + +func TestExtractContent(t *testing.T) { + tests := []struct { + name string + messageType string + rawContent string + want string + }{ + { + name: "text message", + messageType: "text", + rawContent: `{"text": "hello world"}`, + want: "hello world", + }, + { + name: "text message invalid JSON", + messageType: "text", + rawContent: `not json`, + want: "not json", + }, + { + name: "post message returns raw JSON", + messageType: "post", + rawContent: `{"title": "test post"}`, + want: `{"title": "test post"}`, + }, + { + name: "image message returns empty", + messageType: "image", + rawContent: `{"image_key": "img_xxx"}`, + want: "", + }, + { + name: "file message with filename", + messageType: "file", + rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`, + want: "report.pdf", + }, + { + name: "file message without filename", + messageType: "file", + rawContent: `{"file_key": "file_xxx"}`, + want: "", + }, + { + name: "audio message with filename", + messageType: "audio", + rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`, + want: "recording.ogg", + }, + { + name: "media message with filename", + messageType: "media", + rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`, + want: "video.mp4", + }, + { + name: "unknown message type returns raw", + messageType: "sticker", + rawContent: `{"sticker_id": "sticker_xxx"}`, + want: `{"sticker_id": "sticker_xxx"}`, + }, + { + name: "empty raw content", + messageType: "text", + rawContent: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractContent(tt.messageType, tt.rawContent) + if got != tt.want { + t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want) + } + }) + } +} + +func TestAppendMediaTags(t *testing.T) { + tests := []struct { + name string + content string + messageType string + mediaRefs []string + want string + }{ + { + name: "no refs returns content unchanged", + content: "hello", + messageType: "image", + mediaRefs: nil, + want: "hello", + }, + { + name: "empty refs returns content unchanged", + content: "hello", + messageType: "image", + mediaRefs: []string{}, + want: "hello", + }, + { + name: "image with content", + content: "check this", + messageType: "image", + mediaRefs: []string{"ref1"}, + want: "check this [image: photo]", + }, + { + name: "image empty content", + content: "", + messageType: "image", + mediaRefs: []string{"ref1"}, + want: "[image: photo]", + }, + { + name: "audio", + content: "listen", + messageType: "audio", + mediaRefs: []string{"ref1"}, + want: "listen [audio]", + }, + { + name: "media/video", + content: "watch", + messageType: "media", + mediaRefs: []string{"ref1"}, + want: "watch [video]", + }, + { + name: "file", + content: "report.pdf", + messageType: "file", + mediaRefs: []string{"ref1"}, + want: "report.pdf [file]", + }, + { + name: "unknown type", + content: "something", + messageType: "sticker", + mediaRefs: []string{"ref1"}, + want: "something [attachment]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs) + if got != tt.want { + t.Errorf( + "appendMediaTags(%q, %q, %v) = %q, want %q", + tt.content, + tt.messageType, + tt.mediaRefs, + got, + tt.want, + ) + } + }) + } +} + +func TestExtractFeishuSenderID(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + sender *larkim.EventSender + want string + }{ + { + name: "nil sender", + sender: nil, + want: "", + }, + { + name: "nil sender ID", + sender: &larkim.EventSender{SenderId: nil}, + want: "", + }, + { + name: "userId preferred", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr("u_abc123"), + OpenId: strPtr("ou_def456"), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "u_abc123", + }, + { + name: "openId fallback", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr("ou_def456"), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "ou_def456", + }, + { + name: "unionId fallback", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr(""), + UnionId: strPtr("on_ghi789"), + }, + }, + want: "on_ghi789", + }, + { + name: "all empty strings", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: strPtr(""), + OpenId: strPtr(""), + UnionId: strPtr(""), + }, + }, + want: "", + }, + { + name: "nil userId pointer falls through", + sender: &larkim.EventSender{ + SenderId: &larkim.UserId{ + UserId: nil, + OpenId: strPtr("ou_def456"), + UnionId: nil, + }, + }, + want: "ou_def456", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractFeishuSenderID(tt.sender) + if got != tt.want { + t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index a11cf53b8..f328f32b8 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,12 +7,12 @@ import ( "net/url" "os" "regexp" + "slices" "strconv" "strings" "time" "github.com/mymmrac/telego" - "github.com/mymmrac/telego/telegohandler" th "github.com/mymmrac/telego/telegohandler" tu "github.com/mymmrac/telego/telegoutil" @@ -41,7 +41,7 @@ var ( type TelegramChannel struct { *channels.BaseChannel bot *telego.Bot - bh *telegohandler.BotHandler + bh *th.BotHandler commands TelegramCommander config *config.Config chatIDs map[string]int64 @@ -72,6 +72,10 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann })) } + if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" { + opts = append(opts, telego.WithAPIServer(baseURL)) + } + bot, err := telego.NewBot(telegramCfg.Token, opts...) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) @@ -101,6 +105,12 @@ func (c *TelegramChannel) Start(ctx context.Context) error { c.ctx, c.cancel = context.WithCancel(ctx) + if err := c.initBotCommands(c.ctx); err != nil { + logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{ + "error": err.Error(), + }) + } + updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) @@ -109,20 +119,19 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to start long polling: %w", err) } - bh, err := telegohandler.NewBotHandler(c.bot, updates) + bh, err := th.NewBotHandler(c.bot, updates) if err != nil { c.cancel() return fmt.Errorf("failed to create bot handler: %w", err) } c.bh = bh - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - c.commands.Help(ctx, message) - return nil - }, th.CommandEqual("help")) bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { return c.commands.Start(ctx, message) }, th.CommandEqual("start")) + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Help(ctx, message) + }, th.CommandEqual("help")) bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { return c.commands.Show(ctx, message) @@ -141,7 +150,13 @@ func (c *TelegramChannel) Start(ctx context.Context) error { "username": c.bot.Username(), }) - go bh.Start() + go func() { + if err = bh.Start(); err != nil { + logger.ErrorCF("telegram", "Bot handler failed", map[string]any{ + "error": err.Error(), + }) + } + }() return nil } @@ -152,7 +167,7 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { // Stop the bot handler if c.bh != nil { - c.bh.Stop() + _ = c.bh.StopWithContext(ctx) } // Cancel our context (stops long polling) @@ -163,6 +178,51 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { return nil } +func (c *TelegramChannel) initBotCommands(ctx context.Context) error { + currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{ + Scope: tu.ScopeDefault(), + }) + if err != nil { + return fmt.Errorf("get commands: %w", err) + } + + commands := []telego.BotCommand{ + { + Command: "start", + Description: "Start the bot", + }, + { + Command: "help", + Description: "Show a help message", + }, + { + Command: "show", + Description: "Show current configuration", + }, + { + Command: "list", + Description: "List available options", + }, + } + + // Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed + if !slices.Equal(currentCommands, commands) { + logger.InfoC("telegram", "Updating bot commands") + + err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{ + Commands: commands, + Scope: tu.ScopeDefault(), + }) + if err != nil { + return fmt.Errorf("set commands: %w", err) + } + } else { + logger.DebugC("telegram", "Bot commands are up to date") + } + + return nil +} + func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return channels.ErrNotRunning diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index b79340315..717815b9f 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -38,8 +38,7 @@ type WeComAppChannel struct { tokenMu sync.RWMutex ctx context.Context cancel context.CancelFunc - processedMsgs map[string]bool // Message deduplication: msg_id -> processed - msgMu sync.RWMutex + processedMsgs *MessageDeduplicator } // WeComXMLMessage represents the XML message structure from WeCom @@ -144,7 +143,7 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, - processedMsgs: make(map[string]bool), + processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), }, nil } @@ -607,23 +606,12 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag // Message deduplication: Use msg_id to prevent duplicate processing // As per WeCom documentation, use msg_id for deduplication msgID := fmt.Sprintf("%d", msg.MsgId) - c.msgMu.Lock() - if c.processedMsgs[msgID] { - c.msgMu.Unlock() + if !c.processedMsgs.MarkMessageProcessed(msgID) { logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{ "msg_id": msgID, }) return } - c.processedMsgs[msgID] = true - // Clean up old messages while still holding the lock to avoid a data race - // on len(). Reset the map but re-insert the current msgID so it remains - // deduplicated. - if len(c.processedMsgs) > 1000 { - c.processedMsgs = make(map[string]bool) - c.processedMsgs[msgID] = true - } - c.msgMu.Unlock() senderID := msg.FromUserName chatID := senderID // WeCom App uses user ID as chat ID for direct messages diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 0d0426c0d..9126a847d 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -9,7 +9,6 @@ import ( "io" "net/http" "strings" - "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" @@ -28,8 +27,7 @@ type WeComBotChannel struct { client *http.Client ctx context.Context cancel context.CancelFunc - processedMsgs map[string]bool // Message deduplication: msg_id -> processed - msgMu sync.RWMutex + processedMsgs *MessageDeduplicator } // WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) @@ -108,7 +106,7 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, - processedMsgs: make(map[string]bool), + processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages), }, nil } @@ -330,23 +328,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag // Message deduplication: Use msg_id to prevent duplicate processing msgID := msg.MsgID - c.msgMu.Lock() - if c.processedMsgs[msgID] { - c.msgMu.Unlock() + if !c.processedMsgs.MarkMessageProcessed(msgID) { logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{ "msg_id": msgID, }) return } - c.processedMsgs[msgID] = true - // Clean up old messages while still holding the lock to avoid a data race - // on len(). Reset the map but re-insert the current msgID so it remains - // deduplicated. - if len(c.processedMsgs) > 1000 { - c.processedMsgs = make(map[string]bool) - c.processedMsgs[msgID] = true - } - c.msgMu.Unlock() senderID := msg.From.UserID diff --git a/pkg/channels/wecom/dedupe.go b/pkg/channels/wecom/dedupe.go new file mode 100644 index 000000000..865be668e --- /dev/null +++ b/pkg/channels/wecom/dedupe.go @@ -0,0 +1,54 @@ +package wecom + +import "sync" + +const wecomMaxProcessedMessages = 1000 + +// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer) +// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest +// messages without causing "amnesia cliffs" when the limit is reached. +type MessageDeduplicator struct { + mu sync.Mutex + msgs map[string]bool + ring []string + idx int + max int +} + +// NewMessageDeduplicator creates a new deduplicator with the specified capacity. +func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator { + if maxEntries <= 0 { + maxEntries = wecomMaxProcessedMessages + } + return &MessageDeduplicator{ + msgs: make(map[string]bool, maxEntries), + ring: make([]string, maxEntries), + max: maxEntries, + } +} + +// MarkMessageProcessed marks msgID as processed and returns false for duplicates. +func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool { + d.mu.Lock() + defer d.mu.Unlock() + + // 1. Check for duplicate + if d.msgs[msgID] { + return false + } + + // 2. Evict the oldest message at our current ring position (if any) + oldestID := d.ring[d.idx] + if oldestID != "" { + delete(d.msgs, oldestID) + } + + // 3. Store the new message + d.msgs[msgID] = true + d.ring[d.idx] = msgID + + // 4. Advance the circle queue index + d.idx = (d.idx + 1) % d.max + + return true +} diff --git a/pkg/channels/wecom/dedupe_test.go b/pkg/channels/wecom/dedupe_test.go new file mode 100644 index 000000000..10dff4cfe --- /dev/null +++ b/pkg/channels/wecom/dedupe_test.go @@ -0,0 +1,83 @@ +package wecom + +import ( + "sync" + "testing" +) + +func TestMessageDeduplicator_DuplicateDetection(t *testing.T) { + d := NewMessageDeduplicator(wecomMaxProcessedMessages) + + if ok := d.MarkMessageProcessed("msg-1"); !ok { + t.Fatalf("first message should be accepted") + } + + if ok := d.MarkMessageProcessed("msg-1"); ok { + t.Fatalf("duplicate message should be rejected") + } +} + +func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) { + d := NewMessageDeduplicator(wecomMaxProcessedMessages) + + const goroutines = 64 + var wg sync.WaitGroup + wg.Add(goroutines) + + results := make(chan bool, goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + results <- d.MarkMessageProcessed("msg-concurrent") + }() + } + + wg.Wait() + close(results) + + successes := 0 + for ok := range results { + if ok { + successes++ + } + } + + if successes != 1 { + t.Fatalf("expected exactly 1 successful mark, got %d", successes) + } +} + +func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) { + // Create a deduplicator with a very small capacity to test eviction easily. + capacity := 3 + d := NewMessageDeduplicator(capacity) + + // Fill the queue. + d.MarkMessageProcessed("msg-1") + d.MarkMessageProcessed("msg-2") + d.MarkMessageProcessed("msg-3") + + // At this point, the queue is full. msg-1 is the oldest. + if len(d.msgs) != 3 { + t.Fatalf("expected map size to be 3, got %d", len(d.msgs)) + } + + // This should evict msg-1 and add msg-4. + if ok := d.MarkMessageProcessed("msg-4"); !ok { + t.Fatalf("msg-4 should be accepted") + } + + if len(d.msgs) != 3 { + t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs)) + } + + // msg-1 should now be forgotten (evicted). + if ok := d.MarkMessageProcessed("msg-1"); !ok { + t.Fatalf("msg-1 should be accepted again because it was evicted") + } + + // msg-2 should have been evicted when we added msg-1 back. + if ok := d.MarkMessageProcessed("msg-2"); !ok { + t.Fatalf("msg-2 should be accepted again because it was evicted") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index e7e14323c..e801b44c9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -180,6 +180,16 @@ type AgentDefaults struct { MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"` +} + +const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB + +func (d *AgentDefaults) GetMaxMediaSize() int { + if d.MaxMediaSize > 0 { + return d.MaxMediaSize + } + return DefaultMaxMediaSize } // GetModelName returns the effective model name for the agent defaults. @@ -237,6 +247,7 @@ type WhatsAppConfig struct { type TelegramConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"` Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` @@ -253,6 +264,7 @@ type FeishuConfig struct { VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` } @@ -399,6 +411,7 @@ type DevicesConfig struct { type ProvidersConfig struct { Anthropic ProviderConfig `json:"anthropic"` OpenAI OpenAIProviderConfig `json:"openai"` + LiteLLM ProviderConfig `json:"litellm"` OpenRouter ProviderConfig `json:"openrouter"` Groq ProviderConfig `json:"groq"` Zhipu ProviderConfig `json:"zhipu"` @@ -423,6 +436,7 @@ type ProvidersConfig struct { func (p ProvidersConfig) IsEmpty() bool { return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" && p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" && + p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" && p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" && p.Groq.APIKey == "" && p.Groq.APIBase == "" && p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" && @@ -567,6 +581,7 @@ type ToolsConfig struct { Exec ExecConfig `json:"exec"` Skills SkillsToolsConfig `json:"skills"` MediaCleanup MediaCleanupConfig `json:"media_cleanup"` + MCP MCPConfig `json:"mcp"` } type SkillsToolsConfig struct { @@ -596,6 +611,34 @@ type ClawHubRegistryConfig struct { MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"` } +// MCPServerConfig defines configuration for a single MCP server +type MCPServerConfig struct { + // Enabled indicates whether this MCP server is active + Enabled bool `json:"enabled"` + // Command is the executable to run (e.g., "npx", "python", "/path/to/server") + Command string `json:"command"` + // Args are the arguments to pass to the command + Args []string `json:"args,omitempty"` + // Env are environment variables to set for the server process (stdio only) + Env map[string]string `json:"env,omitempty"` + // EnvFile is the path to a file containing environment variables (stdio only) + EnvFile string `json:"env_file,omitempty"` + // Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set) + Type string `json:"type,omitempty"` + // URL is used for SSE/HTTP transport + URL string `json:"url,omitempty"` + // Headers are HTTP headers to send with requests (sse/http only) + Headers map[string]string `json:"headers,omitempty"` +} + +// MCPConfig defines configuration for all MCP servers +type MCPConfig struct { + // Enabled globally enables/disables MCP integration + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"` + // Servers is a map of server name to server configuration + Servers map[string]MCPServerConfig `json:"servers,omitempty"` +} + func LoadConfig(path string) (*Config, error) { cfg := DefaultConfig() diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index fb0fd4451..9fc09c5f1 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -361,6 +361,10 @@ func DefaultConfig() *Config { TTLSeconds: 300, }, }, + MCP: MCPConfig{ + Enabled: false, + Servers: map[string]MCPServerConfig{}, + }, }, Heartbeat: HeartbeatConfig{ Enabled: true, diff --git a/pkg/config/migration.go b/pkg/config/migration.go index 2475f5aa9..685916950 100644 --- a/pkg/config/migration.go +++ b/pkg/config/migration.go @@ -88,6 +88,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig { }, true }, }, + { + providerNames: []string{"litellm"}, + protocol: "litellm", + buildConfig: func(p ProvidersConfig) (ModelConfig, bool) { + if p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" { + return ModelConfig{}, false + } + return ModelConfig{ + ModelName: "litellm", + Model: "litellm/auto", + APIKey: p.LiteLLM.APIKey, + APIBase: p.LiteLLM.APIBase, + Proxy: p.LiteLLM.Proxy, + RequestTimeout: p.LiteLLM.RequestTimeout, + }, true + }, + }, { providerNames: []string{"openrouter"}, protocol: "openrouter", diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go index 7fda3a1fc..8f7b19801 100644 --- a/pkg/config/migration_test.go +++ b/pkg/config/migration_test.go @@ -63,6 +63,33 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) { } } +func TestConvertProvidersToModelList_LiteLLM(t *testing.T) { + cfg := &Config{ + Providers: ProvidersConfig{ + LiteLLM: ProviderConfig{ + APIKey: "litellm-key", + APIBase: "http://localhost:4000/v1", + }, + }, + } + + result := ConvertProvidersToModelList(cfg) + + if len(result) != 1 { + t.Fatalf("len(result) = %d, want 1", len(result)) + } + + if result[0].ModelName != "litellm" { + t.Errorf("ModelName = %q, want %q", result[0].ModelName, "litellm") + } + if result[0].Model != "litellm/auto" { + t.Errorf("Model = %q, want %q", result[0].Model, "litellm/auto") + } + if result[0].APIBase != "http://localhost:4000/v1" { + t.Errorf("APIBase = %q, want %q", result[0].APIBase, "http://localhost:4000/v1") + } +} + func TestConvertProvidersToModelList_Multiple(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ @@ -115,6 +142,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) { cfg := &Config{ Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}}, + LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"}, Anthropic: ProviderConfig{APIKey: "key2"}, OpenRouter: ProviderConfig{APIKey: "key3"}, Groq: ProviderConfig{APIKey: "key4"}, diff --git a/pkg/mcp/manager.go b/pkg/mcp/manager.go new file mode 100644 index 000000000..7b63cc979 --- /dev/null +++ b/pkg/mcp/manager.go @@ -0,0 +1,532 @@ +package mcp + +import ( + "bufio" + "context" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "sync/atomic" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// headerTransport is an http.RoundTripper that adds custom headers to requests +type headerTransport struct { + base http.RoundTripper + headers map[string]string +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request to avoid modifying the original + req = req.Clone(req.Context()) + + // Add custom headers + for key, value := range t.headers { + req.Header.Set(key, value) + } + + // Use the base transport + base := t.base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(req) +} + +// loadEnvFile loads environment variables from a file in .env format +// Each line should be in the format: KEY=value +// Lines starting with # are comments +// Empty lines are ignored +func loadEnvFile(path string) (map[string]string, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open env file: %w", err) + } + defer file.Close() + + envVars := make(map[string]string) + scanner := bufio.NewScanner(file) + lineNum := 0 + + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Parse KEY=value + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key == "" { + return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum) + } + + // Remove surrounding quotes if present + if len(value) >= 2 { + if (value[0] == '"' && value[len(value)-1] == '"') || + (value[0] == '\'' && value[len(value)-1] == '\'') { + value = value[1 : len(value)-1] + } + } + + envVars[key] = value + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading env file: %w", err) + } + + return envVars, nil +} + +// ServerConnection represents a connection to an MCP server +type ServerConnection struct { + Name string + Client *mcp.Client + Session *mcp.ClientSession + Tools []*mcp.Tool +} + +// Manager manages multiple MCP server connections +type Manager struct { + servers map[string]*ServerConnection + mu sync.RWMutex + closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race + wg sync.WaitGroup // tracks in-flight CallTool calls +} + +// NewManager creates a new MCP manager +func NewManager() *Manager { + return &Manager{ + servers: make(map[string]*ServerConnection), + } +} + +// LoadFromConfig loads MCP servers from configuration +func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error { + return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath()) +} + +// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path. +// This is the minimal dependency version that doesn't require the full Config object. +func (m *Manager) LoadFromMCPConfig( + ctx context.Context, + mcpCfg config.MCPConfig, + workspacePath string, +) error { + if !mcpCfg.Enabled { + logger.InfoCF("mcp", "MCP integration is disabled", nil) + return nil + } + + if len(mcpCfg.Servers) == 0 { + logger.InfoCF("mcp", "No MCP servers configured", nil) + return nil + } + + logger.InfoCF("mcp", "Initializing MCP servers", + map[string]any{ + "count": len(mcpCfg.Servers), + }) + + var wg sync.WaitGroup + errs := make(chan error, len(mcpCfg.Servers)) + enabledCount := 0 + + for name, serverCfg := range mcpCfg.Servers { + if !serverCfg.Enabled { + logger.DebugCF("mcp", "Skipping disabled server", + map[string]any{ + "server": name, + }) + continue + } + + enabledCount++ + wg.Add(1) + go func(name string, serverCfg config.MCPServerConfig, workspace string) { + defer wg.Done() + + // Resolve relative envFile paths relative to workspace + if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) { + if workspace == "" { + err := fmt.Errorf( + "workspace path is empty while resolving relative envFile %q for server %s", + serverCfg.EnvFile, + name, + ) + logger.ErrorCF("mcp", "Invalid MCP server configuration", + map[string]any{ + "server": name, + "env_file": serverCfg.EnvFile, + "error": err.Error(), + }) + errs <- err + return + } + serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile) + } + + if err := m.ConnectServer(ctx, name, serverCfg); err != nil { + logger.ErrorCF("mcp", "Failed to connect to MCP server", + map[string]any{ + "server": name, + "error": err.Error(), + }) + errs <- fmt.Errorf("failed to connect to server %s: %w", name, err) + } + }(name, serverCfg, workspacePath) + } + + wg.Wait() + close(errs) + + // Collect errors + var allErrors []error + for err := range errs { + allErrors = append(allErrors, err) + } + + connectedCount := len(m.GetServers()) + + // If all enabled servers failed to connect, return aggregated error + if enabledCount > 0 && connectedCount == 0 { + logger.ErrorCF("mcp", "All MCP servers failed to connect", + map[string]any{ + "failed": len(allErrors), + "total": enabledCount, + }) + return errors.Join(allErrors...) + } + + if len(allErrors) > 0 { + logger.WarnCF("mcp", "Some MCP servers failed to connect", + map[string]any{ + "failed": len(allErrors), + "connected": connectedCount, + "total": enabledCount, + }) + // Don't fail completely if some servers successfully connected + } + + logger.InfoCF("mcp", "MCP server initialization complete", + map[string]any{ + "connected": connectedCount, + "total": enabledCount, + }) + + return nil +} + +// ConnectServer connects to a single MCP server +func (m *Manager) ConnectServer( + ctx context.Context, + name string, + cfg config.MCPServerConfig, +) error { + logger.InfoCF("mcp", "Connecting to MCP server", + map[string]any{ + "server": name, + "command": cfg.Command, + "args_count": len(cfg.Args), + }) + + // Create client + client := mcp.NewClient(&mcp.Implementation{ + Name: "picoclaw", + Version: "1.0.0", + }, nil) + + // Create transport based on configuration + // Auto-detect transport type if not explicitly specified + var transport mcp.Transport + transportType := cfg.Type + + // Auto-detect: if URL is provided, use SSE; if command is provided, use stdio + if transportType == "" { + if cfg.URL != "" { + transportType = "sse" + } else if cfg.Command != "" { + transportType = "stdio" + } else { + return fmt.Errorf("either URL or command must be provided") + } + } + + switch transportType { + case "sse", "http": + if cfg.URL == "" { + return fmt.Errorf("URL is required for SSE/HTTP transport") + } + logger.DebugCF("mcp", "Using SSE/HTTP transport", + map[string]any{ + "server": name, + "url": cfg.URL, + }) + + sseTransport := &mcp.StreamableClientTransport{ + Endpoint: cfg.URL, + } + + // Add custom headers if provided + if len(cfg.Headers) > 0 { + // Create a custom HTTP client with header-injecting transport + sseTransport.HTTPClient = &http.Client{ + Transport: &headerTransport{ + base: http.DefaultTransport, + headers: cfg.Headers, + }, + } + logger.DebugCF("mcp", "Added custom HTTP headers", + map[string]any{ + "server": name, + "header_count": len(cfg.Headers), + }) + } + + transport = sseTransport + case "stdio": + if cfg.Command == "" { + return fmt.Errorf("command is required for stdio transport") + } + logger.DebugCF("mcp", "Using stdio transport", + map[string]any{ + "server": name, + "command": cfg.Command, + }) + // Create command with context + cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...) + + // Build environment variables with proper override semantics + // Use a map to ensure config variables override file variables + envMap := make(map[string]string) + + // Start with parent process environment + for _, e := range cmd.Environ() { + if idx := strings.Index(e, "="); idx > 0 { + envMap[e[:idx]] = e[idx+1:] + } + } + + // Load environment variables from file if specified + if cfg.EnvFile != "" { + envVars, err := loadEnvFile(cfg.EnvFile) + if err != nil { + return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err) + } + for k, v := range envVars { + envMap[k] = v + } + logger.DebugCF("mcp", "Loaded environment variables from file", + map[string]any{ + "server": name, + "envFile": cfg.EnvFile, + "var_count": len(envVars), + }) + } + + // Environment variables from config override those from file + for k, v := range cfg.Env { + envMap[k] = v + } + + // Convert map to slice + env := make([]string, 0, len(envMap)) + for k, v := range envMap { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + cmd.Env = env + + transport = &mcp.CommandTransport{Command: cmd} + default: + return fmt.Errorf( + "unsupported transport type: %s (supported: stdio, sse, http)", + transportType, + ) + } + + // Connect to server + session, err := client.Connect(ctx, transport, nil) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + // Get server info + initResult := session.InitializeResult() + logger.InfoCF("mcp", "Connected to MCP server", + map[string]any{ + "server": name, + "serverName": initResult.ServerInfo.Name, + "serverVersion": initResult.ServerInfo.Version, + "protocol": initResult.ProtocolVersion, + }) + + // List available tools if supported + var tools []*mcp.Tool + if initResult.Capabilities.Tools != nil { + for tool, err := range session.Tools(ctx, nil) { + if err != nil { + logger.WarnCF("mcp", "Error listing tool", + map[string]any{ + "server": name, + "error": err.Error(), + }) + continue + } + tools = append(tools, tool) + } + + logger.InfoCF("mcp", "Listed tools from MCP server", + map[string]any{ + "server": name, + "toolCount": len(tools), + }) + } + + // Store connection + m.mu.Lock() + m.servers[name] = &ServerConnection{ + Name: name, + Client: client, + Session: session, + Tools: tools, + } + m.mu.Unlock() + + return nil +} + +// GetServers returns all connected servers +func (m *Manager) GetServers() map[string]*ServerConnection { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string]*ServerConnection, len(m.servers)) + for k, v := range m.servers { + result[k] = v + } + return result +} + +// GetServer returns a specific server connection +func (m *Manager) GetServer(name string) (*ServerConnection, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + conn, ok := m.servers[name] + return conn, ok +} + +// CallTool calls a tool on a specific server +func (m *Manager) CallTool( + ctx context.Context, + serverName, toolName string, + arguments map[string]any, +) (*mcp.CallToolResult, error) { + // Check if closed before acquiring lock (fast path) + if m.closed.Load() { + return nil, fmt.Errorf("manager is closed") + } + + m.mu.RLock() + // Double-check after acquiring lock to prevent TOCTOU race + if m.closed.Load() { + m.mu.RUnlock() + return nil, fmt.Errorf("manager is closed") + } + conn, ok := m.servers[serverName] + if ok { + m.wg.Add(1) // Add to WaitGroup while holding the lock + } + m.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("server %s not found", serverName) + } + defer m.wg.Done() + + params := &mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + } + + result, err := conn.Session.CallTool(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to call tool: %w", err) + } + + return result, nil +} + +// Close closes all server connections +func (m *Manager) Close() error { + // Use Swap to atomically set closed=true and get the previous value + // This prevents TOCTOU race with CallTool's closed check + if m.closed.Swap(true) { + return nil // already closed + } + + // Wait for all in-flight CallTool calls to finish before closing sessions + // After closed=true is set, no new CallTool can start (they check closed first) + m.wg.Wait() + + m.mu.Lock() + defer m.mu.Unlock() + + logger.InfoCF("mcp", "Closing all MCP server connections", + map[string]any{ + "count": len(m.servers), + }) + + var errs []error + for name, conn := range m.servers { + if err := conn.Session.Close(); err != nil { + logger.ErrorCF("mcp", "Failed to close server connection", + map[string]any{ + "server": name, + "error": err.Error(), + }) + errs = append(errs, fmt.Errorf("server %s: %w", name, err)) + } + } + + m.servers = make(map[string]*ServerConnection) + + if len(errs) > 0 { + return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...)) + } + + return nil +} + +// GetAllTools returns all tools from all connected servers +func (m *Manager) GetAllTools() map[string][]*mcp.Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make(map[string][]*mcp.Tool) + for name, conn := range m.servers { + if len(conn.Tools) > 0 { + result[name] = conn.Tools + } + } + return result +} diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go new file mode 100644 index 000000000..8ce81d09e --- /dev/null +++ b/pkg/mcp/manager_test.go @@ -0,0 +1,298 @@ +package mcp + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestLoadEnvFile(t *testing.T) { + tests := []struct { + name string + content string + expected map[string]string + expectErr bool + }{ + { + name: "basic env file", + content: `API_KEY=secret123 +DATABASE_URL=postgres://localhost/db +PORT=8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "with comments and empty lines", + content: `# This is a comment +API_KEY=secret123 + +# Another comment +DATABASE_URL=postgres://localhost/db + +PORT=8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "with quoted values", + content: `API_KEY="secret with spaces" +NAME='single quoted' +PLAIN=no-quotes`, + expected: map[string]string{ + "API_KEY": "secret with spaces", + "NAME": "single quoted", + "PLAIN": "no-quotes", + }, + expectErr: false, + }, + { + name: "with spaces around equals", + content: `API_KEY = secret123 +DATABASE_URL= postgres://localhost/db +PORT =8080`, + expected: map[string]string{ + "API_KEY": "secret123", + "DATABASE_URL": "postgres://localhost/db", + "PORT": "8080", + }, + expectErr: false, + }, + { + name: "invalid format - no equals", + content: `INVALID_LINE`, + expectErr: true, + }, + { + name: "empty file", + content: ``, + expected: map[string]string{}, + expectErr: false, + }, + { + name: "only comments", + content: `# Comment 1 +# Comment 2`, + expected: map[string]string{}, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + + if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + result, err := loadEnvFile(envFile) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if len(result) != len(tt.expected) { + t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result)) + } + + for key, expectedValue := range tt.expected { + if actualValue, ok := result[key]; !ok { + t.Errorf("Expected key %s not found", key) + } else if actualValue != expectedValue { + t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue) + } + } + }) + } +} + +func TestLoadEnvFileNotFound(t *testing.T) { + _, err := loadEnvFile("/nonexistent/file.env") + if err == nil { + t.Error("Expected error for nonexistent file") + } +} + +func TestEnvFilePriority(t *testing.T) { + // Create a temporary .env file + tmpDir := t.TempDir() + envFile := filepath.Join(tmpDir, ".env") + + envContent := `API_KEY=from_file +DATABASE_URL=from_file +SHARED_VAR=from_file` + + if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil { + t.Fatalf("Failed to create .env file: %v", err) + } + + // Load envFile + envVars, err := loadEnvFile(envFile) + if err != nil { + t.Fatalf("Failed to load env file: %v", err) + } + + // Verify envFile variables + if envVars["API_KEY"] != "from_file" { + t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"]) + } + + // Simulate config.Env overriding envFile + configEnv := map[string]string{ + "SHARED_VAR": "from_config", + "NEW_VAR": "from_config", + } + + // Merge: envFile first, then config overrides + merged := make(map[string]string) + for k, v := range envVars { + merged[k] = v + } + for k, v := range configEnv { + merged[k] = v + } + + // Verify priority: config.Env should override envFile + if merged["SHARED_VAR"] != "from_config" { + t.Errorf( + "Expected SHARED_VAR=from_config (config should override file), got %s", + merged["SHARED_VAR"], + ) + } + if merged["API_KEY"] != "from_file" { + t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"]) + } + if merged["NEW_VAR"] != "from_config" { + t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"]) + } +} + +func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) { + mgr := NewManager() + + mcpCfg := config.MCPConfig{ + Enabled: true, + Servers: map[string]config.MCPServerConfig{ + "test-server": { + Enabled: true, + Command: "echo", + Args: []string{"ok"}, + EnvFile: ".env", + }, + }, + } + + err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "") + if err == nil { + t.Fatal("expected error for relative env_file with empty workspace path, got nil") + } + + if !strings.Contains(err.Error(), "workspace path is empty") { + t.Fatalf("expected workspace path validation error, got: %v", err) + } +} + +func TestNewManager_InitialState(t *testing.T) { + mgr := NewManager() + if mgr == nil { + t.Fatal("expected manager instance, got nil") + } + if len(mgr.GetServers()) != 0 { + t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers())) + } +} + +func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) { + mgr := NewManager() + + err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp") + if err != nil { + t.Fatalf("expected nil error when MCP disabled, got: %v", err) + } + + err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp") + if err != nil { + t.Fatalf("expected nil error when no servers configured, got: %v", err) + } +} + +func TestGetServers_ReturnsCopy(t *testing.T) { + mgr := NewManager() + mgr.servers["s1"] = &ServerConnection{Name: "s1"} + + servers := mgr.GetServers() + delete(servers, "s1") + + if _, ok := mgr.GetServer("s1"); !ok { + t.Fatal("expected internal manager state to remain unchanged") + } +} + +func TestGetAllTools_FiltersEmptyTools(t *testing.T) { + mgr := NewManager() + mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil} + mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}} + + all := mgr.GetAllTools() + if _, ok := all["empty"]; ok { + t.Fatal("expected server without tools to be excluded") + } + if _, ok := all["with-tools"]; !ok { + t.Fatal("expected server with tools to be included") + } +} + +func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) { + t.Run("manager closed", func(t *testing.T) { + mgr := NewManager() + mgr.closed.Store(true) + + _, err := mgr.CallTool(context.Background(), "s1", "tool", nil) + if err == nil || !strings.Contains(err.Error(), "manager is closed") { + t.Fatalf("expected manager closed error, got: %v", err) + } + }) + + t.Run("server missing", func(t *testing.T) { + mgr := NewManager() + + _, err := mgr.CallTool(context.Background(), "missing", "tool", nil) + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected server not found error, got: %v", err) + } + }) +} + +func TestClose_IdempotentOnEmptyManager(t *testing.T) { + mgr := NewManager() + + if err := mgr.Close(); err != nil { + t.Fatalf("first close should succeed, got: %v", err) + } + if err := mgr.Close(); err != nil { + t.Fatalf("second close should be idempotent, got: %v", err) + } +} diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index 3f46d0f3d..9d53dca1c 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -102,6 +102,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = "https://openrouter.ai/api/v1" } } + case "litellm": + if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" { + sel.apiKey = cfg.Providers.LiteLLM.APIKey + sel.apiBase = cfg.Providers.LiteLLM.APIBase + sel.proxy = cfg.Providers.LiteLLM.Proxy + if sel.apiBase == "" { + sel.apiBase = "http://localhost:4000/v1" + } + } case "zhipu", "glm": if cfg.Providers.Zhipu.APIKey != "" { sel.apiKey = cfg.Providers.Zhipu.APIKey diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go index 1ddd056a4..4d2949c91 100644 --- a/pkg/providers/factory_provider.go +++ b/pkg/providers/factory_provider.go @@ -53,7 +53,7 @@ func ExtractProtocol(model string) (protocol, modelID string) { // CreateProviderFromConfig creates a provider based on the ModelConfig. // It uses the protocol prefix in the Model field to determine which provider to create. -// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot +// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot // Returns the provider, the model ID (without protocol prefix), and any error. func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) { if cfg == nil { @@ -92,7 +92,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err cfg.RequestTimeout, ), modelID, nil - case "openrouter", "groq", "zhipu", "gemini", "nvidia", + case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia", "ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras", "volcengine", "vllm", "qwen", "mistral", "opencode": // All other OpenAI-compatible HTTP providers @@ -180,6 +180,8 @@ func getDefaultAPIBase(protocol string) string { return "https://api.openai.com/v1" case "openrouter": return "https://openrouter.ai/api/v1" + case "litellm": + return "http://localhost:4000/v1" case "groq": return "https://api.groq.com/openai/v1" case "zhipu": diff --git a/pkg/providers/factory_provider_test.go b/pkg/providers/factory_provider_test.go index eccb8cd40..7d0ea1e32 100644 --- a/pkg/providers/factory_provider_test.go +++ b/pkg/providers/factory_provider_test.go @@ -136,6 +136,32 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) { } } +func TestGetDefaultAPIBase_LiteLLM(t *testing.T) { + if got := getDefaultAPIBase("litellm"); got != "http://localhost:4000/v1" { + t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "litellm", got, "http://localhost:4000/v1") + } +} + +func TestCreateProviderFromConfig_LiteLLM(t *testing.T) { + cfg := &config.ModelConfig{ + ModelName: "test-litellm", + Model: "litellm/my-proxy-alias", + APIKey: "test-key", + APIBase: "http://localhost:4000/v1", + } + + provider, modelID, err := CreateProviderFromConfig(cfg) + if err != nil { + t.Fatalf("CreateProviderFromConfig() error = %v", err) + } + if provider == nil { + t.Fatal("CreateProviderFromConfig() returned nil provider") + } + if modelID != "my-proxy-alias" { + t.Errorf("modelID = %q, want %q", modelID, "my-proxy-alias") + } +} + func TestCreateProviderFromConfig_Anthropic(t *testing.T) { cfg := &config.ModelConfig{ ModelName: "test-anthropic", diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go index 5680f23b3..f7a916d9e 100644 --- a/pkg/providers/factory_test.go +++ b/pkg/providers/factory_test.go @@ -17,6 +17,27 @@ func TestResolveProviderSelection(t *testing.T) { wantProxy string wantErrSubstr string }{ + { + name: "explicit litellm provider uses configured base", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "litellm" + cfg.Providers.LiteLLM.APIKey = "litellm-key" + cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1" + cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "http://localhost:4000/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit litellm provider defaults base when only key is configured", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "litellm" + cfg.Providers.LiteLLM.APIKey = "litellm-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "http://localhost:4000/v1", + }, { name: "explicit claude-cli provider routes to cli provider type", setup: func(cfg *config.Config) { diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index b0718384f..6bed72456 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -124,7 +124,7 @@ func (p *Provider) Chat( requestBody := map[string]any{ "model": model, - "messages": stripSystemParts(messages), + "messages": serializeMessages(messages), } if len(tools) > 0 { @@ -310,19 +310,55 @@ type openaiMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` } -// stripSystemParts converts []Message to []openaiMessage, dropping the -// SystemParts field so it doesn't leak into the JSON payload sent to -// OpenAI-compatible APIs (some strict endpoints reject unknown fields). -func stripSystemParts(messages []Message) []openaiMessage { - out := make([]openaiMessage, len(messages)) - for i, m := range messages { - out[i] = openaiMessage{ - Role: m.Role, - Content: m.Content, - ReasoningContent: m.ReasoningContent, - ToolCalls: m.ToolCalls, - ToolCallID: m.ToolCallID, +// serializeMessages converts internal Message structs to the OpenAI wire format. +// - Strips SystemParts (unknown to third-party endpoints) +// - Converts messages with Media to multipart content format (text + image_url parts) +// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages +func serializeMessages(messages []Message) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + if len(m.Media) == 0 { + out = append(out, openaiMessage{ + Role: m.Role, + Content: m.Content, + ReasoningContent: m.ReasoningContent, + ToolCalls: m.ToolCalls, + ToolCallID: m.ToolCallID, + }) + continue } + + // Multipart content format for messages with media + parts := make([]map[string]any, 0, 1+len(m.Media)) + if m.Content != "" { + parts = append(parts, map[string]any{ + "type": "text", + "text": m.Content, + }) + } + for _, mediaURL := range m.Media { + parts = append(parts, map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": mediaURL, + }, + }) + } + + msg := map[string]any{ + "role": m.Role, + "content": parts, + } + if m.ToolCallID != "" { + msg["tool_call_id"] = m.ToolCallID + } + if len(m.ToolCalls) > 0 { + msg["tool_calls"] = m.ToolCalls + } + if m.ReasoningContent != "" { + msg["reasoning_content"] = m.ReasoningContent + } + out = append(out, msg) } return out } @@ -339,7 +375,7 @@ func normalizeModel(model, apiBase string) string { prefix := strings.ToLower(before) switch prefix { - case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral": + case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral": return after default: return model diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go index 014451144..f08b24f17 100644 --- a/pkg/providers/openai_compat/provider_test.go +++ b/pkg/providers/openai_compat/provider_test.go @@ -9,6 +9,8 @@ import ( "strings" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" ) func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { @@ -258,6 +260,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { input string wantModel string }{ + { + name: "strips litellm prefix and preserves proxy model name", + input: "litellm/my-proxy-alias", + wantModel: "my-proxy-alias", + }, { name: "strips groq prefix and keeps nested model", input: "groq/openai/gpt-oss-120b", @@ -489,3 +496,97 @@ func TestProviderChat_KimiCodeUserAgent(t *testing.T) { }) } } + +func TestSerializeMessages_PlainText(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi", ReasoningContent: "thinking..."}, + } + result := serializeMessages(messages) + + data, err := json.Marshal(result) + if err != nil { + t.Fatal(err) + } + + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["content"] != "hello" { + t.Fatalf("expected plain string content, got %v", msgs[0]["content"]) + } + if msgs[1]["reasoning_content"] != "thinking..." { + t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"]) + } +} + +func TestSerializeMessages_WithMedia(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + content, ok := msgs[0]["content"].([]any) + if !ok { + t.Fatalf("expected array content for media message, got %T", msgs[0]["content"]) + } + if len(content) != 2 { + t.Fatalf("expected 2 content parts, got %d", len(content)) + } + + textPart := content[0].(map[string]any) + if textPart["type"] != "text" || textPart["text"] != "describe this" { + t.Fatalf("text part mismatch: %v", textPart) + } + + imgPart := content[1].(map[string]any) + if imgPart["type"] != "image_url" { + t.Fatalf("expected image_url type, got %v", imgPart["type"]) + } + imgURL := imgPart["image_url"].(map[string]any) + if imgURL["url"] != "data:image/png;base64,abc123" { + t.Fatalf("image url mismatch: %v", imgURL["url"]) + } +} + +func TestSerializeMessages_MediaWithToolCallID(t *testing.T) { + messages := []protocoltypes.Message{ + {Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"}, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + var msgs []map[string]any + json.Unmarshal(data, &msgs) + + if msgs[0]["tool_call_id"] != "call_1" { + t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"]) + } + // Content should be multipart array + if _, ok := msgs[0]["content"].([]any); !ok { + t.Fatalf("expected array content, got %T", msgs[0]["content"]) + } +} + +func TestSerializeMessages_StripsSystemParts(t *testing.T) { + messages := []protocoltypes.Message{ + { + Role: "system", + Content: "you are helpful", + SystemParts: []protocoltypes.ContentBlock{ + {Type: "text", Text: "you are helpful"}, + }, + }, + } + result := serializeMessages(messages) + + data, _ := json.Marshal(result) + raw := string(data) + if strings.Contains(raw, "system_parts") { + t.Fatal("system_parts should not appear in serialized output") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 99f13334e..194c1aa6f 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -65,6 +65,7 @@ type ContentBlock struct { type Message struct { Role string `json:"role"` Content string `json:"content"` + Media []string `json:"media,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters ToolCalls []ToolCall `json:"tool_calls,omitempty"` diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index fcbcf934b..30d84635a 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -64,6 +64,29 @@ type SkillsLoader struct { builtinSkills string // builtin skills } +// SkillRoots returns all unique skill root directories used by this loader. +// The order follows resolution priority: workspace > global > builtin. +func (sl *SkillsLoader) SkillRoots() []string { + roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills} + seen := make(map[string]struct{}, len(roots)) + out := make([]string, 0, len(roots)) + + for _, root := range roots { + trimmed := strings.TrimSpace(root) + if trimmed == "" { + continue + } + clean := filepath.Clean(trimmed) + if _, ok := seen[clean]; ok { + continue + } + seen[clean] = struct{}{} + out = append(out, clean) + } + + return out +} + func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader { return &SkillsLoader{ workspace: workspace, diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go index 9428bea62..31619f9c2 100644 --- a/pkg/skills/loader_test.go +++ b/pkg/skills/loader_test.go @@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) { }) } } + +func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) { + tmp := t.TempDir() + workspace := filepath.Join(tmp, "workspace") + global := filepath.Join(tmp, "global") + builtin := filepath.Join(tmp, "builtin") + + sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n") + roots := sl.SkillRoots() + + assert.Equal(t, []string{ + filepath.Join(workspace, "skills"), + global, + builtin, + }, roots) +} diff --git a/pkg/tools/mcp_tool.go b/pkg/tools/mcp_tool.go new file mode 100644 index 000000000..6e53cf354 --- /dev/null +++ b/pkg/tools/mcp_tool.go @@ -0,0 +1,246 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "hash/fnv" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MCPManager defines the interface for MCP manager operations +// This allows for easier testing with mock implementations +type MCPManager interface { + CallTool( + ctx context.Context, + serverName, toolName string, + arguments map[string]any, + ) (*mcp.CallToolResult, error) +} + +// MCPTool wraps an MCP tool to implement the Tool interface +type MCPTool struct { + manager MCPManager + serverName string + tool *mcp.Tool +} + +// NewMCPTool creates a new MCP tool wrapper +func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool { + return &MCPTool{ + manager: manager, + serverName: serverName, + tool: tool, + } +} + +// sanitizeIdentifierComponent normalizes a string so it can be safely used +// as part of a tool/function identifier for downstream providers. +// It: +// - lowercases the string +// - replaces any character not in [a-z0-9_-] with '_' +// - collapses multiple consecutive '_' into a single '_' +// - trims leading/trailing '_' +// - falls back to "unnamed" if the result is empty +// - truncates overly long components to a reasonable length +func sanitizeIdentifierComponent(s string) string { + const maxLen = 64 + + s = strings.ToLower(s) + var b strings.Builder + b.Grow(len(s)) + + prevUnderscore := false + for _, r := range s { + isAllowed := (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '_' || r == '-' + + if !isAllowed { + // Normalize any disallowed character to '_' + if !prevUnderscore { + b.WriteRune('_') + prevUnderscore = true + } + continue + } + + if r == '_' { + if prevUnderscore { + continue + } + prevUnderscore = true + } else { + prevUnderscore = false + } + + b.WriteRune(r) + } + + result := strings.Trim(b.String(), "_") + if result == "" { + result = "unnamed" + } + + if len(result) > maxLen { + result = result[:maxLen] + } + + return result +} + +// Name returns the tool name, prefixed with the server name. +// The total length is capped at 64 characters (OpenAI-compatible API limit). +// A short hash of the original (unsanitized) server and tool names is appended +// whenever sanitization is lossy or the name is truncated, ensuring that two +// names which differ only in disallowed characters remain distinct after sanitization. +func (t *MCPTool) Name() string { + // Prefix with server name to avoid conflicts, and sanitize components + sanitizedServer := sanitizeIdentifierComponent(t.serverName) + sanitizedTool := sanitizeIdentifierComponent(t.tool.Name) + full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool) + + // Check if sanitization was lossless (only lowercasing, no char replacement/truncation) + lossless := strings.ToLower(t.serverName) == sanitizedServer && + strings.ToLower(t.tool.Name) == sanitizedTool + + const maxTotal = 64 + if lossless && len(full) <= maxTotal { + return full + } + + // Sanitization was lossy or name too long: append hash of the ORIGINAL names + // (not the sanitized names) so different originals always yield different hashes. + h := fnv.New32a() + _, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name)) + suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars + + base := full + if len(base) > maxTotal-9 { + base = strings.TrimRight(full[:maxTotal-9], "_") + } + return base + "_" + suffix +} + +// Description returns the tool description +func (t *MCPTool) Description() string { + desc := t.tool.Description + if desc == "" { + desc = fmt.Sprintf("MCP tool from %s server", t.serverName) + } + // Add server info to description + return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc) +} + +// Parameters returns the tool parameters schema +func (t *MCPTool) Parameters() map[string]any { + // The InputSchema is already a JSON Schema object + schema := t.tool.InputSchema + + // Handle nil schema + if schema == nil { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + // Try direct conversion first (fast path) + if schemaMap, ok := schema.(map[string]any); ok { + return schemaMap + } + + // Handle json.RawMessage and []byte - unmarshal directly + var jsonData []byte + if rawMsg, ok := schema.(json.RawMessage); ok { + jsonData = rawMsg + } else if bytes, ok := schema.([]byte); ok { + jsonData = bytes + } + + if jsonData != nil { + var result map[string]any + if err := json.Unmarshal(jsonData, &result); err == nil { + return result + } + // Fallback on error + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + // For other types (structs, etc.), convert via JSON marshal/unmarshal + var err error + jsonData, err = json.Marshal(schema) + if err != nil { + // Fallback to empty schema if marshaling fails + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + var result map[string]any + if err := json.Unmarshal(jsonData, &result); err != nil { + // Fallback to empty schema if unmarshaling fails + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + "required": []string{}, + } + } + + return result +} + +// Execute executes the MCP tool +func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult { + result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args) + if err != nil { + return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err) + } + + if result == nil { + nilErr := fmt.Errorf("MCP tool returned nil result without error") + return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr) + } + + // Handle error result from server + if result.IsError { + errMsg := extractContentText(result.Content) + return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)). + WithError(fmt.Errorf("MCP tool error: %s", errMsg)) + } + + // Extract text content from result + output := extractContentText(result.Content) + + return &ToolResult{ + ForLLM: output, + IsError: false, + } +} + +// extractContentText extracts text from MCP content array +func extractContentText(content []mcp.Content) string { + var parts []string + for _, c := range content { + switch v := c.(type) { + case *mcp.TextContent: + parts = append(parts, v.Text) + case *mcp.ImageContent: + // For images, just indicate that an image was returned + parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType)) + default: + // For other content types, use string representation + parts = append(parts, fmt.Sprintf("[Content: %T]", v)) + } + } + return strings.Join(parts, "\n") +} diff --git a/pkg/tools/mcp_tool_test.go b/pkg/tools/mcp_tool_test.go new file mode 100644 index 000000000..95bb0f992 --- /dev/null +++ b/pkg/tools/mcp_tool_test.go @@ -0,0 +1,492 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MockMCPManager is a mock implementation of MCPManager interface for testing +type MockMCPManager struct { + callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) +} + +func (m *MockMCPManager) CallTool( + ctx context.Context, + serverName, toolName string, + arguments map[string]any, +) (*mcp.CallToolResult, error) { + if m.callToolFunc != nil { + return m.callToolFunc(ctx, serverName, toolName, arguments) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "mock result"}, + }, + IsError: false, + }, nil +} + +// TestNewMCPTool verifies MCP tool creation +func TestNewMCPTool(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{ + "type": "string", + "description": "Test input", + }, + }, + }, + } + + mcpTool := NewMCPTool(manager, "test_server", tool) + + if mcpTool == nil { + t.Fatal("NewMCPTool should not return nil") + } + // Verify tool properties we can access + if mcpTool.Name() != "mcp_test_server_test_tool" { + t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name()) + } +} + +// TestMCPTool_Name verifies tool name with server prefix +func TestMCPTool_Name(t *testing.T) { + tests := []struct { + name string + serverName string + toolName string + expected string + }{ + { + name: "simple name", + serverName: "github", + toolName: "create_issue", + expected: "mcp_github_create_issue", + }, + { + name: "filesystem server", + serverName: "filesystem", + toolName: "read_file", + expected: "mcp_filesystem_read_file", + }, + { + name: "remote server", + serverName: "remote-api", + toolName: "fetch_data", + expected: "mcp_remote-api_fetch_data", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{Name: tt.toolName} + mcpTool := NewMCPTool(manager, tt.serverName, tool) + + result := mcpTool.Name() + if result != tt.expected { + t.Errorf("Expected name '%s', got '%s'", tt.expected, result) + } + }) + } +} + +// TestMCPTool_Description verifies tool description generation +func TestMCPTool_Description(t *testing.T) { + tests := []struct { + name string + serverName string + toolDescription string + expectContains []string + }{ + { + name: "with description", + serverName: "github", + toolDescription: "Create a GitHub issue", + expectContains: []string{"[MCP:github]", "Create a GitHub issue"}, + }, + { + name: "empty description", + serverName: "filesystem", + toolDescription: "", + expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + Description: tt.toolDescription, + } + mcpTool := NewMCPTool(manager, tt.serverName, tool) + + result := mcpTool.Description() + + for _, expected := range tt.expectContains { + if !strings.Contains(result, expected) { + t.Errorf("Description should contain '%s', got: %s", expected, result) + } + } + }) + } +} + +// TestMCPTool_Parameters verifies parameter schema conversion +func TestMCPTool_Parameters(t *testing.T) { + tests := []struct { + name string + inputSchema any + expectType string + checkProperty string + expectProperty bool + }{ + { + name: "map schema", + inputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "Search query", + }, + }, + "required": []string{"query"}, + }, + expectType: "object", + checkProperty: "query", + expectProperty: true, + }, + { + name: "nil schema", + inputSchema: nil, + expectType: "object", + expectProperty: false, + }, + { + name: "json.RawMessage schema", + inputSchema: []byte(`{ + "type": "object", + "properties": { + "repo": { + "type": "string", + "description": "Repository name" + }, + "stars": { + "type": "integer", + "description": "Minimum stars" + } + }, + "required": ["repo"] + }`), + expectType: "object", + checkProperty: "repo", + expectProperty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{ + Name: "test_tool", + InputSchema: tt.inputSchema, + } + mcpTool := NewMCPTool(manager, "test_server", tool) + + params := mcpTool.Parameters() + + if params == nil { + t.Fatal("Parameters should not be nil") + } + + if params["type"] != tt.expectType { + t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"]) + } + + // Check if property exists when expected + if tt.checkProperty != "" { + properties, ok := params["properties"].(map[string]any) + if !ok && tt.expectProperty { + t.Errorf("Expected properties to be a map") + return + } + if ok { + _, hasProperty := properties[tt.checkProperty] + if hasProperty != tt.expectProperty { + t.Errorf("Expected property '%s' existence: %v, got: %v", + tt.checkProperty, tt.expectProperty, hasProperty) + } + } + } + }) + } +} + +// TestMCPTool_Execute_Success tests successful tool execution +func TestMCPTool_Execute_Success(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + // Verify correct parameters passed + if serverName != "github" { + t.Errorf("Expected serverName 'github', got '%s'", serverName) + } + if toolName != "search_repos" { + t.Errorf("Expected toolName 'search_repos', got '%s'", toolName) + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Found 3 repositories"}, + }, + IsError: false, + }, nil + }, + } + + tool := &mcp.Tool{ + Name: "search_repos", + Description: "Search GitHub repositories", + } + mcpTool := NewMCPTool(manager, "github", tool) + + ctx := context.Background() + args := map[string]any{ + "query": "golang mcp", + } + + result := mcpTool.Execute(ctx, args) + + if result == nil { + t.Fatal("Result should not be nil") + } + if result.IsError { + t.Errorf("Expected no error, got error: %s", result.ForLLM) + } + if result.ForLLM != "Found 3 repositories" { + t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM) + } +} + +// TestMCPTool_Execute_ManagerError tests execution when manager returns error +func TestMCPTool_Execute_ManagerError(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return nil, fmt.Errorf("connection failed") + }, + } + + tool := &mcp.Tool{Name: "test_tool"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]any{}) + + if result == nil { + t.Fatal("Result should not be nil") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if !strings.Contains(result.ForLLM, "MCP tool execution failed") { + t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "connection failed") { + t.Errorf("Error message should include original error, got: %s", result.ForLLM) + } +} + +// TestMCPTool_Execute_ServerError tests execution when server returns error +func TestMCPTool_Execute_ServerError(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "Invalid API key"}, + }, + IsError: true, + }, nil + }, + } + + tool := &mcp.Tool{Name: "test_tool"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]any{}) + + if result == nil { + t.Fatal("Result should not be nil") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if !strings.Contains(result.ForLLM, "MCP tool returned error") { + t.Errorf("Error message should mention server error, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Invalid API key") { + t.Errorf("Error message should include server message, got: %s", result.ForLLM) + } +} + +// TestMCPTool_Execute_MultipleContent tests execution with multiple content items +func TestMCPTool_Execute_MultipleContent(t *testing.T) { + manager := &MockMCPManager{ + callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: "First line"}, + &mcp.TextContent{Text: "Second line"}, + &mcp.TextContent{Text: "Third line"}, + }, + IsError: false, + }, nil + }, + } + + tool := &mcp.Tool{Name: "multi_output"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + ctx := context.Background() + result := mcpTool.Execute(ctx, map[string]any{}) + + if result.IsError { + t.Errorf("Expected no error, got: %s", result.ForLLM) + } + + expected := "First line\nSecond line\nThird line" + if result.ForLLM != expected { + t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM) + } +} + +// TestExtractContentText_TextContent tests text content extraction +func TestExtractContentText_TextContent(t *testing.T) { + content := []mcp.Content{ + &mcp.TextContent{Text: "Hello World"}, + &mcp.TextContent{Text: "Second message"}, + } + + result := extractContentText(content) + expected := "Hello World\nSecond message" + + if result != expected { + t.Errorf("Expected '%s', got '%s'", expected, result) + } +} + +// TestExtractContentText_ImageContent tests image content extraction +func TestExtractContentText_ImageContent(t *testing.T) { + content := []mcp.Content{ + &mcp.ImageContent{ + Data: []byte("base64data"), + MIMEType: "image/png", + }, + } + + result := extractContentText(content) + + if !strings.Contains(result, "[Image:") { + t.Errorf("Expected image indicator, got: %s", result) + } + if !strings.Contains(result, "image/png") { + t.Errorf("Expected MIME type in output, got: %s", result) + } +} + +// TestExtractContentText_MixedContent tests mixed content types +func TestExtractContentText_MixedContent(t *testing.T) { + content := []mcp.Content{ + &mcp.TextContent{Text: "Description"}, + &mcp.ImageContent{ + Data: []byte("data"), + MIMEType: "image/jpeg", + }, + &mcp.TextContent{Text: "More text"}, + } + + result := extractContentText(content) + + if !strings.Contains(result, "Description") { + t.Errorf("Should contain text content, got: %s", result) + } + if !strings.Contains(result, "[Image:") { + t.Errorf("Should contain image indicator, got: %s", result) + } + if !strings.Contains(result, "More text") { + t.Errorf("Should contain second text, got: %s", result) + } +} + +// TestExtractContentText_EmptyContent tests empty content array +func TestExtractContentText_EmptyContent(t *testing.T) { + content := []mcp.Content{} + + result := extractContentText(content) + + if result != "" { + t.Errorf("Expected empty string for empty content, got: %s", result) + } +} + +// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface +func TestMCPTool_InterfaceCompliance(t *testing.T) { + manager := &MockMCPManager{} + tool := &mcp.Tool{Name: "test"} + mcpTool := NewMCPTool(manager, "test_server", tool) + + // Verify it implements Tool interface + var _ Tool = mcpTool +} + +// TestMCPTool_Parameters_MapSchema tests schema that's already a map +func TestMCPTool_Parameters_MapSchema(t *testing.T) { + manager := &MockMCPManager{} + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "description": "The name parameter", + }, + }, + "required": []string{"name"}, + } + + tool := &mcp.Tool{ + Name: "test_tool", + InputSchema: schema, + } + mcpTool := NewMCPTool(manager, "test_server", tool) + + params := mcpTool.Parameters() + + // Should return the schema as-is when it's already a map + if params["type"] != "object" { + t.Errorf("Expected type 'object', got '%v'", params["type"]) + } + + props, ok := params["properties"].(map[string]any) + if !ok { + t.Error("Properties should be a map") + } + + nameParam, ok := props["name"].(map[string]any) + if !ok { + t.Error("Name parameter should exist") + } + + if nameParam["type"] != "string" { + t.Errorf("Name type should be 'string', got '%v'", nameParam["type"]) + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index d37a093a8..0ba983e02 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry { func (r *ToolRegistry) Register(tool Tool) { r.mu.Lock() defer r.mu.Unlock() - r.tools[tool.Name()] = tool + name := tool.Name() + if _, exists := r.tools[name]; exists { + logger.WarnCF("tools", "Tool registration overwrites existing tool", + map[string]any{"name": name}) + } + r.tools[name] = tool } func (r *ToolRegistry) Get(name string) (Tool, bool) { diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 10498126b..15d2330ff 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -109,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in return "", fmt.Errorf("failed to read response: %w", err) } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body)) + } + var searchResp struct { Web struct { Results []struct { diff --git a/scripts/test-docker-mcp.sh b/scripts/test-docker-mcp.sh new file mode 100755 index 000000000..9d582ffa0 --- /dev/null +++ b/scripts/test-docker-mcp.sh @@ -0,0 +1,49 @@ +#!/bin/sh +# Test script for MCP tools in Docker (full-featured image) + +set -e + +COMPOSE_FILE="docker/docker-compose.full.yml" +SERVICE="picoclaw-agent" + +echo "๐Ÿงช Testing MCP tools in Docker container (full-featured image)..." +echo "" + +# Build the image +echo "๐Ÿ“ฆ Building Docker image..." +docker compose -f "$COMPOSE_FILE" build "$SERVICE" + +# Test npx +echo "โœ… Testing npx..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'npx --version' + +# Test npm +echo "โœ… Testing npm..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'npm --version' + +# Test node +echo "โœ… Testing Node.js..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'node --version' + +# Test git +echo "โœ… Testing git..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'git --version' + +# Test python +echo "โœ… Testing Python..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'python3 --version' + +# Test uv +echo "โœ… Testing uv..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'uv --version' + +# Test MCP server installation (quick) +echo "โœ… Testing @modelcontextprotocol/server-filesystem MCP server install with npx..." +docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c '