mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge: resolve conflict with upstream/main in provider_test.go
Keep both KimiCodeUserAgent test (from this branch) and serializeMessages tests (from upstream) — no overlap. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -168,11 +168,11 @@ clean:
|
||||
@echo "Clean complete"
|
||||
|
||||
## vet: Run go vet for static analysis
|
||||
vet:
|
||||
vet: generate
|
||||
@$(GO) vet ./...
|
||||
|
||||
## test: Test Go code
|
||||
test:
|
||||
test: generate
|
||||
@$(GO) test ./...
|
||||
|
||||
## fmt: Format Go code
|
||||
@@ -204,6 +204,44 @@ check: deps fmt vet test
|
||||
run: build
|
||||
@$(BUILD_DIR)/$(BINARY_NAME) $(ARGS)
|
||||
|
||||
## docker-build: Build Docker image (minimal Alpine-based)
|
||||
docker-build:
|
||||
@echo "Building minimal Docker image (Alpine-based)..."
|
||||
docker compose -f docker/docker-compose.yml build picoclaw-agent picoclaw-gateway
|
||||
|
||||
## docker-build-full: Build Docker image with full MCP support (Node.js 24)
|
||||
docker-build-full:
|
||||
@echo "Building full-featured Docker image (Node.js 24)..."
|
||||
docker compose -f docker/docker-compose.full.yml build picoclaw-agent picoclaw-gateway
|
||||
|
||||
## docker-test: Test MCP tools in Docker container
|
||||
docker-test:
|
||||
@echo "Testing MCP tools in Docker..."
|
||||
@chmod +x scripts/test-docker-mcp.sh
|
||||
@./scripts/test-docker-mcp.sh
|
||||
|
||||
## docker-run: Run picoclaw gateway in Docker (Alpine-based)
|
||||
docker-run:
|
||||
docker compose -f docker/docker-compose.yml --profile gateway up
|
||||
|
||||
## docker-run-full: Run picoclaw gateway in Docker (full-featured)
|
||||
docker-run-full:
|
||||
docker compose -f docker/docker-compose.full.yml --profile gateway up
|
||||
|
||||
## docker-run-agent: Run picoclaw agent in Docker (interactive, Alpine-based)
|
||||
docker-run-agent:
|
||||
docker compose -f docker/docker-compose.yml run --rm picoclaw-agent
|
||||
|
||||
## docker-run-agent-full: Run picoclaw agent in Docker (interactive, full-featured)
|
||||
docker-run-agent-full:
|
||||
docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent
|
||||
|
||||
## docker-clean: Clean Docker images and volumes
|
||||
docker-clean:
|
||||
docker compose -f docker/docker-compose.yml down -v
|
||||
docker compose -f docker/docker-compose.full.yml down -v
|
||||
docker rmi picoclaw:latest picoclaw:full 2>/dev/null || true
|
||||
|
||||
## help: Show this help message
|
||||
help:
|
||||
@echo "picoclaw Makefile"
|
||||
@@ -219,6 +257,8 @@ help:
|
||||
@echo " make install # Install to ~/.local/bin"
|
||||
@echo " make uninstall # Remove from /usr/local/bin"
|
||||
@echo " make install-skills # Install skills to workspace"
|
||||
@echo " make docker-build # Build minimal Docker image"
|
||||
@echo " make docker-test # Test MCP tools in Docker"
|
||||
@echo ""
|
||||
@echo "Environment Variables:"
|
||||
@echo " INSTALL_PREFIX # Installation prefix (default: ~/.local)"
|
||||
|
||||
@@ -721,6 +721,20 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
|
||||
└── USER.md # User preferences
|
||||
```
|
||||
|
||||
### Skill Sources
|
||||
|
||||
By default, skills are loaded from:
|
||||
|
||||
1. `~/.picoclaw/workspace/skills` (workspace)
|
||||
2. `~/.picoclaw/skills` (global)
|
||||
3. `<current-working-directory>/skills` (builtin)
|
||||
|
||||
For advanced/test setups, you can override the builtin skills root with:
|
||||
|
||||
```bash
|
||||
export PICOCLAW_BUILTIN_SKILLS=/path/to/skills
|
||||
```
|
||||
|
||||
### 🔒 Security Sandbox
|
||||
|
||||
PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace.
|
||||
@@ -925,7 +939,7 @@ This design also enables **multi-agent support** with flexible provider selectio
|
||||
#### 📋 All Supported Vendors
|
||||
|
||||
| Vendor | `model` Prefix | Default API Base | Protocol | API Key |
|
||||
| ------------------- | ----------------- | --------------------------------------------------- | --------- | ---------------------------------------------------------------- |
|
||||
| ------------------- | ----------------- |-----------------------------------------------------| --------- | ---------------------------------------------------------------- |
|
||||
| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) |
|
||||
| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) |
|
||||
| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
|
||||
@@ -937,6 +951,7 @@ This design also enables **multi-agent support** with flexible provider selectio
|
||||
| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) |
|
||||
| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) |
|
||||
| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) |
|
||||
| **LiteLLM Proxy** | `litellm/` | `http://localhost:4000/v1 | OpenAI | Your LiteLLM proxy key |
|
||||
| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local |
|
||||
| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) |
|
||||
| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) |
|
||||
@@ -1038,6 +1053,19 @@ This design also enables **multi-agent support** with flexible provider selectio
|
||||
}
|
||||
```
|
||||
|
||||
**LiteLLM Proxy**
|
||||
|
||||
```json
|
||||
{
|
||||
"model_name": "lite-gpt4",
|
||||
"model": "litellm/lite-gpt4",
|
||||
"api_base": "http://localhost:4000/v1",
|
||||
"api_key": "sk-..."
|
||||
}
|
||||
```
|
||||
|
||||
PicoClaw strips only the outer `litellm/` prefix before sending the request, so proxy aliases like `litellm/lite-gpt4` send `lite-gpt4`, while `litellm/openai/gpt-4o` sends `openai/gpt-4o`.
|
||||
|
||||
#### Load Balancing
|
||||
|
||||
Configure multiple endpoints for the same model name—PicoClaw will automatically round-robin between them:
|
||||
|
||||
@@ -362,6 +362,20 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work
|
||||
|
||||
```
|
||||
|
||||
### 技能来源 (Skill Sources)
|
||||
|
||||
默认情况下,技能会按以下顺序加载:
|
||||
|
||||
1. `~/.picoclaw/workspace/skills`(工作区)
|
||||
2. `~/.picoclaw/skills`(全局)
|
||||
3. `<current-working-directory>/skills`(内置)
|
||||
|
||||
在高级/测试场景下,可通过以下环境变量覆盖内置技能目录:
|
||||
|
||||
```bash
|
||||
export PICOCLAW_BUILTIN_SKILLS=/path/to/skills
|
||||
```
|
||||
|
||||
### 心跳 / 周期性任务 (Heartbeat)
|
||||
|
||||
PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件:
|
||||
|
||||
@@ -5,6 +5,19 @@ import (
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
const (
|
||||
colorBlue = "[#3e5db9]"
|
||||
colorRed = "[#d54646]"
|
||||
banner = "\r\n[::b]" +
|
||||
colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" +
|
||||
colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" +
|
||||
colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" +
|
||||
colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" +
|
||||
colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
|
||||
colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " +
|
||||
"[:]"
|
||||
)
|
||||
|
||||
func applyStyles() {
|
||||
tview.Styles.PrimitiveBackgroundColor = tcell.NewRGBColor(12, 13, 22)
|
||||
tview.Styles.ContrastBackgroundColor = tcell.NewRGBColor(34, 19, 53)
|
||||
@@ -24,14 +37,7 @@ func bannerView() *tview.TextView {
|
||||
text.SetDynamicColors(true)
|
||||
text.SetTextAlign(tview.AlignCenter)
|
||||
text.SetBackgroundColor(tview.Styles.PrimitiveBackgroundColor)
|
||||
text.SetText(
|
||||
"[::b][#84aaff]██████╗ ██╗ ██████╗ ██████╗ ██████╗██╗ █████╗ ██╗ ██╗\n" +
|
||||
"[#84aaff]██╔══██╗██║██╔════╝██╔═══██╗██╔════╝██║ ██╔══██╗██║ ██║\n" +
|
||||
"[#84aaff]██████╔╝██║██║ ██║ ██║██║ ██║ ███████║██║ █╗ ██║\n" +
|
||||
"[#84aaff]██╔═══╝ ██║██║ ██║ ██║██║ ██║ ██╔══██║██║███╗██║\n" +
|
||||
"[#84aaff]██║ ██║╚██████╗╚██████╔╝╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
|
||||
"[#84aaff]╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝",
|
||||
)
|
||||
text.SetText(banner)
|
||||
text.SetBorder(false)
|
||||
return text
|
||||
}
|
||||
|
||||
@@ -48,7 +48,21 @@ func NewPicoclawCommand() *cobra.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
const (
|
||||
colorBlue = "\033[1;38;2;62;93;185m"
|
||||
colorRed = "\033[1;38;2;213;70;70m"
|
||||
banner = "\r\n" +
|
||||
colorBlue + "██████╗ ██╗ ██████╗ ██████╗ " + colorRed + " ██████╗██╗ █████╗ ██╗ ██╗\n" +
|
||||
colorBlue + "██╔══██╗██║██╔════╝██╔═══██╗" + colorRed + "██╔════╝██║ ██╔══██╗██║ ██║\n" +
|
||||
colorBlue + "██████╔╝██║██║ ██║ ██║" + colorRed + "██║ ██║ ███████║██║ █╗ ██║\n" +
|
||||
colorBlue + "██╔═══╝ ██║██║ ██║ ██║" + colorRed + "██║ ██║ ██╔══██║██║███╗██║\n" +
|
||||
colorBlue + "██║ ██║╚██████╗╚██████╔╝" + colorRed + "╚██████╗███████╗██║ ██║╚███╔███╔╝\n" +
|
||||
colorBlue + "╚═╝ ╚═╝ ╚═════╝ ╚═════╝ " + colorRed + " ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\n " +
|
||||
"\033[0m\r\n"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Printf("%s", banner)
|
||||
cmd := NewPicoclawCommand()
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
|
||||
@@ -49,6 +49,7 @@
|
||||
"telegram": {
|
||||
"enabled": false,
|
||||
"token": "YOUR_TELEGRAM_BOT_TOKEN",
|
||||
"base_url": "",
|
||||
"proxy": "",
|
||||
"allow_from": [
|
||||
"YOUR_USER_ID"
|
||||
@@ -243,6 +244,71 @@
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
},
|
||||
"mcp": {
|
||||
"enabled": false,
|
||||
"servers": {
|
||||
"context7": {
|
||||
"enabled": false,
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp",
|
||||
"headers": {
|
||||
"CONTEXT7_API_KEY": "ctx7sk-xx"
|
||||
}
|
||||
},
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/tmp"
|
||||
]
|
||||
},
|
||||
"github": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-github"
|
||||
],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
|
||||
}
|
||||
},
|
||||
"brave-search": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-brave-search"
|
||||
],
|
||||
"env": {
|
||||
"BRAVE_API_KEY": "YOUR_BRAVE_API_KEY"
|
||||
}
|
||||
},
|
||||
"postgres": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-postgres",
|
||||
"postgresql://user:password@localhost/dbname"
|
||||
]
|
||||
},
|
||||
"slack": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-slack"
|
||||
],
|
||||
"env": {
|
||||
"SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN",
|
||||
"SLACK_TEAM_ID": "YOUR_SLACK_TEAM_ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"exec": {
|
||||
"enable_deny_patterns": false,
|
||||
"custom_deny_patterns": []
|
||||
@@ -271,4 +337,4 @@
|
||||
"host": "127.0.0.1",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
# ============================================================
|
||||
# Stage 1: Build the picoclaw binary
|
||||
# ============================================================
|
||||
FROM golang:1.26.0-alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git make
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
# Cache dependencies
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy source and build
|
||||
COPY . .
|
||||
RUN make build
|
||||
|
||||
# ============================================================
|
||||
# Stage 2: Node.js-based runtime with full MCP support
|
||||
# ============================================================
|
||||
FROM node:24-alpine3.23
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apk add --no-cache \
|
||||
ca-certificates \
|
||||
curl \
|
||||
git \
|
||||
python3 \
|
||||
py3-pip
|
||||
|
||||
# Install uv and symlink to system path
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
ln -s /root/.local/bin/uv /usr/local/bin/uv && \
|
||||
ln -s /root/.local/bin/uvx /usr/local/bin/uvx && \
|
||||
uv --version
|
||||
|
||||
# Copy binary
|
||||
COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
|
||||
|
||||
# Create picoclaw home directory
|
||||
RUN /usr/local/bin/picoclaw onboard
|
||||
|
||||
ENTRYPOINT ["picoclaw"]
|
||||
CMD ["gateway"]
|
||||
@@ -0,0 +1,44 @@
|
||||
services:
|
||||
# ─────────────────────────────────────────────
|
||||
# PicoClaw Agent (one-shot query) - Full MCP Support
|
||||
# docker compose -f docker/docker-compose.full.yml run --rm picoclaw-agent -m "Hello"
|
||||
# ─────────────────────────────────────────────
|
||||
picoclaw-agent:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile.full
|
||||
container_name: picoclaw-agent-full
|
||||
profiles:
|
||||
- agent
|
||||
volumes:
|
||||
- ../config/config.json:/root/.picoclaw/config.json:ro
|
||||
- picoclaw-workspace:/root/.picoclaw/workspace
|
||||
- picoclaw-npm-cache:/root/.npm # npm cache for faster MCP server installs
|
||||
entrypoint: ["picoclaw", "agent"]
|
||||
stdin_open: true
|
||||
tty: true
|
||||
|
||||
# ─────────────────────────────────────────────
|
||||
# PicoClaw Gateway (Long-running Bot) - Full MCP Support
|
||||
# docker compose -f docker/docker-compose.full.yml --profile gateway up
|
||||
# ─────────────────────────────────────────────
|
||||
picoclaw-gateway:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile.full
|
||||
container_name: picoclaw-gateway-full
|
||||
restart: unless-stopped
|
||||
profiles:
|
||||
- gateway
|
||||
volumes:
|
||||
# Configuration file
|
||||
- ../config/config.json:/root/.picoclaw/config.json:ro
|
||||
# Persistent workspace (sessions, memory, logs)
|
||||
- picoclaw-workspace:/root/.picoclaw/workspace
|
||||
# NPM cache for faster MCP server installs
|
||||
- picoclaw-npm-cache:/root/.npm
|
||||
command: ["gateway"]
|
||||
|
||||
volumes:
|
||||
picoclaw-workspace:
|
||||
picoclaw-npm-cache: # Cache npm packages to speed up MCP server installations
|
||||
+108
-33
@@ -8,6 +8,7 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`.
|
||||
{
|
||||
"tools": {
|
||||
"web": { ... },
|
||||
"mcp": { ... },
|
||||
"exec": { ... },
|
||||
"cron": { ... },
|
||||
"skills": { ... }
|
||||
@@ -21,35 +22,35 @@ Web tools are used for web search and fetching.
|
||||
|
||||
### Brave
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enabled` | bool | false | Enable Brave search |
|
||||
| `api_key` | string | - | Brave Search API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
| Config | Type | Default | Description |
|
||||
| ------------- | ------ | ------- | ------------------------- |
|
||||
| `enabled` | bool | false | Enable Brave search |
|
||||
| `api_key` | string | - | Brave Search API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
|
||||
### DuckDuckGo
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enabled` | bool | true | Enable DuckDuckGo search |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
| Config | Type | Default | Description |
|
||||
| ------------- | ---- | ------- | ------------------------- |
|
||||
| `enabled` | bool | true | Enable DuckDuckGo search |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
|
||||
### Perplexity
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enabled` | bool | false | Enable Perplexity search |
|
||||
| `api_key` | string | - | Perplexity API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
| Config | Type | Default | Description |
|
||||
| ------------- | ------ | ------- | ------------------------- |
|
||||
| `enabled` | bool | false | Enable Perplexity search |
|
||||
| `api_key` | string | - | Perplexity API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
|
||||
## Exec Tool
|
||||
|
||||
The exec tool is used to execute shell commands.
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
|
||||
| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
|
||||
| Config | Type | Default | Description |
|
||||
| ---------------------- | ----- | ------- | ------------------------------------------ |
|
||||
| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
|
||||
| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
|
||||
|
||||
### Functionality
|
||||
|
||||
@@ -80,10 +81,7 @@ By default, PicoClaw blocks the following dangerous commands:
|
||||
"tools": {
|
||||
"exec": {
|
||||
"enable_deny_patterns": true,
|
||||
"custom_deny_patterns": [
|
||||
"\\brm\\s+-r\\b",
|
||||
"\\bkillall\\s+python"
|
||||
]
|
||||
"custom_deny_patterns": ["\\brm\\s+-r\\b", "\\bkillall\\s+python"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,9 +91,84 @@ By default, PicoClaw blocks the following dangerous commands:
|
||||
|
||||
The cron tool is used for scheduling periodic tasks.
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
|
||||
| Config | Type | Default | Description |
|
||||
| ---------------------- | ---- | ------- | ---------------------------------------------- |
|
||||
| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
|
||||
|
||||
## MCP Tool
|
||||
|
||||
The MCP tool enables integration with external Model Context Protocol servers.
|
||||
|
||||
### Global Config
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
| --------- | ------ | ------- | ----------------------------------- |
|
||||
| `enabled` | bool | false | Enable MCP integration globally |
|
||||
| `servers` | object | `{}` | Map of server name to server config |
|
||||
|
||||
### Per-Server Config
|
||||
|
||||
| Config | Type | Required | Description |
|
||||
| ---------- | ------ | -------- | ------------------------------------------ |
|
||||
| `enabled` | bool | yes | Enable this MCP server |
|
||||
| `type` | string | no | Transport type: `stdio`, `sse`, `http` |
|
||||
| `command` | string | stdio | Executable command for stdio transport |
|
||||
| `args` | array | no | Command arguments for stdio transport |
|
||||
| `env` | object | no | Environment variables for stdio process |
|
||||
| `env_file` | string | no | Path to environment file for stdio process |
|
||||
| `url` | string | sse/http | Endpoint URL for `sse`/`http` transport |
|
||||
| `headers` | object | no | HTTP headers for `sse`/`http` transport |
|
||||
|
||||
### Transport Behavior
|
||||
|
||||
- If `type` is omitted, transport is auto-detected:
|
||||
- `url` is set → `sse`
|
||||
- `command` is set → `stdio`
|
||||
- `http` and `sse` both use `url` + optional `headers`.
|
||||
- `env` and `env_file` are only applied to `stdio` servers.
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### 1) Stdio MCP server
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcp": {
|
||||
"enabled": true,
|
||||
"servers": {
|
||||
"filesystem": {
|
||||
"enabled": true,
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 2) Remote SSE/HTTP MCP server
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcp": {
|
||||
"enabled": true,
|
||||
"servers": {
|
||||
"remote-mcp": {
|
||||
"enabled": true,
|
||||
"type": "sse",
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer YOUR_TOKEN"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Skills Tool
|
||||
|
||||
@@ -103,13 +176,13 @@ The skills tool configures skill discovery and installation via registries like
|
||||
|
||||
### Registries
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
|
||||
| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
|
||||
| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
|
||||
| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
|
||||
| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
|
||||
| Config | Type | Default | Description |
|
||||
| ---------------------------------- | ------ | -------------------- | ----------------------- |
|
||||
| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
|
||||
| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
|
||||
| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
|
||||
| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
|
||||
| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
|
||||
|
||||
### Configuration Example
|
||||
|
||||
@@ -136,8 +209,10 @@ The skills tool configures skill discovery and installation via registries like
|
||||
All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS_<SECTION>_<KEY>`:
|
||||
|
||||
For example:
|
||||
|
||||
- `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true`
|
||||
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
|
||||
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
|
||||
- `PICOCLAW_TOOLS_MCP_ENABLED=true`
|
||||
|
||||
Note: Array-type environment variables are not currently supported and must be set via the config file.
|
||||
Note: Nested map-style config (for example `tools.mcp.servers.<name>.*`) is configured in `config.json` rather than environment variables.
|
||||
|
||||
@@ -8,13 +8,16 @@ require (
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/caarlos0/env/v11 v11.3.1
|
||||
github.com/chzyer/readline v1.5.1
|
||||
github.com/gdamore/tcell/v2 v2.13.8
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0
|
||||
github.com/mymmrac/telego v1.6.0
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/openai/openai-go/v3 v3.22.0
|
||||
github.com/rivo/tview v0.42.0
|
||||
github.com/slack-go/slack v0.17.3
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
@@ -35,6 +38,7 @@ require (
|
||||
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/gdamore/tcell/v2 v2.13.8 // indirect
|
||||
github.com/h2non/filetype v1.1.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
@@ -43,7 +47,6 @@ require (
|
||||
github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/tview v0.42.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
@@ -81,6 +84,7 @@ require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.69.0 // indirect
|
||||
github.com/valyala/fastjson v1.6.7 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
golang.org/x/arch v0.24.0 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/net v0.50.0 // indirect
|
||||
|
||||
@@ -66,6 +66,8 @@ github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncV
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
|
||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
|
||||
@@ -96,6 +98,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc=
|
||||
github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek=
|
||||
github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg=
|
||||
github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
@@ -130,6 +134,8 @@ github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp
|
||||
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
|
||||
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0 h1:gMfZkv3DzQF5q/DcQePo5rahEY+sguyPfXDfNBcT0Zs=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0/go.mod h1:AnQ//Qc6+4nIyyrB4cxBU7UW9VibK4iOZBeyP/rF1IE=
|
||||
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
|
||||
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
@@ -212,6 +218,8 @@ github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTd
|
||||
github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
|
||||
+123
-58
@@ -34,6 +34,11 @@ type ContextBuilder struct {
|
||||
// created (didn't exist at cache time, now exist) or deleted (existed at
|
||||
// cache time, now gone) — both of which should trigger a cache rebuild.
|
||||
existedAtCache map[string]bool
|
||||
|
||||
// skillFilesAtCache snapshots the skill tree file set and mtimes at cache
|
||||
// build time. This catches nested file creations/deletions/mtime changes
|
||||
// that may not update the top-level skill root directory mtime.
|
||||
skillFilesAtCache map[string]time.Time
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
@@ -47,8 +52,11 @@ func getGlobalConfigDir() string {
|
||||
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
// builtin skills: skills directory in current project
|
||||
// Use the skills/ directory under the current working directory
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir := filepath.Join(wd, "skills")
|
||||
builtinSkillsDir := strings.TrimSpace(os.Getenv("PICOCLAW_BUILTIN_SKILLS"))
|
||||
if builtinSkillsDir == "" {
|
||||
wd, _ := os.Getwd()
|
||||
builtinSkillsDir = filepath.Join(wd, "skills")
|
||||
}
|
||||
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
|
||||
|
||||
return &ContextBuilder{
|
||||
@@ -148,6 +156,7 @@ func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
|
||||
cb.cachedSystemPrompt = prompt
|
||||
cb.cachedAt = baseline.maxMtime
|
||||
cb.existedAtCache = baseline.existed
|
||||
cb.skillFilesAtCache = baseline.skillFiles
|
||||
|
||||
logger.DebugCF("agent", "System prompt cached",
|
||||
map[string]any{
|
||||
@@ -167,14 +176,14 @@ func (cb *ContextBuilder) InvalidateCache() {
|
||||
cb.cachedSystemPrompt = ""
|
||||
cb.cachedAt = time.Time{}
|
||||
cb.existedAtCache = nil
|
||||
cb.skillFilesAtCache = nil
|
||||
|
||||
logger.DebugCF("agent", "System prompt cache invalidated", nil)
|
||||
}
|
||||
|
||||
// sourcePaths returns the workspace source file paths tracked for cache
|
||||
// invalidation (bootstrap files + memory). The skills directory is handled
|
||||
// separately in sourceFilesChangedLocked because it requires both directory-
|
||||
// level and recursive file-level mtime checks.
|
||||
// sourcePaths returns non-skill workspace source files tracked for cache
|
||||
// invalidation (bootstrap files + memory). Skill roots are handled separately
|
||||
// because they require both directory-level and recursive file-level checks.
|
||||
func (cb *ContextBuilder) sourcePaths() []string {
|
||||
return []string{
|
||||
filepath.Join(cb.workspace, "AGENTS.md"),
|
||||
@@ -185,23 +194,39 @@ func (cb *ContextBuilder) sourcePaths() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// skillRoots returns all skill root directories that can affect
|
||||
// BuildSkillsSummary output (workspace/global/builtin).
|
||||
func (cb *ContextBuilder) skillRoots() []string {
|
||||
if cb.skillsLoader == nil {
|
||||
return []string{filepath.Join(cb.workspace, "skills")}
|
||||
}
|
||||
|
||||
roots := cb.skillsLoader.SkillRoots()
|
||||
if len(roots) == 0 {
|
||||
return []string{filepath.Join(cb.workspace, "skills")}
|
||||
}
|
||||
return roots
|
||||
}
|
||||
|
||||
// cacheBaseline holds the file existence snapshot and the latest observed
|
||||
// mtime across all tracked paths. Used as the cache reference point.
|
||||
type cacheBaseline struct {
|
||||
existed map[string]bool
|
||||
maxMtime time.Time
|
||||
existed map[string]bool
|
||||
skillFiles map[string]time.Time
|
||||
maxMtime time.Time
|
||||
}
|
||||
|
||||
// buildCacheBaseline records which tracked paths currently exist and computes
|
||||
// the latest mtime across all tracked files + skills directory contents.
|
||||
// Called under write lock when the cache is built.
|
||||
func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
skillsDir := filepath.Join(cb.workspace, "skills")
|
||||
skillRoots := cb.skillRoots()
|
||||
|
||||
// All paths whose existence we track: source files + skills dir.
|
||||
allPaths := append(cb.sourcePaths(), skillsDir)
|
||||
// All paths whose existence we track: source files + all skill roots.
|
||||
allPaths := append(cb.sourcePaths(), skillRoots...)
|
||||
|
||||
existed := make(map[string]bool, len(allPaths))
|
||||
skillFiles := make(map[string]time.Time)
|
||||
var maxMtime time.Time
|
||||
|
||||
for _, p := range allPaths {
|
||||
@@ -212,17 +237,21 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
}
|
||||
}
|
||||
|
||||
// Walk skills files to capture their mtimes too.
|
||||
// Use os.Stat (not d.Info) to match the stat method used in
|
||||
// fileChangedSince / skillFilesModifiedSince for consistency.
|
||||
_ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
|
||||
maxMtime = info.ModTime()
|
||||
// Walk all skill roots recursively to snapshot skill files and mtimes.
|
||||
// Use os.Stat (not d.Info) for consistency with sourceFilesChanged checks.
|
||||
for _, root := range skillRoots {
|
||||
_ = filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, err := os.Stat(path); err == nil {
|
||||
skillFiles[path] = info.ModTime()
|
||||
if info.ModTime().After(maxMtime) {
|
||||
maxMtime = info.ModTime()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// If no tracked files exist yet (empty workspace), maxMtime is zero.
|
||||
// Use a very old non-zero time so that:
|
||||
@@ -234,7 +263,7 @@ func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
|
||||
maxMtime = time.Unix(1, 0)
|
||||
}
|
||||
|
||||
return cacheBaseline{existed: existed, maxMtime: maxMtime}
|
||||
return cacheBaseline{existed: existed, skillFiles: skillFiles, maxMtime: maxMtime}
|
||||
}
|
||||
|
||||
// sourceFilesChangedLocked checks whether any workspace source file has been
|
||||
@@ -254,21 +283,17 @@ func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// --- Skills directory (handled separately from sourcePaths) ---
|
||||
// --- Skill roots (workspace/global/builtin) ---
|
||||
//
|
||||
// 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
|
||||
skillsDir := filepath.Join(cb.workspace, "skills")
|
||||
if cb.fileChangedSince(skillsDir) {
|
||||
return true
|
||||
// For each root:
|
||||
// 1. Creation/deletion and root directory mtime changes are tracked by fileChangedSince.
|
||||
// 2. Nested file create/delete/mtime changes are tracked by the skill file snapshot.
|
||||
for _, root := range cb.skillRoots() {
|
||||
if cb.fileChangedSince(root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Structural changes (add/remove entries inside the dir) are reflected
|
||||
// in the directory's own mtime, which fileChangedSince already checks.
|
||||
//
|
||||
// 3. Content-only edits to files inside skills/ do NOT update the parent
|
||||
// directory mtime on most filesystems, so we recursively walk to check
|
||||
// individual file mtimes at any nesting depth.
|
||||
if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
|
||||
if skillFilesChangedSince(cb.skillRoots(), cb.skillFilesAtCache) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -309,28 +334,64 @@ func (cb *ContextBuilder) fileChangedSince(path string) bool {
|
||||
// if the callback returned nil when its err parameter is non-nil.
|
||||
var errWalkStop = errors.New("walk stop")
|
||||
|
||||
// skillFilesModifiedSince recursively walks the skills directory and checks
|
||||
// whether any file was modified after t. This catches content-only edits at
|
||||
// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
|
||||
// parent directory mtimes.
|
||||
func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
|
||||
changed := false
|
||||
err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr == nil && !d.IsDir() {
|
||||
if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
|
||||
changed = true
|
||||
return errWalkStop // stop walking
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// errWalkStop is expected (early exit on first changed file).
|
||||
// os.IsNotExist means the skills dir doesn't exist yet — not an error.
|
||||
// Any other error is unexpected and worth logging.
|
||||
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
|
||||
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
|
||||
// skillFilesChangedSince compares the current recursive skill file tree
|
||||
// against the cache-time snapshot. Any create/delete/mtime drift invalidates
|
||||
// the cache.
|
||||
func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Time) bool {
|
||||
// Defensive: if the snapshot was never initialized, force rebuild.
|
||||
if filesAtCache == nil {
|
||||
return true
|
||||
}
|
||||
return changed
|
||||
|
||||
// Check cached files still exist and keep the same mtime.
|
||||
for path, cachedMtime := range filesAtCache {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
// A previously tracked file disappeared (or became inaccessible):
|
||||
// either way, cached skill summary may now be stale.
|
||||
return true
|
||||
}
|
||||
if !info.ModTime().Equal(cachedMtime) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check no new files appeared under any skill root.
|
||||
changed := false
|
||||
for _, root := range skillRoots {
|
||||
if strings.TrimSpace(root) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
// Treat unexpected walk errors as changed to avoid stale cache.
|
||||
if !os.IsNotExist(walkErr) {
|
||||
changed = true
|
||||
return errWalkStop
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if _, ok := filesAtCache[path]; !ok {
|
||||
changed = true
|
||||
return errWalkStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if changed {
|
||||
return true
|
||||
}
|
||||
if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
|
||||
logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (cb *ContextBuilder) LoadBootstrapFiles() string {
|
||||
@@ -466,10 +527,14 @@ func (cb *ContextBuilder) BuildMessages(
|
||||
|
||||
// Add current user message
|
||||
if strings.TrimSpace(currentMessage) != "" {
|
||||
messages = append(messages, providers.Message{
|
||||
msg := providers.Message{
|
||||
Role: "user",
|
||||
Content: currentMessage,
|
||||
})
|
||||
}
|
||||
if len(media) > 0 {
|
||||
msg.Media = media
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages
|
||||
|
||||
@@ -383,6 +383,162 @@ Updated content.`
|
||||
}
|
||||
}
|
||||
|
||||
// TestGlobalSkillFileContentChange verifies that modifying a global skill
|
||||
// (~/.picoclaw/skills) invalidates the cached system prompt.
|
||||
func TestGlobalSkillFileContentChange(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
|
||||
tmpDir := setupWorkspace(t, nil)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
globalSkillPath := filepath.Join(tmpHome, ".picoclaw", "skills", "global-skill", "SKILL.md")
|
||||
if err := os.MkdirAll(filepath.Dir(globalSkillPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v1 := `---
|
||||
name: global-skill
|
||||
description: global-v1
|
||||
---
|
||||
# Global Skill v1`
|
||||
if err := os.WriteFile(globalSkillPath, []byte(v1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "global-v1") {
|
||||
t.Fatal("expected initial prompt to contain global skill description")
|
||||
}
|
||||
|
||||
v2 := `---
|
||||
name: global-skill
|
||||
description: global-v2
|
||||
---
|
||||
# Global Skill v2`
|
||||
if err := os.WriteFile(globalSkillPath, []byte(v2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(globalSkillPath, future, future); err != nil {
|
||||
t.Fatalf("failed to update mtime for %s: %v", globalSkillPath, err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect global skill file content change")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp2, "global-v2") {
|
||||
t.Error("rebuilt prompt should contain updated global skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when global skill file content changes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuiltinSkillFileContentChange verifies that modifying a builtin skill
|
||||
// invalidates the cached system prompt.
|
||||
func TestBuiltinSkillFileContentChange(t *testing.T) {
|
||||
tmpHome := t.TempDir()
|
||||
t.Setenv("HOME", tmpHome)
|
||||
|
||||
tmpDir := setupWorkspace(t, nil)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
builtinRoot := t.TempDir()
|
||||
t.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot)
|
||||
|
||||
builtinSkillPath := filepath.Join(builtinRoot, "builtin-skill", "SKILL.md")
|
||||
if err := os.MkdirAll(filepath.Dir(builtinSkillPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v1 := `---
|
||||
name: builtin-skill
|
||||
description: builtin-v1
|
||||
---
|
||||
# Builtin Skill v1`
|
||||
if err := os.WriteFile(builtinSkillPath, []byte(v1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "builtin-v1") {
|
||||
t.Fatal("expected initial prompt to contain builtin skill description")
|
||||
}
|
||||
|
||||
v2 := `---
|
||||
name: builtin-skill
|
||||
description: builtin-v2
|
||||
---
|
||||
# Builtin Skill v2`
|
||||
if err := os.WriteFile(builtinSkillPath, []byte(v2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
future := time.Now().Add(2 * time.Second)
|
||||
if err := os.Chtimes(builtinSkillPath, future, future); err != nil {
|
||||
t.Fatalf("failed to update mtime for %s: %v", builtinSkillPath, err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect builtin skill file content change")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp2, "builtin-v2") {
|
||||
t.Error("rebuilt prompt should contain updated builtin skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when builtin skill file content changes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSkillFileDeletionInvalidatesCache verifies that deleting a nested skill
|
||||
// file invalidates the cached system prompt.
|
||||
func TestSkillFileDeletionInvalidatesCache(t *testing.T) {
|
||||
tmpDir := setupWorkspace(t, map[string]string{
|
||||
"skills/delete-me/SKILL.md": `---
|
||||
name: delete-me
|
||||
description: delete-me-v1
|
||||
---
|
||||
# Delete Me`,
|
||||
})
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cb := NewContextBuilder(tmpDir)
|
||||
sp1 := cb.BuildSystemPromptWithCache()
|
||||
if !strings.Contains(sp1, "delete-me-v1") {
|
||||
t.Fatal("expected initial prompt to contain skill description")
|
||||
}
|
||||
|
||||
skillPath := filepath.Join(tmpDir, "skills", "delete-me", "SKILL.md")
|
||||
if err := os.Remove(skillPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cb.systemPromptMutex.RLock()
|
||||
changed := cb.sourceFilesChangedLocked()
|
||||
cb.systemPromptMutex.RUnlock()
|
||||
if !changed {
|
||||
t.Fatal("sourceFilesChangedLocked() should detect deleted skill file")
|
||||
}
|
||||
|
||||
sp2 := cb.BuildSystemPromptWithCache()
|
||||
if strings.Contains(sp2, "delete-me-v1") {
|
||||
t.Error("rebuilt prompt should not contain deleted skill description")
|
||||
}
|
||||
if sp1 == sp2 {
|
||||
t.Error("cache should be invalidated when skill file is deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
|
||||
// can safely call BuildSystemPromptWithCache concurrently without producing
|
||||
// empty results, panics, or data races.
|
||||
|
||||
+169
-36
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -46,19 +47,24 @@ type AgentLoop struct {
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SessionKey string // Session identifier for history/context
|
||||
Channel string // Target channel for tool execution
|
||||
ChatID string // Target chat ID for tool execution
|
||||
UserMessage string // User message content (may include prefix)
|
||||
Media []string // media:// refs from inbound message
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
const defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
func NewAgentLoop(
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentLoop {
|
||||
registry := NewAgentRegistry(cfg, provider)
|
||||
|
||||
// Register shared tools to all agents
|
||||
@@ -170,6 +176,72 @@ func registerSharedTools(
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
// Initialize MCP servers for all agents
|
||||
if al.cfg.Tools.MCP.Enabled {
|
||||
mcpManager := mcp.NewManager()
|
||||
// Ensure MCP connections are cleaned up on exit, regardless of initialization success
|
||||
// This fixes resource leak when LoadFromMCPConfig partially succeeds then fails
|
||||
defer func() {
|
||||
if err := mcpManager.Close(); err != nil {
|
||||
logger.ErrorCF("agent", "Failed to close MCP manager",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
var workspacePath string
|
||||
if defaultAgent != nil && defaultAgent.Workspace != "" {
|
||||
workspacePath = defaultAgent.Workspace
|
||||
} else {
|
||||
workspacePath = al.cfg.WorkspacePath()
|
||||
}
|
||||
|
||||
if err := mcpManager.LoadFromMCPConfig(ctx, al.cfg.Tools.MCP, workspacePath); err != nil {
|
||||
logger.WarnCF("agent", "Failed to load MCP servers, MCP tools will not be available",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
// Register MCP tools for all agents
|
||||
servers := mcpManager.GetServers()
|
||||
uniqueTools := 0
|
||||
totalRegistrations := 0
|
||||
agentIDs := al.registry.ListAgentIDs()
|
||||
agentCount := len(agentIDs)
|
||||
|
||||
for serverName, conn := range servers {
|
||||
uniqueTools += len(conn.Tools)
|
||||
for _, tool := range conn.Tools {
|
||||
for _, agentID := range agentIDs {
|
||||
agent, ok := al.registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
agent.Tools.Register(mcpTool)
|
||||
totalRegistrations++
|
||||
logger.DebugCF("agent", "Registered MCP tool",
|
||||
map[string]any{
|
||||
"agent_id": agentID,
|
||||
"server": serverName,
|
||||
"tool": tool.Name,
|
||||
"name": mcpTool.Name(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.InfoCF("agent", "MCP tools registered successfully",
|
||||
map[string]any{
|
||||
"server_count": len(servers),
|
||||
"unique_tools": uniqueTools,
|
||||
"total_registrations": totalRegistrations,
|
||||
"agent_count": agentCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for al.running.Load() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -310,7 +382,10 @@ func (al *AgentLoop) RecordLastChatID(chatID string) error {
|
||||
return al.state.SetLastChatID(chatID)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
||||
func (al *AgentLoop) ProcessDirect(
|
||||
ctx context.Context,
|
||||
content, sessionKey string,
|
||||
) (string, error) {
|
||||
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
|
||||
}
|
||||
|
||||
@@ -331,7 +406,10 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
|
||||
// ProcessHeartbeat processes a heartbeat request without session history.
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
func (al *AgentLoop) ProcessHeartbeat(
|
||||
ctx context.Context,
|
||||
content, channel, chatID string,
|
||||
) (string, error) {
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
@@ -356,13 +434,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
} else {
|
||||
logContent = utils.Truncate(msg.Content, 80)
|
||||
}
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
map[string]any{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
"sender_id": msg.SenderID,
|
||||
"session_key": msg.SessionKey,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
// Route system messages to processSystemMessage
|
||||
if msg.Channel == "system" {
|
||||
@@ -417,15 +498,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
UserMessage: msg.Content,
|
||||
Media: msg.Media,
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: true,
|
||||
SendResponse: false,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
func (al *AgentLoop) processSystemMessage(
|
||||
ctx context.Context,
|
||||
msg bus.InboundMessage,
|
||||
) (string, error) {
|
||||
if msg.Channel != "system" {
|
||||
return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel)
|
||||
return "", fmt.Errorf(
|
||||
"processSystemMessage called with non-system message channel: %s",
|
||||
msg.Channel,
|
||||
)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Processing system message",
|
||||
@@ -483,14 +571,22 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
}
|
||||
|
||||
// runAgentLoop is the core message processing logic.
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) {
|
||||
func (al *AgentLoop) runAgentLoop(
|
||||
ctx context.Context,
|
||||
agent *AgentInstance,
|
||||
opts processOptions,
|
||||
) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Failed to record last channel",
|
||||
map[string]any{"error": err.Error()},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -509,11 +605,15 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
history,
|
||||
summary,
|
||||
opts.UserMessage,
|
||||
nil,
|
||||
opts.Media,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
// Resolve media:// refs to base64 data URLs (streaming)
|
||||
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 3. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
|
||||
@@ -572,7 +672,10 @@ func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string
|
||||
return ""
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) {
|
||||
func (al *AgentLoop) handleReasoning(
|
||||
ctx context.Context,
|
||||
reasoningContent, channelName, channelID string,
|
||||
) {
|
||||
if reasoningContent == "" || channelName == "" || channelID == "" {
|
||||
return
|
||||
}
|
||||
@@ -665,22 +768,33 @@ func (al *AgentLoop) runLLMIteration(
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
if len(agent.Candidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
agent.Candidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
})
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
messages,
|
||||
providerToolDefs,
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": agent.Temperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
return nil, fbErr
|
||||
}
|
||||
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
|
||||
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration})
|
||||
logger.InfoCF(
|
||||
"agent",
|
||||
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
|
||||
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
|
||||
map[string]any{"agent_id": agent.ID, "iteration": iteration},
|
||||
)
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
@@ -731,10 +845,14 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
logger.WarnCF(
|
||||
"agent",
|
||||
"Context window error detected, attempting compression",
|
||||
map[string]any{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
},
|
||||
)
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
@@ -766,7 +884,12 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel))
|
||||
go al.handleReasoning(
|
||||
ctx,
|
||||
response.Reasoning,
|
||||
opts.Channel,
|
||||
al.targetReasoningChannelID(opts.Channel),
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM response",
|
||||
map[string]any{
|
||||
@@ -1068,7 +1191,11 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
if tc.Function != nil {
|
||||
fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Arguments: %s\n",
|
||||
utils.Truncate(tc.Function.Arguments, 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1097,7 +1224,11 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
|
||||
fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
|
||||
fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
|
||||
fmt.Fprintf(
|
||||
&sb,
|
||||
" Parameters: %s\n",
|
||||
utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200),
|
||||
)
|
||||
}
|
||||
}
|
||||
sb.WriteString("]")
|
||||
@@ -1194,7 +1325,9 @@ func (al *AgentLoop) summarizeBatch(
|
||||
existingSummary string,
|
||||
) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
|
||||
sb.WriteString(
|
||||
"Provide a concise summary of this conversation segment, preserving core context and key points.\n",
|
||||
)
|
||||
if existingSummary != "" {
|
||||
sb.WriteString("Existing context: ")
|
||||
sb.WriteString(existingSummary)
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// resolveMediaRefs replaces media:// refs in message Media fields with base64 data URLs.
|
||||
// Uses streaming base64 encoding (file handle → encoder → buffer) to avoid holding
|
||||
// both raw bytes and encoded string in memory simultaneously.
|
||||
// Returns a new slice; original messages are not mutated.
|
||||
func resolveMediaRefs(messages []providers.Message, store media.MediaStore, maxSize int) []providers.Message {
|
||||
if store == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
result := make([]providers.Message, len(messages))
|
||||
copy(result, messages)
|
||||
|
||||
for i, m := range result {
|
||||
if len(m.Media) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
resolved := make([]string, 0, len(m.Media))
|
||||
for _, ref := range m.Media {
|
||||
if !strings.HasPrefix(ref, "media://") {
|
||||
resolved = append(resolved, ref)
|
||||
continue
|
||||
}
|
||||
|
||||
localPath, meta, err := store.ResolveWithMeta(ref)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to resolve media ref", map[string]any{
|
||||
"ref": ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to stat media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if info.Size() > int64(maxSize) {
|
||||
logger.WarnCF("agent", "Media file too large, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
"size": info.Size(),
|
||||
"max_size": maxSize,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine MIME type: prefer metadata, fallback to magic-bytes detection
|
||||
mime := meta.ContentType
|
||||
if mime == "" {
|
||||
kind, ftErr := filetype.MatchFile(localPath)
|
||||
if ftErr != nil || kind == filetype.Unknown {
|
||||
logger.WarnCF("agent", "Unknown media type, skipping", map[string]any{
|
||||
"path": localPath,
|
||||
})
|
||||
continue
|
||||
}
|
||||
mime = kind.MIME.Value
|
||||
}
|
||||
|
||||
// Streaming base64: open file → base64 encoder → buffer
|
||||
// Peak memory: ~1.33x file size (buffer only, no raw bytes copy)
|
||||
f, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "Failed to open media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := "data:" + mime + ";base64,"
|
||||
encodedLen := base64.StdEncoding.EncodedLen(int(info.Size()))
|
||||
var buf bytes.Buffer
|
||||
buf.Grow(len(prefix) + encodedLen)
|
||||
buf.WriteString(prefix)
|
||||
|
||||
encoder := base64.NewEncoder(base64.StdEncoding, &buf)
|
||||
if _, err := io.Copy(encoder, f); err != nil {
|
||||
f.Close()
|
||||
logger.WarnCF("agent", "Failed to encode media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
encoder.Close()
|
||||
f.Close()
|
||||
|
||||
resolved = append(resolved, buf.String())
|
||||
}
|
||||
|
||||
result[i].Media = resolved
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
@@ -808,3 +810,142 @@ func TestHandleReasoning(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_ResolvesToBase64(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a minimal valid PNG (8-byte header is enough for filetype detection)
|
||||
pngPath := filepath.Join(dir, "test.png")
|
||||
// PNG magic: 0x89 P N G \r \n 0x1A \n + minimal IHDR
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, // IHDR length
|
||||
0x49, 0x48, 0x44, 0x52, // "IHDR"
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, // 1x1 RGB
|
||||
0x00, 0x00, 0x00, // no interlace
|
||||
0x90, 0x77, 0x53, 0xDE, // CRC
|
||||
}
|
||||
if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, err := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 resolved media, got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/png;base64,") {
|
||||
t.Fatalf("expected data:image/png;base64, prefix, got %q", result[0].Media[0][:40])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_SkipsOversizedFile(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
bigPath := filepath.Join(dir, "big.png")
|
||||
// Write PNG header + padding to exceed limit
|
||||
data := make([]byte, 1024+1) // 1KB + 1 byte
|
||||
copy(data, []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})
|
||||
if err := os.WriteFile(bigPath, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, _ := store.Store(bigPath, media.MediaMeta{}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
// Use a tiny limit (1KB) so the file is oversized
|
||||
result := resolveMediaRefs(messages, store, 1024)
|
||||
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (oversized), got %d", len(result[0].Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_SkipsUnknownType(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
txtPath := filepath.Join(dir, "readme.txt")
|
||||
if err := os.WriteFile(txtPath, []byte("hello world"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ref, _ := store.Store(txtPath, media.MediaMeta{}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 0 {
|
||||
t.Fatalf("expected 0 media (unknown type), got %d", len(result[0].Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_PassesThroughNonMediaRefs(t *testing.T) {
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{"https://example.com/img.png"}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, nil, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 || result[0].Media[0] != "https://example.com/img.png" {
|
||||
t.Fatalf("expected passthrough of non-media:// URL, got %v", result[0].Media)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_DoesNotMutateOriginal(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
pngPath := filepath.Join(dir, "test.png")
|
||||
pngHeader := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
|
||||
0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE,
|
||||
}
|
||||
os.WriteFile(pngPath, pngHeader, 0o644)
|
||||
ref, _ := store.Store(pngPath, media.MediaMeta{}, "test")
|
||||
|
||||
original := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
originalRef := original[0].Media[0]
|
||||
|
||||
resolveMediaRefs(original, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if original[0].Media[0] != originalRef {
|
||||
t.Fatal("resolveMediaRefs mutated original message slice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMediaRefs_UsesMetaContentType(t *testing.T) {
|
||||
store := media.NewFileMediaStore()
|
||||
dir := t.TempDir()
|
||||
|
||||
// File with JPEG content but stored with explicit content type
|
||||
jpegPath := filepath.Join(dir, "photo")
|
||||
jpegHeader := []byte{0xFF, 0xD8, 0xFF, 0xE0} // JPEG magic bytes
|
||||
os.WriteFile(jpegPath, jpegHeader, 0o644)
|
||||
ref, _ := store.Store(jpegPath, media.MediaMeta{ContentType: "image/jpeg"}, "test")
|
||||
|
||||
messages := []providers.Message{
|
||||
{Role: "user", Content: "hi", Media: []string{ref}},
|
||||
}
|
||||
result := resolveMediaRefs(messages, store, config.DefaultMaxMediaSize)
|
||||
|
||||
if len(result[0].Media) != 1 {
|
||||
t.Fatalf("expected 1 media, got %d", len(result[0].Media))
|
||||
}
|
||||
if !strings.HasPrefix(result[0].Media[0], "data:image/jpeg;base64,") {
|
||||
t.Fatalf("expected jpeg prefix, got %q", result[0].Media[0][:30])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
// mentionPlaceholderRegex matches @_user_N placeholders inserted by Feishu for mentions.
|
||||
var mentionPlaceholderRegex = regexp.MustCompile(`@_user_\d+`)
|
||||
|
||||
// stringValue safely dereferences a *string pointer.
|
||||
func stringValue(v *string) string {
|
||||
if v == nil {
|
||||
@@ -7,3 +18,69 @@ func stringValue(v *string) string {
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// buildMarkdownCard builds a Feishu Interactive Card JSON 2.0 string with markdown content.
|
||||
// JSON 2.0 cards support full CommonMark standard markdown syntax.
|
||||
func buildMarkdownCard(content string) (string, error) {
|
||||
card := map[string]any{
|
||||
"schema": "2.0",
|
||||
"body": map[string]any{
|
||||
"elements": []map[string]any{
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(card)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// extractJSONStringField unmarshals content as JSON and returns the value of the given string field.
|
||||
// Returns "" if the content is invalid JSON or the field is missing/empty.
|
||||
func extractJSONStringField(content, field string) string {
|
||||
var m map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(content), &m); err != nil {
|
||||
return ""
|
||||
}
|
||||
raw, ok := m[field]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err != nil {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// extractImageKey extracts the image_key from a Feishu image message content JSON.
|
||||
// Format: {"image_key": "img_xxx"}
|
||||
func extractImageKey(content string) string { return extractJSONStringField(content, "image_key") }
|
||||
|
||||
// extractFileKey extracts the file_key from a Feishu file/audio message content JSON.
|
||||
// Format: {"file_key": "file_xxx", "file_name": "...", ...}
|
||||
func extractFileKey(content string) string { return extractJSONStringField(content, "file_key") }
|
||||
|
||||
// extractFileName extracts the file_name from a Feishu file message content JSON.
|
||||
func extractFileName(content string) string { return extractJSONStringField(content, "file_name") }
|
||||
|
||||
// stripMentionPlaceholders removes @_user_N placeholders from the text content.
|
||||
// These are inserted by Feishu when users @mention someone in a message.
|
||||
func stripMentionPlaceholders(content string, mentions []*larkim.MentionEvent) string {
|
||||
if len(mentions) == 0 {
|
||||
return content
|
||||
}
|
||||
for _, m := range mentions {
|
||||
if m.Key != nil && *m.Key != "" {
|
||||
content = strings.ReplaceAll(content, *m.Key, "")
|
||||
}
|
||||
}
|
||||
// Also clean up any remaining @_user_N patterns
|
||||
content = mentionPlaceholderRegex.ReplaceAllString(content, "")
|
||||
return strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
func TestExtractJSONStringField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
field string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid field",
|
||||
content: `{"image_key": "img_v2_xxx"}`,
|
||||
field: "image_key",
|
||||
want: "img_v2_xxx",
|
||||
},
|
||||
{
|
||||
name: "missing field",
|
||||
content: `{"image_key": "img_v2_xxx"}`,
|
||||
field: "file_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
content: `not json at all`,
|
||||
field: "image_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
field: "image_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "non-string field value",
|
||||
content: `{"count": 42}`,
|
||||
field: "count",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty string value",
|
||||
content: `{"image_key": ""}`,
|
||||
field: "image_key",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "multiple fields",
|
||||
content: `{"file_key": "file_xxx", "file_name": "test.pdf"}`,
|
||||
field: "file_name",
|
||||
want: "test.pdf",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractJSONStringField(tt.content, tt.field)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractJSONStringField(%q, %q) = %q, want %q", tt.content, tt.field, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractImageKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
content: `{"image_key": "img_v2_abc123"}`,
|
||||
want: "img_v2_abc123",
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
content: `{"file_key": "file_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
content: `{broken`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractImageKey(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractImageKey(%q) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFileKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
content: `{"file_key": "file_v2_abc123", "file_name": "test.doc"}`,
|
||||
want: "file_v2_abc123",
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
content: `{"image_key": "img_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
content: `not json`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractFileKey(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractFileKey(%q) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
content: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
|
||||
want: "report.pdf",
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
content: `{"file_key": "file_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
content: `{bad`,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractFileName(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractFileName(%q) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMarkdownCard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
name: "normal content",
|
||||
content: "Hello **world**",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
content: `Code: "foo" & <bar> 'baz'`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := buildMarkdownCard(tt.content)
|
||||
if err != nil {
|
||||
t.Fatalf("buildMarkdownCard(%q) unexpected error: %v", tt.content, err)
|
||||
}
|
||||
|
||||
// Verify valid JSON
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(result), &parsed); err != nil {
|
||||
t.Fatalf("buildMarkdownCard(%q) produced invalid JSON: %v", tt.content, err)
|
||||
}
|
||||
|
||||
// Verify schema
|
||||
if parsed["schema"] != "2.0" {
|
||||
t.Errorf("schema = %v, want %q", parsed["schema"], "2.0")
|
||||
}
|
||||
|
||||
// Verify body.elements[0].content == input
|
||||
body, ok := parsed["body"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("missing body in card JSON")
|
||||
}
|
||||
elements, ok := body["elements"].([]any)
|
||||
if !ok || len(elements) == 0 {
|
||||
t.Fatal("missing or empty elements in card JSON")
|
||||
}
|
||||
elem, ok := elements[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first element is not an object")
|
||||
}
|
||||
if elem["tag"] != "markdown" {
|
||||
t.Errorf("tag = %v, want %q", elem["tag"], "markdown")
|
||||
}
|
||||
if elem["content"] != tt.content {
|
||||
t.Errorf("content = %v, want %q", elem["content"], tt.content)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripMentionPlaceholders(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
mentions []*larkim.MentionEvent
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no mentions",
|
||||
content: "Hello world",
|
||||
mentions: nil,
|
||||
want: "Hello world",
|
||||
},
|
||||
{
|
||||
name: "single mention",
|
||||
content: "@_user_1 hello",
|
||||
mentions: []*larkim.MentionEvent{
|
||||
{Key: strPtr("@_user_1")},
|
||||
},
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "multiple mentions",
|
||||
content: "@_user_1 @_user_2 hey",
|
||||
mentions: []*larkim.MentionEvent{
|
||||
{Key: strPtr("@_user_1")},
|
||||
{Key: strPtr("@_user_2")},
|
||||
},
|
||||
want: "hey",
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
mentions: []*larkim.MentionEvent{{Key: strPtr("@_user_1")}},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty mentions slice",
|
||||
content: "@_user_1 test",
|
||||
mentions: []*larkim.MentionEvent{},
|
||||
want: "@_user_1 test",
|
||||
},
|
||||
{
|
||||
name: "mention with nil key",
|
||||
content: "@_user_1 test",
|
||||
mentions: []*larkim.MentionEvent{
|
||||
{Key: nil},
|
||||
},
|
||||
want: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := stripMentionPlaceholders(tt.content, tt.mentions)
|
||||
if got != tt.want {
|
||||
t.Errorf("stripMentionPlaceholders(%q, ...) = %q, want %q", tt.content, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,8 @@ type FeishuChannel struct {
|
||||
*channels.BaseChannel
|
||||
}
|
||||
|
||||
var errUnsupported = errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
|
||||
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
return nil, errors.New(
|
||||
@@ -25,15 +27,35 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
|
||||
|
||||
// Start is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
return nil
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// Stop is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// Send is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// EditMessage is a stub method to satisfy MessageEditor
|
||||
func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
// SendPlaceholder is a stub method to satisfy PlaceholderCapable
|
||||
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
return "", errUnsupported
|
||||
}
|
||||
|
||||
// ReactToMessage is a stub method to satisfy ReactionCapable
|
||||
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
|
||||
return func() {}, errUnsupported
|
||||
}
|
||||
|
||||
// SendMedia is a stub method to satisfy MediaSender
|
||||
func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
return errUnsupported
|
||||
}
|
||||
|
||||
@@ -6,10 +6,15 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
"sync/atomic"
|
||||
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
|
||||
@@ -19,6 +24,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/identity"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -28,6 +34,8 @@ type FeishuChannel struct {
|
||||
client *lark.Client
|
||||
wsClient *larkws.Client
|
||||
|
||||
botOpenID atomic.Value // stores string; populated lazily for @mention detection
|
||||
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -38,11 +46,13 @@ func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChan
|
||||
channels.WithReasoningChannelID(cfg.ReasoningChannelID),
|
||||
)
|
||||
|
||||
return &FeishuChannel{
|
||||
ch := &FeishuChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: lark.NewClient(cfg.AppID, cfg.AppSecret),
|
||||
}, nil
|
||||
}
|
||||
ch.SetOwner(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
@@ -50,6 +60,13 @@ func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("feishu app_id or app_secret is empty")
|
||||
}
|
||||
|
||||
// Fetch bot open_id via API for reliable @mention detection.
|
||||
if err := c.fetchBotOpenID(ctx); err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to fetch bot open_id, @mention detection may not work", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey).
|
||||
OnP2MessageReceiveV1(c.handleMessageReceive)
|
||||
|
||||
@@ -93,46 +110,213 @@ func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a message using Interactive Card format for markdown rendering.
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
if msg.ChatID == "" {
|
||||
return fmt.Errorf("chat ID is empty")
|
||||
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(map[string]string{"text": msg.Content})
|
||||
// Build interactive card with markdown content
|
||||
cardContent, err := buildMarkdownCard(msg.Content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal feishu content: %w", err)
|
||||
return fmt.Errorf("feishu send: card build failed: %w", err)
|
||||
}
|
||||
return c.sendCard(ctx, msg.ChatID, cardContent)
|
||||
}
|
||||
|
||||
// EditMessage implements channels.MessageEditor.
|
||||
// Uses Message.Patch to update an interactive card message.
|
||||
func (c *FeishuChannel) EditMessage(ctx context.Context, chatID, messageID, content string) error {
|
||||
cardContent, err := buildMarkdownCard(content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu edit: card build failed: %w", err)
|
||||
}
|
||||
|
||||
req := larkim.NewPatchMessageReqBuilder().
|
||||
MessageId(messageID).
|
||||
Body(larkim.NewPatchMessageReqBodyBuilder().Content(cardContent).Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Patch(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu edit: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu edit api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendPlaceholder implements channels.PlaceholderCapable.
|
||||
// Sends an interactive card with placeholder text and returns its message ID.
|
||||
func (c *FeishuChannel) SendPlaceholder(ctx context.Context, chatID string) (string, error) {
|
||||
if !c.config.Placeholder.Enabled {
|
||||
logger.DebugCF("feishu", "Placeholder disabled, skipping", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
return "", nil
|
||||
}
|
||||
|
||||
text := c.config.Placeholder.Text
|
||||
if text == "" {
|
||||
text = "Thinking..."
|
||||
}
|
||||
|
||||
cardContent, err := buildMarkdownCard(text)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("feishu placeholder: card build failed: %w", err)
|
||||
}
|
||||
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(msg.ChatID).
|
||||
MsgType(larkim.MsgTypeText).
|
||||
Content(string(payload)).
|
||||
Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())).
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeInteractive).
|
||||
Content(cardContent).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu send: %w", channels.ErrTemporary)
|
||||
return "", fmt.Errorf("feishu placeholder send: %w", err)
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
return "", fmt.Errorf("feishu placeholder api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
|
||||
logger.DebugCF("feishu", "Feishu message sent", map[string]any{
|
||||
"chat_id": msg.ChatID,
|
||||
})
|
||||
if resp.Data != nil && resp.Data.MessageId != nil {
|
||||
return *resp.Data.MessageId, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// ReactToMessage implements channels.ReactionCapable.
|
||||
// Adds an "Pin" reaction and returns an undo function to remove it.
|
||||
func (c *FeishuChannel) ReactToMessage(ctx context.Context, chatID, messageID string) (func(), error) {
|
||||
req := larkim.NewCreateMessageReactionReqBuilder().
|
||||
MessageId(messageID).
|
||||
Body(larkim.NewCreateMessageReactionReqBodyBuilder().
|
||||
ReactionType(larkim.NewEmojiBuilder().EmojiType("Pin").Build()).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.MessageReaction.Create(ctx, req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to add reaction", map[string]any{
|
||||
"message_id": messageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return func() {}, fmt.Errorf("feishu react: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
logger.ErrorCF("feishu", "Reaction API error", map[string]any{
|
||||
"message_id": messageID,
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
})
|
||||
return func() {}, fmt.Errorf("feishu react api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
|
||||
var reactionID string
|
||||
if resp.Data != nil && resp.Data.ReactionId != nil {
|
||||
reactionID = *resp.Data.ReactionId
|
||||
}
|
||||
if reactionID == "" {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
var undone atomic.Bool
|
||||
undo := func() {
|
||||
if !undone.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
delReq := larkim.NewDeleteMessageReactionReqBuilder().
|
||||
MessageId(messageID).
|
||||
ReactionId(reactionID).
|
||||
Build()
|
||||
_, _ = c.client.Im.V1.MessageReaction.Delete(context.Background(), delReq)
|
||||
}
|
||||
return undo, nil
|
||||
}
|
||||
|
||||
// SendMedia implements channels.MediaSender.
|
||||
// Uploads images/files via Feishu API then sends as messages.
|
||||
func (c *FeishuChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
}
|
||||
|
||||
if msg.ChatID == "" {
|
||||
return fmt.Errorf("chat ID is empty: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
store := c.GetMediaStore()
|
||||
if store == nil {
|
||||
return fmt.Errorf("no media store available: %w", channels.ErrSendFailed)
|
||||
}
|
||||
|
||||
for _, part := range msg.Parts {
|
||||
if err := c.sendMediaPart(ctx, msg.ChatID, part, store); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendMediaPart resolves and sends a single media part.
|
||||
func (c *FeishuChannel) sendMediaPart(
|
||||
ctx context.Context,
|
||||
chatID string,
|
||||
part bus.MediaPart,
|
||||
store media.MediaStore,
|
||||
) error {
|
||||
localPath, err := store.Resolve(part.Ref)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to resolve media ref", map[string]any{
|
||||
"ref": part.Ref,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil // skip this part
|
||||
}
|
||||
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to open media file", map[string]any{
|
||||
"path": localPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil // skip this part
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
switch part.Type {
|
||||
case "image":
|
||||
err = c.sendImage(ctx, chatID, file)
|
||||
default:
|
||||
filename := part.Filename
|
||||
if filename == "" {
|
||||
filename = "file"
|
||||
}
|
||||
err = c.sendFile(ctx, chatID, file, filename, part.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to send media", map[string]any{
|
||||
"type": part.Type,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return fmt.Errorf("feishu send media: %w", channels.ErrTemporary)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Inbound message handling ---
|
||||
|
||||
func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
if event == nil || event.Event == nil || event.Event.Message == nil {
|
||||
return nil
|
||||
@@ -151,34 +335,68 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
senderID = "unknown"
|
||||
}
|
||||
|
||||
content := extractFeishuMessageContent(message)
|
||||
messageType := stringValue(message.MessageType)
|
||||
messageID := stringValue(message.MessageId)
|
||||
rawContent := stringValue(message.Content)
|
||||
|
||||
// Check allowlist early to avoid downloading media for rejected senders.
|
||||
// BaseChannel.HandleMessage will check again, but this avoids wasted network I/O.
|
||||
senderInfo := bus.SenderInfo{
|
||||
Platform: "feishu",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("feishu", senderID),
|
||||
}
|
||||
if !c.IsAllowedSender(senderInfo) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract content based on message type
|
||||
content := extractContent(messageType, rawContent)
|
||||
|
||||
// Handle media messages (download and store)
|
||||
var mediaRefs []string
|
||||
if store := c.GetMediaStore(); store != nil && messageID != "" {
|
||||
mediaRefs = c.downloadInboundMedia(ctx, chatID, messageID, messageType, rawContent, store)
|
||||
}
|
||||
|
||||
// Append media tags to content (like Telegram does)
|
||||
content = appendMediaTags(content, messageType, mediaRefs)
|
||||
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
metadata := map[string]string{}
|
||||
messageID := ""
|
||||
if mid := stringValue(message.MessageId); mid != "" {
|
||||
messageID = mid
|
||||
if messageID != "" {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
if messageType := stringValue(message.MessageType); messageType != "" {
|
||||
if messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
if chatType := stringValue(message.ChatType); chatType != "" {
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
|
||||
chatType := stringValue(message.ChatType)
|
||||
var peer bus.Peer
|
||||
if chatType == "p2p" {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
|
||||
// Check if bot was mentioned
|
||||
isMentioned := c.isBotMentioned(message)
|
||||
|
||||
// Strip mention placeholders from content before group trigger check
|
||||
if len(message.Mentions) > 0 {
|
||||
content = stripMentionPlaceholders(content, message.Mentions)
|
||||
}
|
||||
|
||||
// In group chats, apply unified group trigger filtering
|
||||
respond, cleaned := c.ShouldRespondInGroup(false, content)
|
||||
respond, cleaned := c.ShouldRespondInGroup(isMentioned, content)
|
||||
if !respond {
|
||||
return nil
|
||||
}
|
||||
@@ -186,22 +404,398 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
}
|
||||
|
||||
logger.InfoCF("feishu", "Feishu message received", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 80),
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"preview": utils.Truncate(content, 80),
|
||||
})
|
||||
|
||||
senderInfo := bus.SenderInfo{
|
||||
Platform: "feishu",
|
||||
PlatformID: senderID,
|
||||
CanonicalID: identity.BuildCanonicalID("feishu", senderID),
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// fetchBotOpenID calls the Feishu bot info API to retrieve and store the bot's open_id.
|
||||
func (c *FeishuChannel) fetchBotOpenID(ctx context.Context) error {
|
||||
resp, err := c.client.Do(ctx, &larkcore.ApiReq{
|
||||
HttpMethod: http.MethodGet,
|
||||
ApiPath: "/open-apis/bot/v3/info",
|
||||
SupportedAccessTokenTypes: []larkcore.AccessTokenType{larkcore.AccessTokenTypeTenant},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("bot info request: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsAllowedSender(senderInfo) {
|
||||
return nil
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Bot struct {
|
||||
OpenID string `json:"open_id"`
|
||||
} `json:"bot"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.RawBody, &result); err != nil {
|
||||
return fmt.Errorf("bot info parse: %w", err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
return fmt.Errorf("bot info api error (code=%d)", result.Code)
|
||||
}
|
||||
if result.Bot.OpenID == "" {
|
||||
return fmt.Errorf("bot info: empty open_id")
|
||||
}
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo)
|
||||
c.botOpenID.Store(result.Bot.OpenID)
|
||||
logger.InfoCF("feishu", "Fetched bot open_id from API", map[string]any{
|
||||
"open_id": result.Bot.OpenID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot was @mentioned in the message.
|
||||
func (c *FeishuChannel) isBotMentioned(message *larkim.EventMessage) bool {
|
||||
if message.Mentions == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
knownID, _ := c.botOpenID.Load().(string)
|
||||
if knownID == "" {
|
||||
logger.DebugCF("feishu", "Bot open_id unknown, cannot detect @mention", nil)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, m := range message.Mentions {
|
||||
if m.Id == nil {
|
||||
continue
|
||||
}
|
||||
if m.Id.OpenId != nil && *m.Id.OpenId == knownID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractContent extracts text content from different message types.
|
||||
func extractContent(messageType, rawContent string) string {
|
||||
if rawContent == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case larkim.MsgTypeText:
|
||||
var textPayload struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(rawContent), &textPayload); err == nil {
|
||||
return textPayload.Text
|
||||
}
|
||||
return rawContent
|
||||
|
||||
case larkim.MsgTypePost:
|
||||
// Pass raw JSON to LLM — structured rich text is more informative than flattened plain text
|
||||
return rawContent
|
||||
|
||||
case larkim.MsgTypeImage:
|
||||
// Image messages don't have text content
|
||||
return ""
|
||||
|
||||
case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
|
||||
// File/audio/video messages may have a filename
|
||||
name := extractFileName(rawContent)
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return ""
|
||||
|
||||
default:
|
||||
return rawContent
|
||||
}
|
||||
}
|
||||
|
||||
// downloadInboundMedia downloads media from inbound messages and stores in MediaStore.
|
||||
func (c *FeishuChannel) downloadInboundMedia(
|
||||
ctx context.Context,
|
||||
chatID, messageID, messageType, rawContent string,
|
||||
store media.MediaStore,
|
||||
) []string {
|
||||
var refs []string
|
||||
scope := channels.BuildMediaScope("feishu", chatID, messageID)
|
||||
|
||||
switch messageType {
|
||||
case larkim.MsgTypeImage:
|
||||
imageKey := extractImageKey(rawContent)
|
||||
if imageKey == "" {
|
||||
return nil
|
||||
}
|
||||
ref := c.downloadResource(ctx, messageID, imageKey, "image", ".jpg", store, scope)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
|
||||
case larkim.MsgTypeFile, larkim.MsgTypeAudio, larkim.MsgTypeMedia:
|
||||
fileKey := extractFileKey(rawContent)
|
||||
if fileKey == "" {
|
||||
return nil
|
||||
}
|
||||
// Derive a fallback extension from the message type.
|
||||
var ext string
|
||||
switch messageType {
|
||||
case larkim.MsgTypeAudio:
|
||||
ext = ".ogg"
|
||||
case larkim.MsgTypeMedia:
|
||||
ext = ".mp4"
|
||||
default:
|
||||
ext = "" // generic file — rely on resp.FileName
|
||||
}
|
||||
ref := c.downloadResource(ctx, messageID, fileKey, "file", ext, store, scope)
|
||||
if ref != "" {
|
||||
refs = append(refs, ref)
|
||||
}
|
||||
}
|
||||
|
||||
return refs
|
||||
}
|
||||
|
||||
// downloadResource downloads a message resource (image/file) from Feishu,
|
||||
// writes it to the project media directory, and stores the reference in MediaStore.
|
||||
// fallbackExt (e.g. ".jpg") is appended when the resolved filename has no extension.
|
||||
func (c *FeishuChannel) downloadResource(
|
||||
ctx context.Context,
|
||||
messageID, fileKey, resourceType, fallbackExt string,
|
||||
store media.MediaStore,
|
||||
scope string,
|
||||
) string {
|
||||
req := larkim.NewGetMessageResourceReqBuilder().
|
||||
MessageId(messageID).
|
||||
FileKey(fileKey).
|
||||
Type(resourceType).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.MessageResource.Get(ctx, req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to download resource", map[string]any{
|
||||
"message_id": messageID,
|
||||
"file_key": fileKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
if !resp.Success() {
|
||||
logger.ErrorCF("feishu", "Resource download api error", map[string]any{
|
||||
"code": resp.Code,
|
||||
"msg": resp.Msg,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
if resp.File == nil {
|
||||
return ""
|
||||
}
|
||||
// Safely close the underlying reader if it implements io.Closer (e.g. HTTP response body).
|
||||
if closer, ok := resp.File.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
|
||||
filename := resp.FileName
|
||||
if filename == "" {
|
||||
filename = fileKey
|
||||
}
|
||||
// If filename still has no extension, append the fallback (like Telegram's ext parameter).
|
||||
if filepath.Ext(filename) == "" && fallbackExt != "" {
|
||||
filename += fallbackExt
|
||||
}
|
||||
|
||||
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
|
||||
"error": mkdirErr.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
ext := filepath.Ext(filename)
|
||||
localPath := filepath.Join(mediaDir, utils.SanitizeFilename(messageID+"-"+fileKey+ext))
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create local file for resource", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
if _, copyErr := io.Copy(out, resp.File); copyErr != nil {
|
||||
out.Close()
|
||||
os.Remove(localPath)
|
||||
logger.ErrorCF("feishu", "Failed to write resource to file", map[string]any{
|
||||
"error": copyErr.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
out.Close()
|
||||
|
||||
ref, err := store.Store(localPath, media.MediaMeta{
|
||||
Filename: filename,
|
||||
Source: "feishu",
|
||||
}, scope)
|
||||
if err != nil {
|
||||
logger.ErrorCF("feishu", "Failed to store downloaded resource", map[string]any{
|
||||
"file_key": fileKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
os.Remove(localPath)
|
||||
return ""
|
||||
}
|
||||
|
||||
return ref
|
||||
}
|
||||
|
||||
// appendMediaTags appends media type tags to content (like Telegram's "[image: photo]").
|
||||
func appendMediaTags(content, messageType string, mediaRefs []string) string {
|
||||
if len(mediaRefs) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
var tag string
|
||||
switch messageType {
|
||||
case larkim.MsgTypeImage:
|
||||
tag = "[image: photo]"
|
||||
case larkim.MsgTypeAudio:
|
||||
tag = "[audio]"
|
||||
case larkim.MsgTypeMedia:
|
||||
tag = "[video]"
|
||||
case larkim.MsgTypeFile:
|
||||
tag = "[file]"
|
||||
default:
|
||||
tag = "[attachment]"
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
return tag
|
||||
}
|
||||
return content + " " + tag
|
||||
}
|
||||
|
||||
// sendCard sends an interactive card message to a chat.
|
||||
func (c *FeishuChannel) sendCard(ctx context.Context, chatID, cardContent string) error {
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeInteractive).
|
||||
Content(cardContent).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu send card: %w", channels.ErrTemporary)
|
||||
}
|
||||
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary)
|
||||
}
|
||||
|
||||
logger.DebugCF("feishu", "Feishu card message sent", map[string]any{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendImage uploads an image and sends it as a message.
|
||||
func (c *FeishuChannel) sendImage(ctx context.Context, chatID string, file *os.File) error {
|
||||
// Upload image to get image_key
|
||||
uploadReq := larkim.NewCreateImageReqBuilder().
|
||||
Body(larkim.NewCreateImageReqBodyBuilder().
|
||||
ImageType("message").
|
||||
Image(file).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
uploadResp, err := c.client.Im.V1.Image.Create(ctx, uploadReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu image upload: %w", err)
|
||||
}
|
||||
if !uploadResp.Success() {
|
||||
return fmt.Errorf("feishu image upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
|
||||
}
|
||||
if uploadResp.Data == nil || uploadResp.Data.ImageKey == nil {
|
||||
return fmt.Errorf("feishu image upload: no image_key returned")
|
||||
}
|
||||
|
||||
imageKey := *uploadResp.Data.ImageKey
|
||||
|
||||
// Send image message
|
||||
content, _ := json.Marshal(map[string]string{"image_key": imageKey})
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeImage).
|
||||
Content(string(content)).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu image send: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu image send api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendFile uploads a file and sends it as a message.
|
||||
func (c *FeishuChannel) sendFile(ctx context.Context, chatID string, file *os.File, filename, fileType string) error {
|
||||
// Map part type to Feishu file type
|
||||
feishuFileType := "stream"
|
||||
switch fileType {
|
||||
case "audio":
|
||||
feishuFileType = "opus"
|
||||
case "video":
|
||||
feishuFileType = "mp4"
|
||||
}
|
||||
|
||||
// Upload file to get file_key
|
||||
uploadReq := larkim.NewCreateFileReqBuilder().
|
||||
Body(larkim.NewCreateFileReqBodyBuilder().
|
||||
FileType(feishuFileType).
|
||||
FileName(filename).
|
||||
File(file).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
uploadResp, err := c.client.Im.V1.File.Create(ctx, uploadReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu file upload: %w", err)
|
||||
}
|
||||
if !uploadResp.Success() {
|
||||
return fmt.Errorf("feishu file upload api error (code=%d msg=%s)", uploadResp.Code, uploadResp.Msg)
|
||||
}
|
||||
if uploadResp.Data == nil || uploadResp.Data.FileKey == nil {
|
||||
return fmt.Errorf("feishu file upload: no file_key returned")
|
||||
}
|
||||
|
||||
fileKey := *uploadResp.Data.FileKey
|
||||
|
||||
// Send file message
|
||||
content, _ := json.Marshal(map[string]string{"file_key": fileKey})
|
||||
req := larkim.NewCreateMessageReqBuilder().
|
||||
ReceiveIdType(larkim.ReceiveIdTypeChatId).
|
||||
Body(larkim.NewCreateMessageReqBodyBuilder().
|
||||
ReceiveId(chatID).
|
||||
MsgType(larkim.MsgTypeFile).
|
||||
Content(string(content)).
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Create(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("feishu file send: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
return fmt.Errorf("feishu file send api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -222,20 +816,3 @@ func extractFeishuSenderID(sender *larkim.EventSender) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractFeishuMessageContent(message *larkim.EventMessage) string {
|
||||
if message == nil || message.Content == nil || *message.Content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText {
|
||||
var textPayload struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil {
|
||||
return textPayload.Text
|
||||
}
|
||||
}
|
||||
|
||||
return *message.Content
|
||||
}
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
func TestExtractContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messageType string
|
||||
rawContent string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "text message",
|
||||
messageType: "text",
|
||||
rawContent: `{"text": "hello world"}`,
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
name: "text message invalid JSON",
|
||||
messageType: "text",
|
||||
rawContent: `not json`,
|
||||
want: "not json",
|
||||
},
|
||||
{
|
||||
name: "post message returns raw JSON",
|
||||
messageType: "post",
|
||||
rawContent: `{"title": "test post"}`,
|
||||
want: `{"title": "test post"}`,
|
||||
},
|
||||
{
|
||||
name: "image message returns empty",
|
||||
messageType: "image",
|
||||
rawContent: `{"image_key": "img_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "file message with filename",
|
||||
messageType: "file",
|
||||
rawContent: `{"file_key": "file_xxx", "file_name": "report.pdf"}`,
|
||||
want: "report.pdf",
|
||||
},
|
||||
{
|
||||
name: "file message without filename",
|
||||
messageType: "file",
|
||||
rawContent: `{"file_key": "file_xxx"}`,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "audio message with filename",
|
||||
messageType: "audio",
|
||||
rawContent: `{"file_key": "file_xxx", "file_name": "recording.ogg"}`,
|
||||
want: "recording.ogg",
|
||||
},
|
||||
{
|
||||
name: "media message with filename",
|
||||
messageType: "media",
|
||||
rawContent: `{"file_key": "file_xxx", "file_name": "video.mp4"}`,
|
||||
want: "video.mp4",
|
||||
},
|
||||
{
|
||||
name: "unknown message type returns raw",
|
||||
messageType: "sticker",
|
||||
rawContent: `{"sticker_id": "sticker_xxx"}`,
|
||||
want: `{"sticker_id": "sticker_xxx"}`,
|
||||
},
|
||||
{
|
||||
name: "empty raw content",
|
||||
messageType: "text",
|
||||
rawContent: "",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractContent(tt.messageType, tt.rawContent)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractContent(%q, %q) = %q, want %q", tt.messageType, tt.rawContent, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendMediaTags(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
messageType string
|
||||
mediaRefs []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no refs returns content unchanged",
|
||||
content: "hello",
|
||||
messageType: "image",
|
||||
mediaRefs: nil,
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "empty refs returns content unchanged",
|
||||
content: "hello",
|
||||
messageType: "image",
|
||||
mediaRefs: []string{},
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
name: "image with content",
|
||||
content: "check this",
|
||||
messageType: "image",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "check this [image: photo]",
|
||||
},
|
||||
{
|
||||
name: "image empty content",
|
||||
content: "",
|
||||
messageType: "image",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "[image: photo]",
|
||||
},
|
||||
{
|
||||
name: "audio",
|
||||
content: "listen",
|
||||
messageType: "audio",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "listen [audio]",
|
||||
},
|
||||
{
|
||||
name: "media/video",
|
||||
content: "watch",
|
||||
messageType: "media",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "watch [video]",
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
content: "report.pdf",
|
||||
messageType: "file",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "report.pdf [file]",
|
||||
},
|
||||
{
|
||||
name: "unknown type",
|
||||
content: "something",
|
||||
messageType: "sticker",
|
||||
mediaRefs: []string{"ref1"},
|
||||
want: "something [attachment]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := appendMediaTags(tt.content, tt.messageType, tt.mediaRefs)
|
||||
if got != tt.want {
|
||||
t.Errorf(
|
||||
"appendMediaTags(%q, %q, %v) = %q, want %q",
|
||||
tt.content,
|
||||
tt.messageType,
|
||||
tt.mediaRefs,
|
||||
got,
|
||||
tt.want,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFeishuSenderID(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sender *larkim.EventSender
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil sender",
|
||||
sender: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil sender ID",
|
||||
sender: &larkim.EventSender{SenderId: nil},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "userId preferred",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr("u_abc123"),
|
||||
OpenId: strPtr("ou_def456"),
|
||||
UnionId: strPtr("on_ghi789"),
|
||||
},
|
||||
},
|
||||
want: "u_abc123",
|
||||
},
|
||||
{
|
||||
name: "openId fallback",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr(""),
|
||||
OpenId: strPtr("ou_def456"),
|
||||
UnionId: strPtr("on_ghi789"),
|
||||
},
|
||||
},
|
||||
want: "ou_def456",
|
||||
},
|
||||
{
|
||||
name: "unionId fallback",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr(""),
|
||||
OpenId: strPtr(""),
|
||||
UnionId: strPtr("on_ghi789"),
|
||||
},
|
||||
},
|
||||
want: "on_ghi789",
|
||||
},
|
||||
{
|
||||
name: "all empty strings",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: strPtr(""),
|
||||
OpenId: strPtr(""),
|
||||
UnionId: strPtr(""),
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "nil userId pointer falls through",
|
||||
sender: &larkim.EventSender{
|
||||
SenderId: &larkim.UserId{
|
||||
UserId: nil,
|
||||
OpenId: strPtr("ou_def456"),
|
||||
UnionId: nil,
|
||||
},
|
||||
},
|
||||
want: "ou_def456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractFeishuSenderID(tt.sender)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractFeishuSenderID() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/mymmrac/telego/telegohandler"
|
||||
th "github.com/mymmrac/telego/telegohandler"
|
||||
tu "github.com/mymmrac/telego/telegoutil"
|
||||
|
||||
@@ -41,7 +41,7 @@ var (
|
||||
type TelegramChannel struct {
|
||||
*channels.BaseChannel
|
||||
bot *telego.Bot
|
||||
bh *telegohandler.BotHandler
|
||||
bh *th.BotHandler
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
@@ -72,6 +72,10 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann
|
||||
}))
|
||||
}
|
||||
|
||||
if baseURL := strings.TrimRight(strings.TrimSpace(telegramCfg.BaseURL), "/"); baseURL != "" {
|
||||
opts = append(opts, telego.WithAPIServer(baseURL))
|
||||
}
|
||||
|
||||
bot, err := telego.NewBot(telegramCfg.Token, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||
@@ -101,6 +105,12 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.initBotCommands(c.ctx); err != nil {
|
||||
logger.WarnCF("telegram", "Failed to initialize bot commands", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
})
|
||||
@@ -109,20 +119,19 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
bh, err := telegohandler.NewBotHandler(c.bot, updates)
|
||||
bh, err := th.NewBotHandler(c.bot, updates)
|
||||
if err != nil {
|
||||
c.cancel()
|
||||
return fmt.Errorf("failed to create bot handler: %w", err)
|
||||
}
|
||||
c.bh = bh
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
c.commands.Help(ctx, message)
|
||||
return nil
|
||||
}, th.CommandEqual("help"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Help(ctx, message)
|
||||
}, th.CommandEqual("help"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
@@ -141,7 +150,13 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go bh.Start()
|
||||
go func() {
|
||||
if err = bh.Start(); err != nil {
|
||||
logger.ErrorCF("telegram", "Bot handler failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -152,7 +167,7 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
|
||||
// Stop the bot handler
|
||||
if c.bh != nil {
|
||||
c.bh.Stop()
|
||||
_ = c.bh.StopWithContext(ctx)
|
||||
}
|
||||
|
||||
// Cancel our context (stops long polling)
|
||||
@@ -163,6 +178,51 @@ func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) initBotCommands(ctx context.Context) error {
|
||||
currentCommands, err := c.bot.GetMyCommands(ctx, &telego.GetMyCommandsParams{
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("get commands: %w", err)
|
||||
}
|
||||
|
||||
commands := []telego.BotCommand{
|
||||
{
|
||||
Command: "start",
|
||||
Description: "Start the bot",
|
||||
},
|
||||
{
|
||||
Command: "help",
|
||||
Description: "Show a help message",
|
||||
},
|
||||
{
|
||||
Command: "show",
|
||||
Description: "Show current configuration",
|
||||
},
|
||||
{
|
||||
Command: "list",
|
||||
Description: "List available options",
|
||||
},
|
||||
}
|
||||
|
||||
// Setting commands on each start will hit the rate limit very quickly, that's why we check if an update is needed
|
||||
if !slices.Equal(currentCommands, commands) {
|
||||
logger.InfoC("telegram", "Updating bot commands")
|
||||
|
||||
err = c.bot.SetMyCommands(ctx, &telego.SetMyCommandsParams{
|
||||
Commands: commands,
|
||||
Scope: tu.ScopeDefault(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set commands: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.DebugC("telegram", "Bot commands are up to date")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return channels.ErrNotRunning
|
||||
|
||||
@@ -38,8 +38,7 @@ type WeComAppChannel struct {
|
||||
tokenMu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
|
||||
msgMu sync.RWMutex
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComXMLMessage represents the XML message structure from WeCom
|
||||
@@ -144,7 +143,7 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: make(map[string]bool),
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -607,23 +606,12 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
// As per WeCom documentation, use msg_id for deduplication
|
||||
msgID := fmt.Sprintf("%d", msg.MsgId)
|
||||
c.msgMu.Lock()
|
||||
if c.processedMsgs[msgID] {
|
||||
c.msgMu.Unlock()
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
// Clean up old messages while still holding the lock to avoid a data race
|
||||
// on len(). Reset the map but re-insert the current msgID so it remains
|
||||
// deduplicated.
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.processedMsgs[msgID] = true
|
||||
}
|
||||
c.msgMu.Unlock()
|
||||
|
||||
senderID := msg.FromUserName
|
||||
chatID := senderID // WeCom App uses user ID as chat ID for direct messages
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -28,8 +27,7 @@ type WeComBotChannel struct {
|
||||
client *http.Client
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
processedMsgs map[string]bool // Message deduplication: msg_id -> processed
|
||||
msgMu sync.RWMutex
|
||||
processedMsgs *MessageDeduplicator
|
||||
}
|
||||
|
||||
// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
|
||||
@@ -108,7 +106,7 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We
|
||||
client: &http.Client{Timeout: clientTimeout},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
processedMsgs: make(map[string]bool),
|
||||
processedMsgs: NewMessageDeduplicator(wecomMaxProcessedMessages),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -330,23 +328,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag
|
||||
|
||||
// Message deduplication: Use msg_id to prevent duplicate processing
|
||||
msgID := msg.MsgID
|
||||
c.msgMu.Lock()
|
||||
if c.processedMsgs[msgID] {
|
||||
c.msgMu.Unlock()
|
||||
if !c.processedMsgs.MarkMessageProcessed(msgID) {
|
||||
logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
|
||||
"msg_id": msgID,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.processedMsgs[msgID] = true
|
||||
// Clean up old messages while still holding the lock to avoid a data race
|
||||
// on len(). Reset the map but re-insert the current msgID so it remains
|
||||
// deduplicated.
|
||||
if len(c.processedMsgs) > 1000 {
|
||||
c.processedMsgs = make(map[string]bool)
|
||||
c.processedMsgs[msgID] = true
|
||||
}
|
||||
c.msgMu.Unlock()
|
||||
|
||||
senderID := msg.From.UserID
|
||||
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package wecom
|
||||
|
||||
import "sync"
|
||||
|
||||
const wecomMaxProcessedMessages = 1000
|
||||
|
||||
// MessageDeduplicator provides thread-safe message deduplication using a circular queue (ring buffer)
|
||||
// combined with a hash map. This ensures fast O(1) lookups while naturally evicting the oldest
|
||||
// messages without causing "amnesia cliffs" when the limit is reached.
|
||||
type MessageDeduplicator struct {
|
||||
mu sync.Mutex
|
||||
msgs map[string]bool
|
||||
ring []string
|
||||
idx int
|
||||
max int
|
||||
}
|
||||
|
||||
// NewMessageDeduplicator creates a new deduplicator with the specified capacity.
|
||||
func NewMessageDeduplicator(maxEntries int) *MessageDeduplicator {
|
||||
if maxEntries <= 0 {
|
||||
maxEntries = wecomMaxProcessedMessages
|
||||
}
|
||||
return &MessageDeduplicator{
|
||||
msgs: make(map[string]bool, maxEntries),
|
||||
ring: make([]string, maxEntries),
|
||||
max: maxEntries,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkMessageProcessed marks msgID as processed and returns false for duplicates.
|
||||
func (d *MessageDeduplicator) MarkMessageProcessed(msgID string) bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 1. Check for duplicate
|
||||
if d.msgs[msgID] {
|
||||
return false
|
||||
}
|
||||
|
||||
// 2. Evict the oldest message at our current ring position (if any)
|
||||
oldestID := d.ring[d.idx]
|
||||
if oldestID != "" {
|
||||
delete(d.msgs, oldestID)
|
||||
}
|
||||
|
||||
// 3. Store the new message
|
||||
d.msgs[msgID] = true
|
||||
d.ring[d.idx] = msgID
|
||||
|
||||
// 4. Advance the circle queue index
|
||||
d.idx = (d.idx + 1) % d.max
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageDeduplicator_DuplicateDetection(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("first message should be accepted")
|
||||
}
|
||||
|
||||
if ok := d.MarkMessageProcessed("msg-1"); ok {
|
||||
t.Fatalf("duplicate message should be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_ConcurrentSameMessage(t *testing.T) {
|
||||
d := NewMessageDeduplicator(wecomMaxProcessedMessages)
|
||||
|
||||
const goroutines = 64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
results := make(chan bool, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results <- d.MarkMessageProcessed("msg-concurrent")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successes := 0
|
||||
for ok := range results {
|
||||
if ok {
|
||||
successes++
|
||||
}
|
||||
}
|
||||
|
||||
if successes != 1 {
|
||||
t.Fatalf("expected exactly 1 successful mark, got %d", successes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageDeduplicator_CircularQueueEviction(t *testing.T) {
|
||||
// Create a deduplicator with a very small capacity to test eviction easily.
|
||||
capacity := 3
|
||||
d := NewMessageDeduplicator(capacity)
|
||||
|
||||
// Fill the queue.
|
||||
d.MarkMessageProcessed("msg-1")
|
||||
d.MarkMessageProcessed("msg-2")
|
||||
d.MarkMessageProcessed("msg-3")
|
||||
|
||||
// At this point, the queue is full. msg-1 is the oldest.
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to be 3, got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// This should evict msg-1 and add msg-4.
|
||||
if ok := d.MarkMessageProcessed("msg-4"); !ok {
|
||||
t.Fatalf("msg-4 should be accepted")
|
||||
}
|
||||
|
||||
if len(d.msgs) != 3 {
|
||||
t.Fatalf("expected map size to remain at max capacity (3), got %d", len(d.msgs))
|
||||
}
|
||||
|
||||
// msg-1 should now be forgotten (evicted).
|
||||
if ok := d.MarkMessageProcessed("msg-1"); !ok {
|
||||
t.Fatalf("msg-1 should be accepted again because it was evicted")
|
||||
}
|
||||
|
||||
// msg-2 should have been evicted when we added msg-1 back.
|
||||
if ok := d.MarkMessageProcessed("msg-2"); !ok {
|
||||
t.Fatalf("msg-2 should be accepted again because it was evicted")
|
||||
}
|
||||
}
|
||||
@@ -180,6 +180,16 @@ type AgentDefaults struct {
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
|
||||
}
|
||||
|
||||
const DefaultMaxMediaSize = 20 * 1024 * 1024 // 20 MB
|
||||
|
||||
func (d *AgentDefaults) GetMaxMediaSize() int {
|
||||
if d.MaxMediaSize > 0 {
|
||||
return d.MaxMediaSize
|
||||
}
|
||||
return DefaultMaxMediaSize
|
||||
}
|
||||
|
||||
// GetModelName returns the effective model name for the agent defaults.
|
||||
@@ -237,6 +247,7 @@ type WhatsAppConfig struct {
|
||||
type TelegramConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
|
||||
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
|
||||
BaseURL string `json:"base_url" env:"PICOCLAW_CHANNELS_TELEGRAM_BASE_URL"`
|
||||
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
@@ -253,6 +264,7 @@ type FeishuConfig struct {
|
||||
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
|
||||
GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"`
|
||||
Placeholder PlaceholderConfig `json:"placeholder,omitempty"`
|
||||
ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"`
|
||||
}
|
||||
|
||||
@@ -399,6 +411,7 @@ type DevicesConfig struct {
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI OpenAIProviderConfig `json:"openai"`
|
||||
LiteLLM ProviderConfig `json:"litellm"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
@@ -423,6 +436,7 @@ type ProvidersConfig struct {
|
||||
func (p ProvidersConfig) IsEmpty() bool {
|
||||
return p.Anthropic.APIKey == "" && p.Anthropic.APIBase == "" &&
|
||||
p.OpenAI.APIKey == "" && p.OpenAI.APIBase == "" &&
|
||||
p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" &&
|
||||
p.OpenRouter.APIKey == "" && p.OpenRouter.APIBase == "" &&
|
||||
p.Groq.APIKey == "" && p.Groq.APIBase == "" &&
|
||||
p.Zhipu.APIKey == "" && p.Zhipu.APIBase == "" &&
|
||||
@@ -567,6 +581,7 @@ type ToolsConfig struct {
|
||||
Exec ExecConfig `json:"exec"`
|
||||
Skills SkillsToolsConfig `json:"skills"`
|
||||
MediaCleanup MediaCleanupConfig `json:"media_cleanup"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
type SkillsToolsConfig struct {
|
||||
@@ -596,6 +611,34 @@ type ClawHubRegistryConfig struct {
|
||||
MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
|
||||
}
|
||||
|
||||
// MCPServerConfig defines configuration for a single MCP server
|
||||
type MCPServerConfig struct {
|
||||
// Enabled indicates whether this MCP server is active
|
||||
Enabled bool `json:"enabled"`
|
||||
// Command is the executable to run (e.g., "npx", "python", "/path/to/server")
|
||||
Command string `json:"command"`
|
||||
// Args are the arguments to pass to the command
|
||||
Args []string `json:"args,omitempty"`
|
||||
// Env are environment variables to set for the server process (stdio only)
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
// EnvFile is the path to a file containing environment variables (stdio only)
|
||||
EnvFile string `json:"env_file,omitempty"`
|
||||
// Type is "stdio", "sse", or "http" (default: stdio if command is set, sse if url is set)
|
||||
Type string `json:"type,omitempty"`
|
||||
// URL is used for SSE/HTTP transport
|
||||
URL string `json:"url,omitempty"`
|
||||
// Headers are HTTP headers to send with requests (sse/http only)
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
// MCPConfig defines configuration for all MCP servers
|
||||
type MCPConfig struct {
|
||||
// Enabled globally enables/disables MCP integration
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
|
||||
@@ -361,6 +361,10 @@ func DefaultConfig() *Config {
|
||||
TTLSeconds: 300,
|
||||
},
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: false,
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -88,6 +88,23 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"litellm"},
|
||||
protocol: "litellm",
|
||||
buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
|
||||
if p.LiteLLM.APIKey == "" && p.LiteLLM.APIBase == "" {
|
||||
return ModelConfig{}, false
|
||||
}
|
||||
return ModelConfig{
|
||||
ModelName: "litellm",
|
||||
Model: "litellm/auto",
|
||||
APIKey: p.LiteLLM.APIKey,
|
||||
APIBase: p.LiteLLM.APIBase,
|
||||
Proxy: p.LiteLLM.Proxy,
|
||||
RequestTimeout: p.LiteLLM.RequestTimeout,
|
||||
}, true
|
||||
},
|
||||
},
|
||||
{
|
||||
providerNames: []string{"openrouter"},
|
||||
protocol: "openrouter",
|
||||
|
||||
@@ -63,6 +63,33 @@ func TestConvertProvidersToModelList_Anthropic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_LiteLLM(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
LiteLLM: ProviderConfig{
|
||||
APIKey: "litellm-key",
|
||||
APIBase: "http://localhost:4000/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ConvertProvidersToModelList(cfg)
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("len(result) = %d, want 1", len(result))
|
||||
}
|
||||
|
||||
if result[0].ModelName != "litellm" {
|
||||
t.Errorf("ModelName = %q, want %q", result[0].ModelName, "litellm")
|
||||
}
|
||||
if result[0].Model != "litellm/auto" {
|
||||
t.Errorf("Model = %q, want %q", result[0].Model, "litellm/auto")
|
||||
}
|
||||
if result[0].APIBase != "http://localhost:4000/v1" {
|
||||
t.Errorf("APIBase = %q, want %q", result[0].APIBase, "http://localhost:4000/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertProvidersToModelList_Multiple(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
@@ -115,6 +142,7 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Providers: ProvidersConfig{
|
||||
OpenAI: OpenAIProviderConfig{ProviderConfig: ProviderConfig{APIKey: "key1"}},
|
||||
LiteLLM: ProviderConfig{APIKey: "key-litellm", APIBase: "http://localhost:4000/v1"},
|
||||
Anthropic: ProviderConfig{APIKey: "key2"},
|
||||
OpenRouter: ProviderConfig{APIKey: "key3"},
|
||||
Groq: ProviderConfig{APIKey: "key4"},
|
||||
|
||||
@@ -0,0 +1,532 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// headerTransport is an http.RoundTripper that adds custom headers to requests
|
||||
type headerTransport struct {
|
||||
base http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
// Add custom headers
|
||||
for key, value := range t.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Use the base transport
|
||||
base := t.base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// loadEnvFile loads environment variables from a file in .env format
|
||||
// Each line should be in the format: KEY=value
|
||||
// Lines starting with # are comments
|
||||
// Empty lines are ignored
|
||||
func loadEnvFile(path string) (map[string]string, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open env file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
envVars := make(map[string]string)
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse KEY=value
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid format at line %d: %s", lineNum, line)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("invalid format at line %d: empty key", lineNum)
|
||||
}
|
||||
|
||||
// Remove surrounding quotes if present
|
||||
if len(value) >= 2 {
|
||||
if (value[0] == '"' && value[len(value)-1] == '"') ||
|
||||
(value[0] == '\'' && value[len(value)-1] == '\'') {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
envVars[key] = value
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading env file: %w", err)
|
||||
}
|
||||
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
// ServerConnection represents a connection to an MCP server
|
||||
type ServerConnection struct {
|
||||
Name string
|
||||
Client *mcp.Client
|
||||
Session *mcp.ClientSession
|
||||
Tools []*mcp.Tool
|
||||
}
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
type Manager struct {
|
||||
servers map[string]*ServerConnection
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
|
||||
wg sync.WaitGroup // tracks in-flight CallTool calls
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
servers: make(map[string]*ServerConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromConfig loads MCP servers from configuration
|
||||
func (m *Manager) LoadFromConfig(ctx context.Context, cfg *config.Config) error {
|
||||
return m.LoadFromMCPConfig(ctx, cfg.Tools.MCP, cfg.WorkspacePath())
|
||||
}
|
||||
|
||||
// LoadFromMCPConfig loads MCP servers from MCP configuration and workspace path.
|
||||
// This is the minimal dependency version that doesn't require the full Config object.
|
||||
func (m *Manager) LoadFromMCPConfig(
|
||||
ctx context.Context,
|
||||
mcpCfg config.MCPConfig,
|
||||
workspacePath string,
|
||||
) error {
|
||||
if !mcpCfg.Enabled {
|
||||
logger.InfoCF("mcp", "MCP integration is disabled", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(mcpCfg.Servers) == 0 {
|
||||
logger.InfoCF("mcp", "No MCP servers configured", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Initializing MCP servers",
|
||||
map[string]any{
|
||||
"count": len(mcpCfg.Servers),
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, len(mcpCfg.Servers))
|
||||
enabledCount := 0
|
||||
|
||||
for name, serverCfg := range mcpCfg.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
logger.DebugCF("mcp", "Skipping disabled server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
enabledCount++
|
||||
wg.Add(1)
|
||||
go func(name string, serverCfg config.MCPServerConfig, workspace string) {
|
||||
defer wg.Done()
|
||||
|
||||
// Resolve relative envFile paths relative to workspace
|
||||
if serverCfg.EnvFile != "" && !filepath.IsAbs(serverCfg.EnvFile) {
|
||||
if workspace == "" {
|
||||
err := fmt.Errorf(
|
||||
"workspace path is empty while resolving relative envFile %q for server %s",
|
||||
serverCfg.EnvFile,
|
||||
name,
|
||||
)
|
||||
logger.ErrorCF("mcp", "Invalid MCP server configuration",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"env_file": serverCfg.EnvFile,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
serverCfg.EnvFile = filepath.Join(workspace, serverCfg.EnvFile)
|
||||
}
|
||||
|
||||
if err := m.ConnectServer(ctx, name, serverCfg); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to connect to MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs <- fmt.Errorf("failed to connect to server %s: %w", name, err)
|
||||
}
|
||||
}(name, serverCfg, workspacePath)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
// Collect errors
|
||||
var allErrors []error
|
||||
for err := range errs {
|
||||
allErrors = append(allErrors, err)
|
||||
}
|
||||
|
||||
connectedCount := len(m.GetServers())
|
||||
|
||||
// If all enabled servers failed to connect, return aggregated error
|
||||
if enabledCount > 0 && connectedCount == 0 {
|
||||
logger.ErrorCF("mcp", "All MCP servers failed to connect",
|
||||
map[string]any{
|
||||
"failed": len(allErrors),
|
||||
"total": enabledCount,
|
||||
})
|
||||
return errors.Join(allErrors...)
|
||||
}
|
||||
|
||||
if len(allErrors) > 0 {
|
||||
logger.WarnCF("mcp", "Some MCP servers failed to connect",
|
||||
map[string]any{
|
||||
"failed": len(allErrors),
|
||||
"connected": connectedCount,
|
||||
"total": enabledCount,
|
||||
})
|
||||
// Don't fail completely if some servers successfully connected
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "MCP server initialization complete",
|
||||
map[string]any{
|
||||
"connected": connectedCount,
|
||||
"total": enabledCount,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectServer connects to a single MCP server
|
||||
func (m *Manager) ConnectServer(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) error {
|
||||
logger.InfoCF("mcp", "Connecting to MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
"args_count": len(cfg.Args),
|
||||
})
|
||||
|
||||
// Create client
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: "picoclaw",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
// Create transport based on configuration
|
||||
// Auto-detect transport type if not explicitly specified
|
||||
var transport mcp.Transport
|
||||
transportType := cfg.Type
|
||||
|
||||
// Auto-detect: if URL is provided, use SSE; if command is provided, use stdio
|
||||
if transportType == "" {
|
||||
if cfg.URL != "" {
|
||||
transportType = "sse"
|
||||
} else if cfg.Command != "" {
|
||||
transportType = "stdio"
|
||||
} else {
|
||||
return fmt.Errorf("either URL or command must be provided")
|
||||
}
|
||||
}
|
||||
|
||||
switch transportType {
|
||||
case "sse", "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("URL is required for SSE/HTTP transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using SSE/HTTP transport",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"url": cfg.URL,
|
||||
})
|
||||
|
||||
sseTransport := &mcp.StreamableClientTransport{
|
||||
Endpoint: cfg.URL,
|
||||
}
|
||||
|
||||
// Add custom headers if provided
|
||||
if len(cfg.Headers) > 0 {
|
||||
// Create a custom HTTP client with header-injecting transport
|
||||
sseTransport.HTTPClient = &http.Client{
|
||||
Transport: &headerTransport{
|
||||
base: http.DefaultTransport,
|
||||
headers: cfg.Headers,
|
||||
},
|
||||
}
|
||||
logger.DebugCF("mcp", "Added custom HTTP headers",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"header_count": len(cfg.Headers),
|
||||
})
|
||||
}
|
||||
|
||||
transport = sseTransport
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("command is required for stdio transport")
|
||||
}
|
||||
logger.DebugCF("mcp", "Using stdio transport",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"command": cfg.Command,
|
||||
})
|
||||
// Create command with context
|
||||
cmd := exec.CommandContext(ctx, cfg.Command, cfg.Args...)
|
||||
|
||||
// Build environment variables with proper override semantics
|
||||
// Use a map to ensure config variables override file variables
|
||||
envMap := make(map[string]string)
|
||||
|
||||
// Start with parent process environment
|
||||
for _, e := range cmd.Environ() {
|
||||
if idx := strings.Index(e, "="); idx > 0 {
|
||||
envMap[e[:idx]] = e[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Load environment variables from file if specified
|
||||
if cfg.EnvFile != "" {
|
||||
envVars, err := loadEnvFile(cfg.EnvFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load env file %s: %w", cfg.EnvFile, err)
|
||||
}
|
||||
for k, v := range envVars {
|
||||
envMap[k] = v
|
||||
}
|
||||
logger.DebugCF("mcp", "Loaded environment variables from file",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"envFile": cfg.EnvFile,
|
||||
"var_count": len(envVars),
|
||||
})
|
||||
}
|
||||
|
||||
// Environment variables from config override those from file
|
||||
for k, v := range cfg.Env {
|
||||
envMap[k] = v
|
||||
}
|
||||
|
||||
// Convert map to slice
|
||||
env := make([]string, 0, len(envMap))
|
||||
for k, v := range envMap {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
cmd.Env = env
|
||||
|
||||
transport = &mcp.CommandTransport{Command: cmd}
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"unsupported transport type: %s (supported: stdio, sse, http)",
|
||||
transportType,
|
||||
)
|
||||
}
|
||||
|
||||
// Connect to server
|
||||
session, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// Get server info
|
||||
initResult := session.InitializeResult()
|
||||
logger.InfoCF("mcp", "Connected to MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"serverName": initResult.ServerInfo.Name,
|
||||
"serverVersion": initResult.ServerInfo.Version,
|
||||
"protocol": initResult.ProtocolVersion,
|
||||
})
|
||||
|
||||
// List available tools if supported
|
||||
var tools []*mcp.Tool
|
||||
if initResult.Capabilities.Tools != nil {
|
||||
for tool, err := range session.Tools(ctx, nil) {
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "Error listing tool",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
logger.InfoCF("mcp", "Listed tools from MCP server",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"toolCount": len(tools),
|
||||
})
|
||||
}
|
||||
|
||||
// Store connection
|
||||
m.mu.Lock()
|
||||
m.servers[name] = &ServerConnection{
|
||||
Name: name,
|
||||
Client: client,
|
||||
Session: session,
|
||||
Tools: tools,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServers returns all connected servers
|
||||
func (m *Manager) GetServers() map[string]*ServerConnection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ServerConnection, len(m.servers))
|
||||
for k, v := range m.servers {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetServer returns a specific server connection
|
||||
func (m *Manager) GetServer(name string) (*ServerConnection, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
conn, ok := m.servers[name]
|
||||
return conn, ok
|
||||
}
|
||||
|
||||
// CallTool calls a tool on a specific server
|
||||
func (m *Manager) CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
// Check if closed before acquiring lock (fast path)
|
||||
if m.closed.Load() {
|
||||
return nil, fmt.Errorf("manager is closed")
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
// Double-check after acquiring lock to prevent TOCTOU race
|
||||
if m.closed.Load() {
|
||||
m.mu.RUnlock()
|
||||
return nil, fmt.Errorf("manager is closed")
|
||||
}
|
||||
conn, ok := m.servers[serverName]
|
||||
if ok {
|
||||
m.wg.Add(1) // Add to WaitGroup while holding the lock
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server %s not found", serverName)
|
||||
}
|
||||
defer m.wg.Done()
|
||||
|
||||
params := &mcp.CallToolParams{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
result, err := conn.Session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call tool: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Close closes all server connections
|
||||
func (m *Manager) Close() error {
|
||||
// Use Swap to atomically set closed=true and get the previous value
|
||||
// This prevents TOCTOU race with CallTool's closed check
|
||||
if m.closed.Swap(true) {
|
||||
return nil // already closed
|
||||
}
|
||||
|
||||
// Wait for all in-flight CallTool calls to finish before closing sessions
|
||||
// After closed=true is set, no new CallTool can start (they check closed first)
|
||||
m.wg.Wait()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
logger.InfoCF("mcp", "Closing all MCP server connections",
|
||||
map[string]any{
|
||||
"count": len(m.servers),
|
||||
})
|
||||
|
||||
var errs []error
|
||||
for name, conn := range m.servers {
|
||||
if err := conn.Session.Close(); err != nil {
|
||||
logger.ErrorCF("mcp", "Failed to close server connection",
|
||||
map[string]any{
|
||||
"server": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
errs = append(errs, fmt.Errorf("server %s: %w", name, err))
|
||||
}
|
||||
}
|
||||
|
||||
m.servers = make(map[string]*ServerConnection)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("failed to close %d server(s): %w", len(errs), errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllTools returns all tools from all connected servers
|
||||
func (m *Manager) GetAllTools() map[string][]*mcp.Tool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string][]*mcp.Tool)
|
||||
for name, conn := range m.servers {
|
||||
if len(conn.Tools) > 0 {
|
||||
result[name] = conn.Tools
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "basic env file",
|
||||
content: `API_KEY=secret123
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with comments and empty lines",
|
||||
content: `# This is a comment
|
||||
API_KEY=secret123
|
||||
|
||||
# Another comment
|
||||
DATABASE_URL=postgres://localhost/db
|
||||
|
||||
PORT=8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with quoted values",
|
||||
content: `API_KEY="secret with spaces"
|
||||
NAME='single quoted'
|
||||
PLAIN=no-quotes`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret with spaces",
|
||||
"NAME": "single quoted",
|
||||
"PLAIN": "no-quotes",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "with spaces around equals",
|
||||
content: `API_KEY = secret123
|
||||
DATABASE_URL= postgres://localhost/db
|
||||
PORT =8080`,
|
||||
expected: map[string]string{
|
||||
"API_KEY": "secret123",
|
||||
"DATABASE_URL": "postgres://localhost/db",
|
||||
"PORT": "8080",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format - no equals",
|
||||
content: `INVALID_LINE`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
content: ``,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "only comments",
|
||||
content: `# Comment 1
|
||||
# Comment 2`,
|
||||
expected: map[string]string{},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(tt.content), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
result, err := loadEnvFile(envFile)
|
||||
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d variables, got %d", len(tt.expected), len(result))
|
||||
}
|
||||
|
||||
for key, expectedValue := range tt.expected {
|
||||
if actualValue, ok := result[key]; !ok {
|
||||
t.Errorf("Expected key %s not found", key)
|
||||
} else if actualValue != expectedValue {
|
||||
t.Errorf("For key %s: expected %q, got %q", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadEnvFileNotFound(t *testing.T) {
|
||||
_, err := loadEnvFile("/nonexistent/file.env")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvFilePriority(t *testing.T) {
|
||||
// Create a temporary .env file
|
||||
tmpDir := t.TempDir()
|
||||
envFile := filepath.Join(tmpDir, ".env")
|
||||
|
||||
envContent := `API_KEY=from_file
|
||||
DATABASE_URL=from_file
|
||||
SHARED_VAR=from_file`
|
||||
|
||||
if err := os.WriteFile(envFile, []byte(envContent), 0o644); err != nil {
|
||||
t.Fatalf("Failed to create .env file: %v", err)
|
||||
}
|
||||
|
||||
// Load envFile
|
||||
envVars, err := loadEnvFile(envFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load env file: %v", err)
|
||||
}
|
||||
|
||||
// Verify envFile variables
|
||||
if envVars["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", envVars["API_KEY"])
|
||||
}
|
||||
|
||||
// Simulate config.Env overriding envFile
|
||||
configEnv := map[string]string{
|
||||
"SHARED_VAR": "from_config",
|
||||
"NEW_VAR": "from_config",
|
||||
}
|
||||
|
||||
// Merge: envFile first, then config overrides
|
||||
merged := make(map[string]string)
|
||||
for k, v := range envVars {
|
||||
merged[k] = v
|
||||
}
|
||||
for k, v := range configEnv {
|
||||
merged[k] = v
|
||||
}
|
||||
|
||||
// Verify priority: config.Env should override envFile
|
||||
if merged["SHARED_VAR"] != "from_config" {
|
||||
t.Errorf(
|
||||
"Expected SHARED_VAR=from_config (config should override file), got %s",
|
||||
merged["SHARED_VAR"],
|
||||
)
|
||||
}
|
||||
if merged["API_KEY"] != "from_file" {
|
||||
t.Errorf("Expected API_KEY=from_file, got %s", merged["API_KEY"])
|
||||
}
|
||||
if merged["NEW_VAR"] != "from_config" {
|
||||
t.Errorf("Expected NEW_VAR=from_config, got %s", merged["NEW_VAR"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromMCPConfig_EmptyWorkspaceWithRelativeEnvFile(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
mcpCfg := config.MCPConfig{
|
||||
Enabled: true,
|
||||
Servers: map[string]config.MCPServerConfig{
|
||||
"test-server": {
|
||||
Enabled: true,
|
||||
Command: "echo",
|
||||
Args: []string{"ok"},
|
||||
EnvFile: ".env",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := mgr.LoadFromMCPConfig(context.Background(), mcpCfg, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for relative env_file with empty workspace path, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "workspace path is empty") {
|
||||
t.Fatalf("expected workspace path validation error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_InitialState(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
if mgr == nil {
|
||||
t.Fatal("expected manager instance, got nil")
|
||||
}
|
||||
if len(mgr.GetServers()) != 0 {
|
||||
t.Fatalf("expected no servers on new manager, got %d", len(mgr.GetServers()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
err := mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: false}, "/tmp")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when MCP disabled, got: %v", err)
|
||||
}
|
||||
|
||||
err = mgr.LoadFromMCPConfig(context.Background(), config.MCPConfig{Enabled: true}, "/tmp")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error when no servers configured, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServers_ReturnsCopy(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.servers["s1"] = &ServerConnection{Name: "s1"}
|
||||
|
||||
servers := mgr.GetServers()
|
||||
delete(servers, "s1")
|
||||
|
||||
if _, ok := mgr.GetServer("s1"); !ok {
|
||||
t.Fatal("expected internal manager state to remain unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllTools_FiltersEmptyTools(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.servers["empty"] = &ServerConnection{Name: "empty", Tools: nil}
|
||||
mgr.servers["with-tools"] = &ServerConnection{Name: "with-tools", Tools: []*sdkmcp.Tool{{}}}
|
||||
|
||||
all := mgr.GetAllTools()
|
||||
if _, ok := all["empty"]; ok {
|
||||
t.Fatal("expected server without tools to be excluded")
|
||||
}
|
||||
if _, ok := all["with-tools"]; !ok {
|
||||
t.Fatal("expected server with tools to be included")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallTool_ErrorsForClosedOrMissingServer(t *testing.T) {
|
||||
t.Run("manager closed", func(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.closed.Store(true)
|
||||
|
||||
_, err := mgr.CallTool(context.Background(), "s1", "tool", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "manager is closed") {
|
||||
t.Fatalf("expected manager closed error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server missing", func(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
_, err := mgr.CallTool(context.Background(), "missing", "tool", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "not found") {
|
||||
t.Fatalf("expected server not found error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClose_IdempotentOnEmptyManager(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
if err := mgr.Close(); err != nil {
|
||||
t.Fatalf("first close should succeed, got: %v", err)
|
||||
}
|
||||
if err := mgr.Close(); err != nil {
|
||||
t.Fatalf("second close should be idempotent, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -102,6 +102,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "litellm":
|
||||
if cfg.Providers.LiteLLM.APIKey != "" || cfg.Providers.LiteLLM.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.LiteLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.LiteLLM.APIBase
|
||||
sel.proxy = cfg.Providers.LiteLLM.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "http://localhost:4000/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
|
||||
@@ -53,7 +53,7 @@ func ExtractProtocol(model string) (protocol, modelID string) {
|
||||
|
||||
// CreateProviderFromConfig creates a provider based on the ModelConfig.
|
||||
// It uses the protocol prefix in the Model field to determine which provider to create.
|
||||
// Supported protocols: openai, anthropic, antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Supported protocols: openai, litellm, anthropic, antigravity, claude-cli, codex-cli, github-copilot
|
||||
// Returns the provider, the model ID (without protocol prefix), and any error.
|
||||
func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, error) {
|
||||
if cfg == nil {
|
||||
@@ -92,7 +92,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"volcengine", "vllm", "qwen", "mistral", "opencode":
|
||||
// All other OpenAI-compatible HTTP providers
|
||||
@@ -180,6 +180,8 @@ func getDefaultAPIBase(protocol string) string {
|
||||
return "https://api.openai.com/v1"
|
||||
case "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
case "litellm":
|
||||
return "http://localhost:4000/v1"
|
||||
case "groq":
|
||||
return "https://api.groq.com/openai/v1"
|
||||
case "zhipu":
|
||||
|
||||
@@ -136,6 +136,32 @@ func TestCreateProviderFromConfig_DefaultAPIBase(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultAPIBase_LiteLLM(t *testing.T) {
|
||||
if got := getDefaultAPIBase("litellm"); got != "http://localhost:4000/v1" {
|
||||
t.Fatalf("getDefaultAPIBase(%q) = %q, want %q", "litellm", got, "http://localhost:4000/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_LiteLLM(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-litellm",
|
||||
Model: "litellm/my-proxy-alias",
|
||||
APIKey: "test-key",
|
||||
APIBase: "http://localhost:4000/v1",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-proxy-alias" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-proxy-alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Anthropic(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-anthropic",
|
||||
|
||||
@@ -17,6 +17,27 @@ func TestResolveProviderSelection(t *testing.T) {
|
||||
wantProxy string
|
||||
wantErrSubstr string
|
||||
}{
|
||||
{
|
||||
name: "explicit litellm provider uses configured base",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "litellm"
|
||||
cfg.Providers.LiteLLM.APIKey = "litellm-key"
|
||||
cfg.Providers.LiteLLM.APIBase = "http://localhost:4000/v1"
|
||||
cfg.Providers.LiteLLM.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:4000/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit litellm provider defaults base when only key is configured",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "litellm"
|
||||
cfg.Providers.LiteLLM.APIKey = "litellm-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:4000/v1",
|
||||
},
|
||||
{
|
||||
name: "explicit claude-cli provider routes to cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
|
||||
@@ -124,7 +124,7 @@ func (p *Provider) Chat(
|
||||
|
||||
requestBody := map[string]any{
|
||||
"model": model,
|
||||
"messages": stripSystemParts(messages),
|
||||
"messages": serializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
@@ -310,19 +310,55 @@ type openaiMessage struct {
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// stripSystemParts converts []Message to []openaiMessage, dropping the
|
||||
// SystemParts field so it doesn't leak into the JSON payload sent to
|
||||
// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
|
||||
func stripSystemParts(messages []Message) []openaiMessage {
|
||||
out := make([]openaiMessage, len(messages))
|
||||
for i, m := range messages {
|
||||
out[i] = openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
// serializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func serializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -339,7 +375,7 @@ func normalizeModel(model, apiBase string) string {
|
||||
|
||||
prefix := strings.ToLower(before)
|
||||
switch prefix {
|
||||
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
||||
case "litellm", "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
|
||||
return after
|
||||
default:
|
||||
return model
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
@@ -258,6 +260,11 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
input string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "strips litellm prefix and preserves proxy model name",
|
||||
input: "litellm/my-proxy-alias",
|
||||
wantModel: "my-proxy-alias",
|
||||
},
|
||||
{
|
||||
name: "strips groq prefix and keeps nested model",
|
||||
input: "groq/openai/gpt-oss-120b",
|
||||
@@ -489,3 +496,97 @@ func TestProviderChat_KimiCodeUserAgent(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["content"] != "hello" {
|
||||
t.Fatalf("expected plain string content, got %v", msgs[0]["content"])
|
||||
}
|
||||
if msgs[1]["reasoning_content"] != "thinking..." {
|
||||
t.Fatalf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
|
||||
textPart := content[0].(map[string]any)
|
||||
if textPart["type"] != "text" || textPart["text"] != "describe this" {
|
||||
t.Fatalf("text part mismatch: %v", textPart)
|
||||
}
|
||||
|
||||
imgPart := content[1].(map[string]any)
|
||||
if imgPart["type"] != "image_url" {
|
||||
t.Fatalf("expected image_url type, got %v", imgPart["type"])
|
||||
}
|
||||
imgURL := imgPart["image_url"].(map[string]any)
|
||||
if imgURL["url"] != "data:image/png;base64,abc123" {
|
||||
t.Fatalf("image url mismatch: %v", imgURL["url"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["tool_call_id"] != "call_1" {
|
||||
t.Fatalf("tool_call_id not preserved with media, got %v", msgs[0]["tool_call_id"])
|
||||
}
|
||||
// Content should be multipart array
|
||||
if _, ok := msgs[0]["content"].([]any); !ok {
|
||||
t.Fatalf("expected array content, got %T", msgs[0]["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "you are helpful",
|
||||
SystemParts: []protocoltypes.ContentBlock{
|
||||
{Type: "text", Text: "you are helpful"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
raw := string(data)
|
||||
if strings.Contains(raw, "system_parts") {
|
||||
t.Fatal("system_parts should not appear in serialized output")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ type ContentBlock struct {
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
@@ -64,6 +64,29 @@ type SkillsLoader struct {
|
||||
builtinSkills string // builtin skills
|
||||
}
|
||||
|
||||
// SkillRoots returns all unique skill root directories used by this loader.
|
||||
// The order follows resolution priority: workspace > global > builtin.
|
||||
func (sl *SkillsLoader) SkillRoots() []string {
|
||||
roots := []string{sl.workspaceSkills, sl.globalSkills, sl.builtinSkills}
|
||||
seen := make(map[string]struct{}, len(roots))
|
||||
out := make([]string, 0, len(roots))
|
||||
|
||||
for _, root := range roots {
|
||||
trimmed := strings.TrimSpace(root)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clean := filepath.Clean(trimmed)
|
||||
if _, ok := seen[clean]; ok {
|
||||
continue
|
||||
}
|
||||
seen[clean] = struct{}{}
|
||||
out = append(out, clean)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader {
|
||||
return &SkillsLoader{
|
||||
workspace: workspace,
|
||||
|
||||
@@ -326,3 +326,19 @@ func TestStripFrontmatter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkillRootsTrimsWhitespaceAndDedups(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
workspace := filepath.Join(tmp, "workspace")
|
||||
global := filepath.Join(tmp, "global")
|
||||
builtin := filepath.Join(tmp, "builtin")
|
||||
|
||||
sl := NewSkillsLoader(workspace, " "+global+" ", "\t"+builtin+"\n")
|
||||
roots := sl.SkillRoots()
|
||||
|
||||
assert.Equal(t, []string{
|
||||
filepath.Join(workspace, "skills"),
|
||||
global,
|
||||
builtin,
|
||||
}, roots)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,246 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MCPManager defines the interface for MCP manager operations
|
||||
// This allows for easier testing with mock implementations
|
||||
type MCPManager interface {
|
||||
CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
// MCPTool wraps an MCP tool to implement the Tool interface
|
||||
type MCPTool struct {
|
||||
manager MCPManager
|
||||
serverName string
|
||||
tool *mcp.Tool
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
func NewMCPTool(manager MCPManager, serverName string, tool *mcp.Tool) *MCPTool {
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
serverName: serverName,
|
||||
tool: tool,
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeIdentifierComponent normalizes a string so it can be safely used
|
||||
// as part of a tool/function identifier for downstream providers.
|
||||
// It:
|
||||
// - lowercases the string
|
||||
// - replaces any character not in [a-z0-9_-] with '_'
|
||||
// - collapses multiple consecutive '_' into a single '_'
|
||||
// - trims leading/trailing '_'
|
||||
// - falls back to "unnamed" if the result is empty
|
||||
// - truncates overly long components to a reasonable length
|
||||
func sanitizeIdentifierComponent(s string) string {
|
||||
const maxLen = 64
|
||||
|
||||
s = strings.ToLower(s)
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
prevUnderscore := false
|
||||
for _, r := range s {
|
||||
isAllowed := (r >= 'a' && r <= 'z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '_' || r == '-'
|
||||
|
||||
if !isAllowed {
|
||||
// Normalize any disallowed character to '_'
|
||||
if !prevUnderscore {
|
||||
b.WriteRune('_')
|
||||
prevUnderscore = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if r == '_' {
|
||||
if prevUnderscore {
|
||||
continue
|
||||
}
|
||||
prevUnderscore = true
|
||||
} else {
|
||||
prevUnderscore = false
|
||||
}
|
||||
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
result := strings.Trim(b.String(), "_")
|
||||
if result == "" {
|
||||
result = "unnamed"
|
||||
}
|
||||
|
||||
if len(result) > maxLen {
|
||||
result = result[:maxLen]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Name returns the tool name, prefixed with the server name.
|
||||
// The total length is capped at 64 characters (OpenAI-compatible API limit).
|
||||
// A short hash of the original (unsanitized) server and tool names is appended
|
||||
// whenever sanitization is lossy or the name is truncated, ensuring that two
|
||||
// names which differ only in disallowed characters remain distinct after sanitization.
|
||||
func (t *MCPTool) Name() string {
|
||||
// Prefix with server name to avoid conflicts, and sanitize components
|
||||
sanitizedServer := sanitizeIdentifierComponent(t.serverName)
|
||||
sanitizedTool := sanitizeIdentifierComponent(t.tool.Name)
|
||||
full := fmt.Sprintf("mcp_%s_%s", sanitizedServer, sanitizedTool)
|
||||
|
||||
// Check if sanitization was lossless (only lowercasing, no char replacement/truncation)
|
||||
lossless := strings.ToLower(t.serverName) == sanitizedServer &&
|
||||
strings.ToLower(t.tool.Name) == sanitizedTool
|
||||
|
||||
const maxTotal = 64
|
||||
if lossless && len(full) <= maxTotal {
|
||||
return full
|
||||
}
|
||||
|
||||
// Sanitization was lossy or name too long: append hash of the ORIGINAL names
|
||||
// (not the sanitized names) so different originals always yield different hashes.
|
||||
h := fnv.New32a()
|
||||
_, _ = h.Write([]byte(t.serverName + "\x00" + t.tool.Name))
|
||||
suffix := fmt.Sprintf("%08x", h.Sum32()) // 8 chars
|
||||
|
||||
base := full
|
||||
if len(base) > maxTotal-9 {
|
||||
base = strings.TrimRight(full[:maxTotal-9], "_")
|
||||
}
|
||||
return base + "_" + suffix
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *MCPTool) Description() string {
|
||||
desc := t.tool.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("MCP tool from %s server", t.serverName)
|
||||
}
|
||||
// Add server info to description
|
||||
return fmt.Sprintf("[MCP:%s] %s", t.serverName, desc)
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *MCPTool) Parameters() map[string]any {
|
||||
// The InputSchema is already a JSON Schema object
|
||||
schema := t.tool.InputSchema
|
||||
|
||||
// Handle nil schema
|
||||
if schema == nil {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// Try direct conversion first (fast path)
|
||||
if schemaMap, ok := schema.(map[string]any); ok {
|
||||
return schemaMap
|
||||
}
|
||||
|
||||
// Handle json.RawMessage and []byte - unmarshal directly
|
||||
var jsonData []byte
|
||||
if rawMsg, ok := schema.(json.RawMessage); ok {
|
||||
jsonData = rawMsg
|
||||
} else if bytes, ok := schema.([]byte); ok {
|
||||
jsonData = bytes
|
||||
}
|
||||
|
||||
if jsonData != nil {
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err == nil {
|
||||
return result
|
||||
}
|
||||
// Fallback on error
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// For other types (structs, etc.), convert via JSON marshal/unmarshal
|
||||
var err error
|
||||
jsonData, err = json.Marshal(schema)
|
||||
if err != nil {
|
||||
// Fallback to empty schema if marshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonData, &result); err != nil {
|
||||
// Fallback to empty schema if unmarshaling fails
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Execute executes the MCP tool
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
nilErr := fmt.Errorf("MCP tool returned nil result without error")
|
||||
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
|
||||
}
|
||||
|
||||
// Handle error result from server
|
||||
if result.IsError {
|
||||
errMsg := extractContentText(result.Content)
|
||||
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
|
||||
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
|
||||
}
|
||||
|
||||
// Extract text content from result
|
||||
output := extractContentText(result.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
// extractContentText extracts text from MCP content array
|
||||
func extractContentText(content []mcp.Content) string {
|
||||
var parts []string
|
||||
for _, c := range content {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
parts = append(parts, v.Text)
|
||||
case *mcp.ImageContent:
|
||||
// For images, just indicate that an image was returned
|
||||
parts = append(parts, fmt.Sprintf("[Image: %s]", v.MIMEType))
|
||||
default:
|
||||
// For other content types, use string representation
|
||||
parts = append(parts, fmt.Sprintf("[Content: %T]", v))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -0,0 +1,492 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MockMCPManager is a mock implementation of MCPManager interface for testing
|
||||
type MockMCPManager struct {
|
||||
callToolFunc func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error)
|
||||
}
|
||||
|
||||
func (m *MockMCPManager) CallTool(
|
||||
ctx context.Context,
|
||||
serverName, toolName string,
|
||||
arguments map[string]any,
|
||||
) (*mcp.CallToolResult, error) {
|
||||
if m.callToolFunc != nil {
|
||||
return m.callToolFunc(ctx, serverName, toolName, arguments)
|
||||
}
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "mock result"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestNewMCPTool verifies MCP tool creation
|
||||
func TestNewMCPTool(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Test input",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
if mcpTool == nil {
|
||||
t.Fatal("NewMCPTool should not return nil")
|
||||
}
|
||||
// Verify tool properties we can access
|
||||
if mcpTool.Name() != "mcp_test_server_test_tool" {
|
||||
t.Errorf("Expected tool name with prefix, got '%s'", mcpTool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Name verifies tool name with server prefix
|
||||
func TestMCPTool_Name(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple name",
|
||||
serverName: "github",
|
||||
toolName: "create_issue",
|
||||
expected: "mcp_github_create_issue",
|
||||
},
|
||||
{
|
||||
name: "filesystem server",
|
||||
serverName: "filesystem",
|
||||
toolName: "read_file",
|
||||
expected: "mcp_filesystem_read_file",
|
||||
},
|
||||
{
|
||||
name: "remote server",
|
||||
serverName: "remote-api",
|
||||
toolName: "fetch_data",
|
||||
expected: "mcp_remote-api_fetch_data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: tt.toolName}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Name()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected name '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Description verifies tool description generation
|
||||
func TestMCPTool_Description(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
toolDescription string
|
||||
expectContains []string
|
||||
}{
|
||||
{
|
||||
name: "with description",
|
||||
serverName: "github",
|
||||
toolDescription: "Create a GitHub issue",
|
||||
expectContains: []string{"[MCP:github]", "Create a GitHub issue"},
|
||||
},
|
||||
{
|
||||
name: "empty description",
|
||||
serverName: "filesystem",
|
||||
toolDescription: "",
|
||||
expectContains: []string{"[MCP:filesystem]", "MCP tool from filesystem server"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
Description: tt.toolDescription,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, tt.serverName, tool)
|
||||
|
||||
result := mcpTool.Description()
|
||||
|
||||
for _, expected := range tt.expectContains {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Description should contain '%s', got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters verifies parameter schema conversion
|
||||
func TestMCPTool_Parameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputSchema any
|
||||
expectType string
|
||||
checkProperty string
|
||||
expectProperty bool
|
||||
}{
|
||||
{
|
||||
name: "map schema",
|
||||
inputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
expectType: "object",
|
||||
checkProperty: "query",
|
||||
expectProperty: true,
|
||||
},
|
||||
{
|
||||
name: "nil schema",
|
||||
inputSchema: nil,
|
||||
expectType: "object",
|
||||
expectProperty: false,
|
||||
},
|
||||
{
|
||||
name: "json.RawMessage schema",
|
||||
inputSchema: []byte(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"stars": {
|
||||
"type": "integer",
|
||||
"description": "Minimum stars"
|
||||
}
|
||||
},
|
||||
"required": ["repo"]
|
||||
}`),
|
||||
expectType: "object",
|
||||
checkProperty: "repo",
|
||||
expectProperty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: tt.inputSchema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
if params == nil {
|
||||
t.Fatal("Parameters should not be nil")
|
||||
}
|
||||
|
||||
if params["type"] != tt.expectType {
|
||||
t.Errorf("Expected type '%s', got '%v'", tt.expectType, params["type"])
|
||||
}
|
||||
|
||||
// Check if property exists when expected
|
||||
if tt.checkProperty != "" {
|
||||
properties, ok := params["properties"].(map[string]any)
|
||||
if !ok && tt.expectProperty {
|
||||
t.Errorf("Expected properties to be a map")
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
_, hasProperty := properties[tt.checkProperty]
|
||||
if hasProperty != tt.expectProperty {
|
||||
t.Errorf("Expected property '%s' existence: %v, got: %v",
|
||||
tt.checkProperty, tt.expectProperty, hasProperty)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_Success tests successful tool execution
|
||||
func TestMCPTool_Execute_Success(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
// Verify correct parameters passed
|
||||
if serverName != "github" {
|
||||
t.Errorf("Expected serverName 'github', got '%s'", serverName)
|
||||
}
|
||||
if toolName != "search_repos" {
|
||||
t.Errorf("Expected toolName 'search_repos', got '%s'", toolName)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Found 3 repositories"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "search_repos",
|
||||
Description: "Search GitHub repositories",
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "github", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]any{
|
||||
"query": "golang mcp",
|
||||
}
|
||||
|
||||
result := mcpTool.Execute(ctx, args)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got error: %s", result.ForLLM)
|
||||
}
|
||||
if result.ForLLM != "Found 3 repositories" {
|
||||
t.Errorf("Expected 'Found 3 repositories', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
|
||||
func TestMCPTool_Execute_ManagerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return nil, fmt.Errorf("connection failed")
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool execution failed") {
|
||||
t.Errorf("Error message should mention execution failure, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "connection failed") {
|
||||
t.Errorf("Error message should include original error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ServerError tests execution when server returns error
|
||||
func TestMCPTool_Execute_ServerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Invalid API key"},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "test_tool"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "MCP tool returned error") {
|
||||
t.Errorf("Error message should mention server error, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Invalid API key") {
|
||||
t.Errorf("Error message should include server message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_MultipleContent tests execution with multiple content items
|
||||
func TestMCPTool_Execute_MultipleContent(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
callToolFunc: func(ctx context.Context, serverName, toolName string, arguments map[string]any) (*mcp.CallToolResult, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "First line"},
|
||||
&mcp.TextContent{Text: "Second line"},
|
||||
&mcp.TextContent{Text: "Third line"},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{Name: "multi_output"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
ctx := context.Background()
|
||||
result := mcpTool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected no error, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
expected := "First line\nSecond line\nThird line"
|
||||
if result.ForLLM != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_TextContent tests text content extraction
|
||||
func TestExtractContentText_TextContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Hello World"},
|
||||
&mcp.TextContent{Text: "Second message"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
expected := "Hello World\nSecond message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_ImageContent tests image content extraction
|
||||
func TestExtractContentText_ImageContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("base64data"),
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Expected image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "image/png") {
|
||||
t.Errorf("Expected MIME type in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_MixedContent tests mixed content types
|
||||
func TestExtractContentText_MixedContent(t *testing.T) {
|
||||
content := []mcp.Content{
|
||||
&mcp.TextContent{Text: "Description"},
|
||||
&mcp.ImageContent{
|
||||
Data: []byte("data"),
|
||||
MIMEType: "image/jpeg",
|
||||
},
|
||||
&mcp.TextContent{Text: "More text"},
|
||||
}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if !strings.Contains(result, "Description") {
|
||||
t.Errorf("Should contain text content, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "[Image:") {
|
||||
t.Errorf("Should contain image indicator, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "More text") {
|
||||
t.Errorf("Should contain second text, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractContentText_EmptyContent tests empty content array
|
||||
func TestExtractContentText_EmptyContent(t *testing.T) {
|
||||
content := []mcp.Content{}
|
||||
|
||||
result := extractContentText(content)
|
||||
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty string for empty content, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_InterfaceCompliance verifies MCPTool implements Tool interface
|
||||
func TestMCPTool_InterfaceCompliance(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
tool := &mcp.Tool{Name: "test"}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
// Verify it implements Tool interface
|
||||
var _ Tool = mcpTool
|
||||
}
|
||||
|
||||
// TestMCPTool_Parameters_MapSchema tests schema that's already a map
|
||||
func TestMCPTool_Parameters_MapSchema(t *testing.T) {
|
||||
manager := &MockMCPManager{}
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The name parameter",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
}
|
||||
|
||||
tool := &mcp.Tool{
|
||||
Name: "test_tool",
|
||||
InputSchema: schema,
|
||||
}
|
||||
mcpTool := NewMCPTool(manager, "test_server", tool)
|
||||
|
||||
params := mcpTool.Parameters()
|
||||
|
||||
// Should return the schema as-is when it's already a map
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got '%v'", params["type"])
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Properties should be a map")
|
||||
}
|
||||
|
||||
nameParam, ok := props["name"].(map[string]any)
|
||||
if !ok {
|
||||
t.Error("Name parameter should exist")
|
||||
}
|
||||
|
||||
if nameParam["type"] != "string" {
|
||||
t.Errorf("Name type should be 'string', got '%v'", nameParam["type"])
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,12 @@ func NewToolRegistry() *ToolRegistry {
|
||||
func (r *ToolRegistry) Register(tool Tool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.tools[tool.Name()] = tool
|
||||
name := tool.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
logger.WarnCF("tools", "Tool registration overwrites existing tool",
|
||||
map[string]any{"name": name})
|
||||
}
|
||||
r.tools[name] = tool
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
|
||||
@@ -109,6 +109,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("brave api error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
|
||||
Executable
+49
@@ -0,0 +1,49 @@
|
||||
#!/bin/sh
|
||||
# Test script for MCP tools in Docker (full-featured image)
|
||||
|
||||
set -e
|
||||
|
||||
COMPOSE_FILE="docker/docker-compose.full.yml"
|
||||
SERVICE="picoclaw-agent"
|
||||
|
||||
echo "🧪 Testing MCP tools in Docker container (full-featured image)..."
|
||||
echo ""
|
||||
|
||||
# Build the image
|
||||
echo "📦 Building Docker image..."
|
||||
docker compose -f "$COMPOSE_FILE" build "$SERVICE"
|
||||
|
||||
# Test npx
|
||||
echo "✅ Testing npx..."
|
||||
docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'npx --version'
|
||||
|
||||
# Test npm
|
||||
echo "✅ Testing npm..."
|
||||
docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'npm --version'
|
||||
|
||||
# Test node
|
||||
echo "✅ Testing Node.js..."
|
||||
docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'node --version'
|
||||
|
||||
# Test git
|
||||
echo "✅ Testing git..."
|
||||
docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'git --version'
|
||||
|
||||
# Test python
|
||||
echo "✅ Testing Python..."
|
||||
docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'python3 --version'
|
||||
|
||||
# Test uv
|
||||
echo "✅ Testing uv..."
|
||||
docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c 'uv --version'
|
||||
|
||||
# Test MCP server installation (quick)
|
||||
echo "✅ Testing @modelcontextprotocol/server-filesystem MCP server install with npx..."
|
||||
docker compose -f "$COMPOSE_FILE" run --rm --entrypoint sh "$SERVICE" -c '</dev/null timeout 5 npx -y @modelcontextprotocol/server-filesystem /tmp || true'
|
||||
|
||||
echo ""
|
||||
echo "🎉 All MCP tools are working correctly!"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Configure MCP servers in config/config.json"
|
||||
echo " 2. Run: docker compose -f $COMPOSE_FILE --profile gateway up"
|
||||
Reference in New Issue
Block a user