chore: merge main branch into mcp-tools-support

Resolved conflicts in:
- config/config.example.json: Added empty MCP config block
- pkg/config/config.go: Added MCP config structures to new ToolsConfig
- pkg/agent/loop.go: Integrated MCP tools with new AgentRegistry architecture

MCP tools now register to all agents in the registry during startup.
This commit is contained in:
yuchou87
2026-02-19 19:06:37 +08:00
80 changed files with 10275 additions and 1721 deletions
+28
View File
@@ -0,0 +1,28 @@
---
name: Bug report
about: Report a bug or unexpected behavior
title: "[BUG]"
labels: bug
assignees: ''
---
## Quick Summary
## Environment & Tools
- **PicoClaw Version:** (e.g., v0.1.2 or commit hash)
- **Go Version:** (e.g., go 1.22)
- **AI Model & Provider:** (e.g., GPT-4o via OpenAI / DeepSeek via SiliconFlow)
- **Operating System:** (e.g., Ubuntu 22.04 / macOS / Android Termux)
- **Channels:** (e.g., Discord, Telegram, Feishu, ...)
## 📸 Steps to Reproduce
1.
2.
3.
## ❌ Actual Behavior
## ✅ Expected Behavior
## 💬 Additional Context
+23
View File
@@ -0,0 +1,23 @@
---
name: Feature request
about: Suggest a new idea or improvement
title: "[Feature]"
labels: enhancement
assignees: ''
---
## 🎯 The Goal / Use Case
## 💡 Proposed Solution
## 🛠 Potential Implementation (Optional)
## 🚦 Impact & Roadmap Alignment
- [ ] This is a Core Feature
- [ ] This is a Nice-to-Have / Enhancement
- [ ] This aligns with the current Roadmap
## 🔄 Alternatives Considered
## 💬 Additional Context
@@ -0,0 +1,26 @@
---
name: General Task / Todo
about: A specific piece of work like doc, refactoring, or maintenance.
title: "[Task]"
labels: ''
assignees: ''
---
## 📝 Objective
## 📋 To-Do List
- [ ] Step 1
- [ ] Step 2
- [ ] Step 3
## 🎯 Definition of Done (Acceptance Criteria)
- [ ] Documentation is updated in the README/docs folder.
- [ ] Code follows project linting standards.
- [ ] (If applicable) Basic tests pass.
## 💡 Context / Motivation
## 🔗 Related Issues / PRs
- Fixes #
- Relates to #
+43
View File
@@ -0,0 +1,43 @@
## 📝 Description
<!-- Please briefly describe the changes and purpose of this PR -->
## 🗣️ Type of Change
- [ ] 🐞 Bug fix (non-breaking change which fixes an issue)
- [ ] ✨ New feature (non-breaking change which adds functionality)
- [ ] 📖 Documentation update
- [ ] ⚡ Code refactoring (no functional changes, no api changes)
## 🤖 AI Code Generation
- [ ] 🤖 Fully AI-generated (100% AI, 0% Human)
- [ ] 🛠️ Mostly AI-generated (AI draft, Human verified/modified)
- [ ] 👨‍💻 Mostly Human-written (Human lead, AI assisted or none)
## 🔗 Related Issue
<!-- Please link the related issue(s) (e.g., Fixes #123, Closes #456) -->
## 📚 Technical Context (Skip for Docs)
- **Reference URL:**
- **Reasoning:**
## 🧪 Test Environment
- **Hardware:** <!-- e.g. Raspberry Pi 5, Orange Pi, PC-->
- **OS:** <!-- e.g. Debian 12, Ubuntu 22.04 -->
- **Model/Provider:** <!-- e.g. OpenAI GPT-4o, Kimi k2, DeepSeek-V3 -->
- **Channels:** <!-- e.g. Discord, Telegram, Feishu, ... -->
## 📸 Evidence (Optional)
<details>
<summary>Click to view Logs/Screenshots</summary>
<!-- Please paste relevant screenshots or logs here -->
</details>
## ☑️ Checklist
- [ ] My code/docs follow the style of this project.
- [ ] I have performed a self-review of my own changes.
- [ ] I have updated the documentation accordingly.
+2 -2
View File
@@ -9,10 +9,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version-file: go.mod
+1 -1
View File
@@ -25,7 +25,7 @@ jobs:
steps:
# ── Checkout ──────────────────────────────
- name: 📥 Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
ref: ${{ inputs.tag }}
+34 -10
View File
@@ -1,17 +1,39 @@
name: pr-check
name: PR
on:
pull_request:
pull_request: { }
jobs:
fmt-check:
lint:
name: Linter
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version-file: go.mod
- name: Run go generate
run: go generate ./...
- name: Golangci Lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.10.1
# TODO: Remove once linter is properly configured
fmt-check:
name: Formatting
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version-file: go.mod
@@ -20,15 +42,17 @@ jobs:
make fmt
git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
# TODO: Remove once linter is properly configured
vet:
name: Vet
runs-on: ubuntu-latest
needs: fmt-check
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version-file: go.mod
@@ -39,14 +63,15 @@ jobs:
run: go vet ./...
test:
name: Tests
runs-on: ubuntu-latest
needs: fmt-check
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version-file: go.mod
@@ -55,4 +80,3 @@ jobs:
- name: Run go test
run: go test ./...
+9 -5
View File
@@ -26,17 +26,19 @@ jobs:
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Create and push tag
shell: bash
env:
RELEASE_TAG: ${{ inputs.tag }}
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}"
git push origin "${{ inputs.tag }}"
git tag -a "$RELEASE_TAG" -m "Release $RELEASE_TAG"
git push origin "$RELEASE_TAG"
release:
name: GoReleaser Release
@@ -47,13 +49,14 @@ jobs:
packages: write
steps:
- name: Checkout tag
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.tag }}
- name: Setup Go from go.mod
uses: actions/setup-go@v5
id: setup-go
uses: actions/setup-go@v6
with:
go-version-file: go.mod
@@ -87,6 +90,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
GOVERSION: ${{ steps.setup-go.outputs.go-version }}
- name: Apply release flags
shell: bash
+184
View File
@@ -0,0 +1,184 @@
version: "2"
linters:
default: all
disable:
# TODO: Tweak for current project needs
- containedctx
- cyclop
- depguard
- dupl
- dupword
- err113
- exhaustruct
- funcorder
- gochecknoglobals
- godot
- intrange
- ireturn
- nlreturn
- noctx
- noinlineerr
- nonamedreturns
- tagliatelle
- testpackage
- varnamelen
- wrapcheck
- wsl
- wsl_v5
# TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step)
- bodyclose
- contextcheck
- dogsled
- embeddedstructfieldcheck
- errcheck
- errchkjson
- errorlint
- exhaustive
- forbidigo
- forcetypeassert
- funlen
- gochecknoinits
- gocognit
- goconst
- gocritic
- gocyclo
- godox
- goprintffuncname
- gosec
- govet
- ineffassign
- lll
- maintidx
- misspell
- mnd
- modernize
- nakedret
- nestif
- nilnil
- paralleltest
- perfsprint
- prealloc
- predeclared
- revive
- staticcheck
- tagalign
- testifylint
- thelper
- unparam
- unused
- usestdlibvars
- usetesting
- wastedassign
- whitespace
settings:
errcheck:
check-type-assertions: true
check-blank: true
exhaustive:
default-signifies-exhaustive: true
funlen:
lines: 120
statements: 40
gocognit:
min-complexity: 25
gocyclo:
min-complexity: 20
govet:
enable-all: true
disable:
- fieldalignment
lll:
line-length: 120
tab-width: 4
misspell:
locale: US
mnd:
checks:
- argument
- assign
- case
- condition
- operation
- return
nakedret:
max-func-lines: 3
revive:
enable-all-rules: true
rules:
- name: add-constant
disabled: true
- name: argument-limit
arguments:
- 7
severity: warning
- name: banned-characters
disabled: true
- name: cognitive-complexity
disabled: true
- name: comment-spacings
arguments:
- nolint
severity: warning
- name: cyclomatic
disabled: true
- name: file-header
disabled: true
- name: function-result-limit
arguments:
- 3
severity: warning
- name: function-length
disabled: true
- name: line-length-limit
disabled: true
- name: max-public-structs
disabled: true
- name: modifies-value-receiver
disabled: true
- name: package-comments
disabled: true
- name: unused-receiver
disabled: true
exclusions:
generated: lax
rules:
- linters:
- lll
source: '^//go:generate '
- linters:
- funlen
- maintidx
- gocognit
- gocyclo
path: _test\.go$
issues:
max-issues-per-linter: 0
max-same-issues: 0
formatters:
enable:
- goimports
# TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step)
# - gci
# - gofmt
# - gofumpt
# - golines
settings:
gci:
sections:
- standard
- default
- localmodule
custom-order: true
gofmt:
simplify: true
rewrite-rules:
- pattern: "interface{}"
replacement: "any"
- pattern: "a[b:len(a)]"
replacement: "a[b:]"
golines:
max-len: 120
+8
View File
@@ -11,6 +11,14 @@ builds:
- id: picoclaw
env:
- CGO_ENABLED=0
tags:
- stdjson
ldflags:
- -s -w
- -X main.version={{ .Version }}
- -X main.gitCommit={{ .ShortCommit }}
- -X main.buildTime={{ .Date }}
- -X main.goVersion={{ .Env.GOVERSION }}
goos:
- linux
- windows
+8 -1
View File
@@ -29,7 +29,14 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
# Copy binary
COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
# Create picoclaw home directory
# Create non-root user and group
RUN addgroup -g 1000 picoclaw && \
adduser -D -u 1000 -G picoclaw picoclaw
# Switch to non-root user
USER picoclaw
# Run onboard to create initial directories and config
RUN /usr/local/bin/picoclaw onboard
ENTRYPOINT ["picoclaw"]
+5 -2
View File
@@ -11,11 +11,11 @@ VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z)
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION) -s -w"
# Go variables
GO?=go
GOFLAGS?=-v
GOFLAGS?=-v -tags stdjson
# Installation
INSTALL_PREFIX?=$(HOME)/.local
@@ -39,6 +39,8 @@ ifeq ($(UNAME_S),Linux)
ARCH=amd64
else ifeq ($(UNAME_M),aarch64)
ARCH=arm64
else ifeq ($(UNAME_M),loongarch64)
ARCH=loong64
else ifeq ($(UNAME_M),riscv64)
ARCH=riscv64
else
@@ -84,6 +86,7 @@ build-all: generate
@mkdir -p $(BUILD_DIR)
GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
+14 -8
View File
@@ -3,7 +3,7 @@
<h1>PicoClaw: Go で書かれた超効率 AI アシスタント</h1>
<h3>$10 ハードウェア · 10MB RAM · 1秒起動 · 皮皮虾,我们走</h3>
<h3>$10 ハードウェア · 10MB RAM · 1秒起動 · 行くぜ、シャコ</h3>
<h3></h3>
<p>
@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
</p>
**日本語** | [English](README.md)
[中文](README.zh.md) | **日本語** | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [English](README.md)
</div>
@@ -39,7 +39,7 @@
</table>
## 📢 ニュース
2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 皮皮虾,我们走
2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 行くぜ、シャコ
## ✨ 特徴
@@ -195,6 +195,9 @@ picoclaw onboard
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
},
"cron": {
"exec_timeout_minutes": 5
}
},
"heartbeat": {
@@ -250,7 +253,7 @@ Telegram、Discord、QQ、DingTalk、LINE で PicoClaw と会話できます
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
"allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -290,7 +293,7 @@ picoclaw gateway
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
"allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -673,7 +676,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
"telegram": {
"enabled": true,
"token": "123456:ABC...",
"allowFrom": ["123456789"]
"allow_from": ["123456789"]
},
"discord": {
"enabled": true,
@@ -689,7 +692,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
"appSecret": "xxx",
"encryptKey": "",
"verificationToken": "",
"allowFrom": []
"allow_from": []
}
},
"tools": {
@@ -697,6 +700,9 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
"search": {
"apiKey": "BSA..."
}
},
"cron": {
"exec_timeout_minutes": 5
}
},
"heartbeat": {
@@ -729,7 +735,7 @@ Discord: https://discord.gg/V4sAZ9XWpN
## 🐛 トラブルシューティング
### Web 検索で「API 配置问题」と表示される
### Web 検索で「API 設定の問題」と表示される
検索 API キーをまだ設定していない場合、これは正常です。PicoClaw は手動検索用の便利なリンクを提供します。
+33 -3
View File
@@ -14,7 +14,7 @@
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
[中文](README.zh.md) | [日本語](README.ja.md) | **English**
[中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | **English**
</div>
---
@@ -45,8 +45,11 @@
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release.
> * **Note:** picoclaw has recently merged a lot of PRs, which may result in a larger memory footprint (1020MB) in the latest versions. We plan to prioritize resource optimization as soon as the current feature set reaches a stable state.
## 📢 News
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/picoclaw_community_roadmap_260216.md) —we cant wait to have you on board!
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
@@ -96,6 +99,20 @@
</tr>
</table>
### 📱 Run on old Android Phones
Give your decade-old phone a second life! Turn it into a smart AI Assistant with PicoClaw. Quick Start:
1. **Install Termux** (Available on F-Droid or Google Play).
2. **Execute cmds**
```bash
# Note: Replace v0.1.1 with the latest version from the Releases page
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
chmod +x picoclaw-linux-arm64
pkg install proot
termux-chroot ./picoclaw-linux-arm64 onboard
```
And then follow the instructions in the "Quick Start" section to complete the configuration!
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
### 🐜 Innovative Low-Footprint Deploy
PicoClaw can be deployed on almost any Linux device!
@@ -266,7 +283,7 @@ Talk to your picoclaw through Telegram, Discord, DingTalk, or LINE
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
"allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -309,7 +326,7 @@ picoclaw gateway
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
"allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -662,6 +679,16 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
### Provider Architecture
PicoClaw routes providers by protocol family:
- OpenAI-compatible protocol: OpenRouter, OpenAI-compatible gateways, Groq, Zhipu, and vLLM-style endpoints.
- Anthropic protocol: Claude-native API behavior.
- Codex/OAuth path: OpenAI OAuth/token authentication route.
This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`).
<details>
<summary><b>Zhipu</b></summary>
@@ -757,6 +784,9 @@ picoclaw agent -m "Hello"
"enabled": true,
"max_results": 5
}
},
"cron": {
"exec_timeout_minutes": 5
}
},
"heartbeat": {
+882
View File
@@ -0,0 +1,882 @@
<div align="center">
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
<h1>PicoClaw: Assistente de IA Ultra-Eficiente em Go</h1>
<h3>Hardware de $10 · 10MB de RAM · Boot em 1s · 皮皮虾,我们走!</h3>
<p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
<br>
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
[中文](README.zh.md) | [日本語](README.ja.md) | [English](README.md) | **Português**
</div>
---
🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [nanobot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código.
⚡️ **Extremamente leve:** Roda em hardware de apenas **$10** com **<10MB** de RAM. Isso é 99% menos memória que o OpenClaw e 98% mais barato que um Mac mini!
<table align="center">
<tr align="center">
<td align="center" valign="top">
<p align="center">
<img src="assets/picoclaw_mem.gif" width="360" height="240">
</p>
</td>
<td align="center" valign="top">
<p align="center">
<img src="assets/licheervnano.png" width="400" height="240">
</p>
</td>
</tr>
</table>
> [!CAUTION]
> **🚨 DECLARAÇÃO DE SEGURANÇA & CANAIS OFICIAIS**
>
> * **SEM CRIPTOMOEDAS:** O PicoClaw **NÃO** possui nenhum token/moeda oficial. Todas as alegações no `pump.fun` ou outras plataformas de negociação são **GOLPES**.
> * **DOMÍNIO OFICIAL:** O **ÚNICO** site oficial é o **[picoclaw.io](https://picoclaw.io)**, e o site da empresa é o **[sipeed.com](https://sipeed.com)**.
> * **Aviso:** Muitos domínios `.ai/.org/.com/.net/...` foram registrados por terceiros, não são nossos.
> * **Aviso:** O PicoClaw está em fase inicial de desenvolvimento e pode ter problemas de segurança de rede não resolvidos. Não implante em ambientes de produção antes da versão v1.0.
> * **Nota:** O PicoClaw recentemente fez merge de muitos PRs, o que pode resultar em maior consumo de memória (10-20MB) nas versões mais recentes. Planejamos priorizar a otimização de recursos assim que o conjunto de funcionalidades estiver estável.
## 📢 Novidades
2026-02-16 🎉 PicoClaw atingiu 12K stars em uma semana! Obrigado a todos pelo apoio! O PicoClaw está crescendo mais rápido do que jamais imaginamos. Dado o alto volume de PRs, precisamos urgentemente de maintainers da comunidade. Nossos papéis de voluntários e roadmap foram publicados oficialmente [aqui](docs/picoclaw_community_roadmap_260216.md) — estamos ansiosos para ter você a bordo!
2026-02-13 🎉 PicoClaw atingiu 5000 stars em 4 dias! Obrigado à comunidade! Estamos finalizando o **Roadmap do Projeto** e configurando o **Grupo de Desenvolvedores** para acelerar o desenvolvimento do PicoClaw.
🚀 **Chamada para Ação:** Envie suas solicitações de funcionalidades nas GitHub Discussions. Revisaremos e priorizaremos na próxima reunião semanal.
2026-02-09 🎉 PicoClaw lançado oficialmente! Construído em 1 dia para trazer Agentes de IA para hardware de $10 com <10MB de RAM. 🦐 PicoClaw, Partiu!
## ✨ Funcionalidades
🪶 **Ultra-Leve**: Consumo de memória <10MB — 99% menor que o Clawdbot para funcionalidades essenciais.
💰 **Custo Mínimo**: Eficiente o suficiente para rodar em hardware de $10 — 98% mais barato que um Mac mini.
⚡️ **Inicialização Relámpago**: Tempo de inicialização 400X mais rápido, boot em 1 segundo mesmo em CPU single-core de 0.6GHz.
🌍 **Portabilidade Real**: Um único binário auto-contido para RISC-V, ARM e x86. Um clique e já era!
🤖 **Auto-Construído por IA**: Implementação nativa em Go de forma autônoma — 95% do núcleo gerado pelo Agente com refinamento humano no loop.
| | OpenClaw | NanoBot | **PicoClaw** |
| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- |
| **Linguagem** | TypeScript | Python | **Go** |
| **RAM** | >1GB | >100MB | **< 10MB** |
| **Inicialização**</br>(CPU 0.8GHz) | >500s | >30s | **<1s** |
| **Custo** | Mac Mini $599 | Maioria dos SBC Linux </br>~$50 | **Qualquer placa Linux**</br>**A partir de $10** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 Demonstração
### 🛠️ Fluxos de Trabalho Padrão do Assistente
<table align="center">
<tr align="center">
<th><p align="center">🧩 Engenharia Full-Stack</p></th>
<th><p align="center">🗂️ Gerenciamento de Logs & Planejamento</p></th>
<th><p align="center">🔎 Busca Web & Aprendizado</p></th>
</tr>
<tr>
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
</tr>
<tr>
<td align="center">Desenvolver • Implantar • Escalar</td>
<td align="center">Agendar • Automatizar • Memorizar</td>
<td align="center">Descobrir • Analisar • Tendências</td>
</tr>
</table>
### 📱 Rode em celulares Android antigos
Dê uma segunda vida ao seu celular de dez anos atrás! Transforme-o em um assistente de IA inteligente com o PicoClaw. Início rápido:
1. **Instale o Termux** (Disponível no F-Droid ou Google Play).
2. **Execute os comandos**
```bash
# Nota: Substitua v0.1.1 pela versao mais recente da pagina de Releases
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
chmod +x picoclaw-linux-arm64
pkg install proot
termux-chroot ./picoclaw-linux-arm64 onboard
```
Depois siga as instruções na seção "Início Rápido" para completar a configuração!
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
### 🐜 Implantação Inovadora com Baixo Consumo
O PicoClaw pode ser implantado em praticamente qualquer dispositivo Linux!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) versão E (Ethernet) ou W (WiFi6), para Assistente Doméstico Minimalista
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) para Manutenção Automatizada de Servidores
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) para Monitoramento Inteligente
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
🌟 Mais cenários de implantação aguardam você!
## 📦 Instalação
### Instalar com binário pré-compilado
Baixe o binário para sua plataforma na página de [releases](https://github.com/sipeed/picoclaw/releases).
### Instalar a partir do código-fonte (funcionalidades mais recentes, recomendado para desenvolvimento)
```bash
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
make deps
# Build, sem necessidade de instalar
make build
# Build para multiplas plataformas
make build-all
# Build e Instalar
make install
```
## 🐳 Docker Compose
Você tambêm pode rodar o PicoClaw usando Docker Compose sem instalar nada localmente.
```bash
# 1. Clone este repositorio
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
# 2. Configure suas API keys
cp config/config.example.json config/config.json
vim config/config.json # Configure DISCORD_BOT_TOKEN, API keys, etc.
# 3. Build & Iniciar
docker compose --profile gateway up -d
# 4. Ver logs
docker compose logs -f picoclaw-gateway
# 5. Parar
docker compose --profile gateway down
```
### Modo Agente (Execução única)
```bash
# Fazer uma pergunta
docker compose run --rm picoclaw-agent -m "Quanto e 2+2?"
# Modo interativo
docker compose run --rm picoclaw-agent
```
### Rebuild
```bash
docker compose --profile gateway build --no-cache
docker compose --profile gateway up -d
```
### 🚀 Início Rápido
> [!TIP]
> Configure sua API key em `~/.picoclaw/config.json`.
> Obtenha API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
> Busca web e **opcional** — obtenha a [Brave Search API](https://brave.com/search/api) gratuita (2000 consultas grátis/mês) ou use o fallback automático integrado.
**1. Inicializar**
```bash
picoclaw onboard
```
**2. Configurar** (`~/.picoclaw/config.json`)
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"openrouter": {
"api_key": "xxx",
"api_base": "https://openrouter.ai/api/v1"
}
},
"tools": {
"web": {
"brave": {
"enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
}
}
```
**3. Obter API Keys**
* **Provedor de LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
* **Busca Web** (opcional): [Brave Search](https://brave.com/search/api) - Plano gratuito disponível (2000 consultas/mês)
> **Nota**: Veja `config.example.json` para um modelo de configuração completo.
**4. Conversar**
```bash
picoclaw agent -m "Quanto e 2+2?"
```
Pronto! Você tem um assistente de IA funcionando em 2 minutos.
---
## 💬 Integração com Apps de Chat
Converse com seu PicoClaw via Telegram, Discord, DingTalk ou LINE.
| Canal | Nível de Configuração |
| --- | --- |
| **Telegram** | Fácil (apenas um token) |
| **Discord** | Fácil (bot token + intents) |
| **QQ** | Fácil (AppID + AppSecret) |
| **DingTalk** | Médio (credenciais do app) |
| **LINE** | Médio (credenciais + webhook URL) |
<details>
<summary><b>Telegram</b> (Recomendado)</summary>
**1. Criar o bot**
* Abra o Telegram, busque `@BotFather`
* Envie `/newbot`, siga as instruções
* Copie o token
**2. Configurar**
```json
{
"channels": {
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
> Obtenha seu User ID pelo `@userinfobot` no Telegram.
**3. Executar**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>Discord</b></summary>
**1. Criar o bot**
* Acesse <https://discord.com/developers/applications>
* Crie um aplicativo → Bot → Add Bot
* Copie o token do bot
**2. Habilitar Intents**
* Nas configurações do Bot, habilite **MESSAGE CONTENT INTENT**
* (Opcional) Habilite **SERVER MEMBERS INTENT** se quiser usar lista de permissões baseada em dados dos membros
**3. Obter seu User ID**
* Configurações do Discord → Avançado → habilite **Modo Desenvolvedor**
* Clique com botão direito no seu avatar → **Copiar ID do Usuário**
**4. Configurar**
```json
{
"channels": {
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
**5. Convidar o bot**
* OAuth2 → URL Generator
* Scopes: `bot`
* Bot Permissions: `Send Messages`, `Read Message History`
* Abra a URL de convite gerada e adicione o bot ao seu servidor
**6. Executar**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>QQ</b></summary>
**1. Criar o bot**
- Acesse a [QQ Open Platform](https://q.qq.com/#)
- Crie um aplicativo → Obtenha **AppID** e **AppSecret**
**2. Configurar**
```json
{
"channels": {
"qq": {
"enabled": true,
"app_id": "YOUR_APP_ID",
"app_secret": "YOUR_APP_SECRET",
"allow_from": []
}
}
}
```
> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique números QQ para restringir o acesso.
**3. Executar**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>DingTalk</b></summary>
**1. Criar o bot**
* Acesse a [Open Platform](https://open.dingtalk.com/)
* Crie um app interno
* Copie o Client ID e Client Secret
**2. Configurar**
```json
{
"channels": {
"dingtalk": {
"enabled": true,
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
}
}
}
```
> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique IDs para restringir o acesso.
**3. Executar**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>LINE</b></summary>
**1. Criar uma Conta Oficial LINE**
- Acesse o [LINE Developers Console](https://developers.line.biz/)
- Crie um provider → Crie um canal Messaging API
- Copie o **Channel Secret** e o **Channel Access Token**
**2. Configurar**
```json
{
"channels": {
"line": {
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
"webhook_host": "0.0.0.0",
"webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
}
}
```
**3. Configurar URL do Webhook**
O LINE requer HTTPS para webhooks. Use um reverse proxy ou tunnel:
```bash
# Exemplo com ngrok
ngrok http 18791
```
Em seguida, configure a Webhook URL no LINE Developers Console para `https://seu-dominio/webhook/line` e habilite **Use webhook**.
**4. Executar**
```bash
picoclaw gateway
```
> Em chats de grupo, o bot responde apenas quando mencionado com @. As respostas citam a mensagem original.
> **Docker Compose**: Adicione `ports: ["18791:18791"]` ao serviço `picoclaw-gateway` para expor a porta do webhook.
</details>
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> Junte-se a Rede Social de Agentes
Conecte o PicoClaw a Rede Social de Agentes simplesmente enviando uma única mensagem via CLI ou qualquer App de Chat integrado.
**Leia `https://clawdchat.ai/skill.md` e siga as instruções para entrar no [ClawdChat.ai](https://clawdchat.ai)**
## ⚙️ Configuração Detalhada
Arquivo de configuração: `~/.picoclaw/config.json`
### Estrutura do Workspace
O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/workspace`):
```
~/.picoclaw/workspace/
├── sessions/ # Sessoes de conversa e historico
├── memory/ # Memoria de longo prazo (MEMORY.md)
├── state/ # Estado persistente (ultimo canal, etc.)
├── cron/ # Banco de dados de tarefas agendadas
├── skills/ # Skills personalizadas
├── AGENTS.md # Guia de comportamento do Agente
├── HEARTBEAT.md # Prompts de tarefas periodicas (verificado a cada 30 min)
├── IDENTITY.md # Identidade do Agente
├── SOUL.md # Alma do Agente
├── TOOLS.md # Descrição das ferramentas
└── USER.md # Preferencias do usuario
```
### 🔒 Sandbox de Segurança
O PicoClaw roda em um ambiente sandbox por padrão. O agente so pode acessar arquivos e executar comandos dentro do workspace configurado.
#### Configuração Padrão
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true
}
}
}
```
| Opção | Padrão | Descrição |
|-------|--------|-----------|
| `workspace` | `~/.picoclaw/workspace` | Diretório de trabalho do agente |
| `restrict_to_workspace` | `true` | Restringir acesso de arquivos/comandos ao workspace |
#### Ferramentas Protegidas
Quando `restrict_to_workspace: true`, as seguintes ferramentas são restritas ao sandbox:
| Ferramenta | Função | Restrição |
|------------|--------|-----------|
| `read_file` | Ler arquivos | Apenas arquivos dentro do workspace |
| `write_file` | Escrever arquivos | Apenas arquivos dentro do workspace |
| `list_dir` | Listar diretorios | Apenas diretorios dentro do workspace |
| `edit_file` | Editar arquivos | Apenas arquivos dentro do workspace |
| `append_file` | Adicionar a arquivos | Apenas arquivos dentro do workspace |
| `exec` | Executar comandos | Caminhos dos comandos devem estar dentro do workspace |
#### Proteção Adicional do Exec
Mesmo com `restrict_to_workspace: false`, a ferramenta `exec` bloqueia estes comandos perigosos:
* `rm -rf`, `del /f`, `rmdir /s` — Exclusão em massa
* `format`, `mkfs`, `diskpart` — Formatação de disco
* `dd if=` — Criação de imagem de disco
* Escrita em `/dev/sd[a-z]` — Escrita direta no disco
* `shutdown`, `reboot`, `poweroff` — Desligamento do sistema
* Fork bomb `:(){ :|:& };:`
#### Exemplos de Erro
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
```
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
```
#### Desabilitar Restrições (Risco de Segurança)
Se você precisa que o agente acesse caminhos fora do workspace:
**Método 1: Arquivo de configuração**
```json
{
"agents": {
"defaults": {
"restrict_to_workspace": false
}
}
}
```
**Método 2: Variável de ambiente**
```bash
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
```
> ⚠️ **Aviso**: Desabilitar esta restrição permite que o agente acesse qualquer caminho no seu sistema. Use com cuidado apenas em ambientes controlados.
#### Consistência do Limite de Segurança
A configuração `restrict_to_workspace` se aplica consistentemente em todos os caminhos de execução:
| Caminho de Execução | Limite de Segurança |
|----------------------|---------------------|
| Agente Principal | `restrict_to_workspace` ✅ |
| Subagente / Spawn | Herda a mesma restrição ✅ |
| Tarefas Heartbeat | Herda a mesma restrição ✅ |
Todos os caminhos compartilham a mesma restrição de workspace — nao há como contornar o limite de segurança por meio de subagentes ou tarefas agendadas.
### Heartbeat (Tarefas Periódicas)
O PicoClaw pode executar tarefas periódicas automaticamente. Crie um arquivo `HEARTBEAT.md` no seu workspace:
```markdown
# Tarefas Periodicas
- Verificar meu email para mensagens importantes
- Revisar minha agenda para proximos eventos
- Verificar a previsao do tempo
```
O agente lerá este arquivo a cada 30 minutos (configurável) e executará as tarefas usando as ferramentas disponíveis.
#### Tarefas Assincronas com Spawn
Para tarefas de longa duração (busca web, chamadas de API), use a ferramenta `spawn` para criar um **subagente**:
```markdown
# Tarefas Periódicas
## Tarefas Rápidas (resposta direta)
- Informar hora atual
## Tarefas Longas (usar spawn para async)
- Buscar notícias de IA na web e resumir
- Verificar email e reportar mensagens importantes
```
**Comportamentos principais:**
| Funcionalidade | Descrição |
|----------------|-----------|
| **spawn** | Cria subagente assíncrono, não bloqueia o heartbeat |
| **Contexto independente** | Subagente tem seu próprio contexto, sem histórico de sessão |
| **Ferramenta message** | Subagente se comunica diretamente com o usuário via ferramenta message |
| **Não-bloqueante** | Após o spawn, o heartbeat continua para a próxima tarefa |
#### Como Funciona a Comunicação do Subagente
```
Heartbeat dispara
Agente lê HEARTBEAT.md
Para tarefa longa: spawn subagente
↓ ↓
Continua próxima tarefa Subagente trabalha independentemente
↓ ↓
Todas tarefas concluídas Subagente usa ferramenta "message"
↓ ↓
Responde HEARTBEAT_OK Usuário recebe resultado diretamente
```
O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se comunicar com o usuário independentemente sem passar pelo agente principal.
**Configuração:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| Opção | Padrão | Descrição |
|-------|--------|-----------|
| `enabled` | `true` | Habilitar/desabilitar heartbeat |
| `interval` | `30` | Intervalo de verificação em minutos (min: 5) |
**Variáveis de ambiente:**
* `PICOCLAW_HEARTBEAT_ENABLED=false` para desabilitar
* `PICOCLAW_HEARTBEAT_INTERVAL=60` para alterar o intervalo
### Provedores
> [!NOTE]
> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de voz do Telegram serão automaticamente transcritas.
| Provedor | Finalidade | Obter API Key |
| --- | --- | --- |
| `gemini` | LLM (Gemini direto) | [aistudio.google.com](https://aistudio.google.com) |
| `zhipu` | LLM (Zhipu direto) | [bigmodel.cn](bigmodel.cn) |
| `openrouter` (Em teste) | LLM (recomendado, acesso a todos os modelos) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` (Em teste) | LLM (Claude direto) | [console.anthropic.com](https://console.anthropic.com) |
| `openai` (Em teste) | LLM (GPT direto) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` (Em teste) | LLM (DeepSeek direto) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Transcrição de voz** (Whisper) | [console.groq.com](https://console.groq.com) |
<details>
<summary><b>Configuração Zhipu</b></summary>
**1. Obter API key**
* Obtenha a [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. Configurar**
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"zhipu": {
"api_key": "Sua API Key",
"api_base": "https://open.bigmodel.cn/api/paas/v4"
}
}
}
```
**3. Executar**
```bash
picoclaw agent -m "Ola, como vai?"
```
</details>
<details>
<summary><b>Exemplo de configuraçao completa</b></summary>
```json
{
"agents": {
"defaults": {
"model": "anthropic/claude-opus-4-5"
}
},
"providers": {
"openrouter": {
"api_key": "sk-or-v1-xxx"
},
"groq": {
"api_key": "gsk_xxx"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "123456:ABC...",
"allow_from": ["123456789"]
},
"discord": {
"enabled": true,
"token": "",
"allow_from": [""]
},
"whatsapp": {
"enabled": false
},
"feishu": {
"enabled": false,
"app_id": "cli_xxx",
"app_secret": "xxx",
"encrypt_key": "",
"verification_token": "",
"allow_from": []
},
"qq": {
"enabled": false,
"app_id": "",
"app_secret": "",
"allow_from": []
}
},
"tools": {
"web": {
"brave": {
"enabled": false,
"api_key": "BSA...",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
},
"cron": {
"exec_timeout_minutes": 5
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
</details>
## Referência CLI
| Comando | Descrição |
| --- | --- |
| `picoclaw onboard` | Inicializar configuração & workspace |
| `picoclaw agent -m "..."` | Conversar com o agente |
| `picoclaw agent` | Modo de chat interativo |
| `picoclaw gateway` | Iniciar o gateway (para bots de chat) |
| `picoclaw status` | Mostrar status |
| `picoclaw cron list` | Listar todas as tarefas agendadas |
| `picoclaw cron add ...` | Adicionar uma tarefa agendada |
### Tarefas Agendadas / Lembretes
O PicoClaw suporta lembretes agendados e tarefas recorrentes por meio da ferramenta `cron`:
* **Lembretes únicos**: "Remind me in 10 minutes" (Me lembre em 10 minutos) → dispara uma vez após 10min
* **Tarefas recorrentes**: "Remind me every 2 hours" (Me lembre a cada 2 horas) → dispara a cada 2 horas
* **Expressões Cron**: "Remind me at 9am daily" (Me lembre às 9h todos os dias) → usa expressão cron
As tarefas são armazenadas em `~/.picoclaw/workspace/cron/` e processadas automaticamente.
## 🤝 Contribuir & Roadmap
PRs são bem-vindos! O código-fonte é intencionalmente pequeno e legível. 🤗
Roadmap em breve...
Grupo de desenvolvedores em formação. Requisito de entrada: Pelo menos 1 PR com merge.
Grupos de usuários:
Discord: <https://discord.gg/V4sAZ9XWpN>
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 Solução de Problemas
### Busca web mostra "API 配置问题"
Isso é normal se você ainda não configurou uma API key de busca. O PicoClaw fornecerá links úteis para busca manual.
Para habilitar a busca web:
1. **Opção 1 (Recomendado)**: Obtenha uma API key gratuita em [https://brave.com/search/api](https://brave.com/search/api) (2000 consultas grátis/mês) para os melhores resultados.
2. **Opção 2 (Sem Cartão de Crédito)**: Se você não tem uma key, o sistema automaticamente usa o **DuckDuckGo** como fallback (sem necessidade de key).
Adicione a key em `~/.picoclaw/config.json` se usar o Brave:
```json
{
"tools": {
"web": {
"brave": {
"enabled": true,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
}
}
```
### Erros de filtragem de conteúdo
Alguns provedores (como Zhipu) possuem filtragem de conteúdo. Tente reformular sua pergunta ou use um modelo diferente.
### Bot do Telegram diz "Conflict: terminated by other getUpdates"
Isso acontece quando outra instância do bot está em execução. Certifique-se de que apenas um `picoclaw gateway` esteja rodando por vez.
---
## 📝 Comparação de API Keys
| Serviço | Plano Gratuito | Caso de Uso |
| --- | --- | --- |
| **OpenRouter** | 200K tokens/mês | Múltiplos modelos (Claude, GPT-4, etc.) |
| **Zhipu** | 200K tokens/mês | Melhor para usuários chineses |
| **Brave Search** | 2000 consultas/mês | Funcionalidade de busca web |
| **Groq** | Plano gratuito disponível | Inferência ultra-rápida (Llama, Mixtral) |
+859
View File
@@ -0,0 +1,859 @@
<div align="center">
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
<h1>PicoClaw: Trợ lý AI Siêu Nhẹ viết bằng Go</h1>
<h3>Phần cứng $10 · RAM 10MB · Khởi động 1 giây · 皮皮虾,我们走!</h3>
<p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
<br>
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
**Tiếng Việt** | [中文](README.zh.md) | [日本語](README.ja.md) | [English](README.md)
</div>
---
🦐 **PicoClaw** là trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [nanobot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn.
⚡️ **Cực kỳ nhẹ:** Chạy trên phần cứng chỉ **$10** với RAM **<10MB**. Tiết kiệm 99% bộ nhớ so với OpenClaw và rẻ hơn 98% so với Mac mini!
<table align="center">
<tr align="center">
<td align="center" valign="top">
<p align="center">
<img src="assets/picoclaw_mem.gif" width="360" height="240">
</p>
</td>
<td align="center" valign="top">
<p align="center">
<img src="assets/licheervnano.png" width="400" height="240">
</p>
</td>
</tr>
</table>
> [!CAUTION]
> **🚨 TUYÊN BỐ BẢO MẬT & KÊNH CHÍNH THỨC**
>
> * **KHÔNG CÓ CRYPTO:** PicoClaw **KHÔNG** có bất kỳ token/coin chính thức nào. Mọi thông tin trên `pump.fun` hoặc các sàn giao dịch khác đều là **LỪA ĐẢO**.
> * **DOMAIN CHÍNH THỨC:** Website chính thức **DUY NHẤT** là **[picoclaw.io](https://picoclaw.io)**, website công ty là **[sipeed.com](https://sipeed.com)**.
> * **Cảnh báo:** Nhiều tên miền `.ai/.org/.com/.net/...` đã bị bên thứ ba đăng ký, không phải của chúng tôi.
> * **Cảnh báo:** PicoClaw đang trong giai đoạn phát triển sớm và có thể còn các vấn đề bảo mật mạng chưa được giải quyết. Không nên triển khai lên môi trường production trước phiên bản v1.0.
> * **Lưu ý:** PicoClaw gần đây đã merge nhiều PR, dẫn đến bộ nhớ sử dụng có thể lớn hơn (1020MB) ở các phiên bản mới nhất. Chúng tôi sẽ ưu tiên tối ưu tài nguyên khi bộ tính năng đã ổn định.
## 📢 Tin tức
2026-02-16 🎉 PicoClaw đạt 12K stars chỉ trong một tuần! Cảm ơn tất cả mọi người! PicoClaw đang phát triển nhanh hơn chúng tôi tưởng tượng. Do số lượng PR tăng cao, chúng tôi cấp thiết cần maintainer từ cộng đồng. Các vai trò tình nguyện viên và roadmap đã được công bố [tại đây](docs/picoclaw_community_roadmap_260216.md) — rất mong đón nhận sự tham gia của bạn!
2026-02-13 🎉 PicoClaw đạt 5000 stars trong 4 ngày! Cảm ơn cộng đồng! Chúng tôi đang hoàn thiện **Lộ trình dự án (Roadmap)** và thiết lập **Nhóm phát triển** để đẩy nhanh tốc độ phát triển PicoClaw.
🚀 **Kêu gọi hành động:** Vui lòng gửi yêu cầu tính năng tại GitHub Discussions. Chúng tôi sẽ xem xét và ưu tiên trong cuộc họp hàng tuần.
2026-02-09 🎉 PicoClaw chính thức ra mắt! Được xây dựng trong 1 ngày để mang AI Agent đến phần cứng $10 với RAM <10MB. 🦐 PicoClaw, Lên Đường!
## ✨ Tính năng nổi bật
🪶 **Siêu nhẹ**: Bộ nhớ sử dụng <10MB — nhỏ hơn 99% so với Clawdbot (chức năng cốt lõi).
💰 **Chi phí tối thiểu**: Đủ hiệu quả để chạy trên phần cứng $10 — rẻ hơn 98% so với Mac mini.
⚡️ **Khởi động siêu nhanh**: Nhanh gấp 400 lần, khởi động trong 1 giây ngay cả trên CPU đơn nhân 0.6GHz.
🌍 **Di động thực sự**: Một file binary duy nhất chạy trên RISC-V, ARM và x86. Một click là chạy!
🤖 **AI tự xây dựng**: Triển khai Go-native tự động — 95% mã nguồn cốt lõi được Agent tạo ra, với sự tinh chỉnh của con người.
| | OpenClaw | NanoBot | **PicoClaw** |
| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- |
| **Ngôn ngữ** | TypeScript | Python | **Go** |
| **RAM** | >1GB | >100MB | **< 10MB** |
| **Thời gian khởi động**</br>(CPU 0.8GHz) | >500s | >30s | **<1s** |
| **Chi phí** | Mac Mini $599 | Hầu hết SBC Linux ~$50 | **Mọi bo mạch Linux**</br>**Chỉ từ $10** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 Demo
### 🛠️ Quy trình trợ lý tiêu chuẩn
<table align="center">
<tr align="center">
<th><p align="center">🧩 Lập trình Full-Stack</p></th>
<th><p align="center">🗂️ Quản lý Nhật ký & Kế hoạch</p></th>
<th><p align="center">🔎 Tìm kiếm Web & Học hỏi</p></th>
</tr>
<tr>
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
</tr>
<tr>
<td align="center">Phát triển • Triển khai • Mở rộng</td>
<td align="center">Lên lịch • Tự động hóa • Ghi nhớ</td>
<td align="center">Khám phá • Phân tích • Xu hướng</td>
</tr>
</table>
### 🐜 Triển khai sáng tạo trên phần cứng tối thiểu
PicoClaw có thể triển khai trên hầu hết mọi thiết bị Linux!
* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) phiên bản E (Ethernet) hoặc W (WiFi6), dùng làm Trợ lý Gia đình tối giản.
* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), hoặc $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html), dùng cho quản trị Server tự động.
* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) hoặc $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera), dùng cho Giám sát thông minh.
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
🌟 Nhiều hình thức triển khai hơn đang chờ bạn khám phá!
## 📦 Cài đặt
### Cài đặt bằng binary biên dịch sẵn
Tải file binary cho nền tảng của bạn từ [trang Release](https://github.com/sipeed/picoclaw/releases).
### Cài đặt từ mã nguồn (có tính năng mới nhất, khuyên dùng cho phát triển)
```bash
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
make deps
# Build (không cần cài đặt)
make build
# Build cho nhiều nền tảng
make build-all
# Build và cài đặt
make install
```
## 🐳 Docker Compose
Bạn cũng có thể chạy PicoClaw bằng Docker Compose mà không cần cài đặt gì trên máy.
```bash
# 1. Clone repo
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
# 2. Thiết lập API Key
cp config/config.example.json config/config.json
vim config/config.json # Thiết lập DISCORD_BOT_TOKEN, API keys, v.v.
# 3. Build & Khởi động
docker compose --profile gateway up -d
# 4. Xem logs
docker compose logs -f picoclaw-gateway
# 5. Dừng
docker compose --profile gateway down
```
### Chế độ Agent (chạy một lần)
```bash
# Đặt câu hỏi
docker compose run --rm picoclaw-agent -m "2+2 bằng mấy?"
# Chế độ tương tác
docker compose run --rm picoclaw-agent
```
### Build lại
```bash
docker compose --profile gateway build --no-cache
docker compose --profile gateway up -d
```
### 🚀 Bắt đầu nhanh
> [!TIP]
> Thiết lập API key trong `~/.picoclaw/config.json`.
> Lấy API key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
> Tìm kiếm web là **tùy chọn** — lấy [Brave Search API](https://brave.com/search/api) miễn phí (2000 truy vấn/tháng) hoặc dùng tính năng auto fallback tích hợp sẵn.
**1. Khởi tạo**
```bash
picoclaw onboard
```
**2. Cấu hình** (`~/.picoclaw/config.json`)
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"openrouter": {
"api_key": "xxx",
"api_base": "https://openrouter.ai/api/v1"
}
},
"tools": {
"web": {
"brave": {
"enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
}
}
```
**3. Lấy API Key**
* **Nhà cung cấp LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
* **Tìm kiếm Web** (tùy chọn): [Brave Search](https://brave.com/search/api) — Có gói miễn phí (2000 truy vấn/tháng)
> **Lưu ý**: Xem `config.example.json` để có mẫu cấu hình đầy đủ.
**4. Trò chuyện**
```bash
picoclaw agent -m "Xin chào, bạn là ai?"
```
Vậy là xong! Bạn đã có một trợ lý AI hoạt động chỉ trong 2 phút.
---
## 💬 Tích hợp ứng dụng Chat
Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk hoặc LINE.
| Kênh | Mức độ thiết lập |
| --- | --- |
| **Telegram** | Dễ (chỉ cần token) |
| **Discord** | Dễ (bot token + intents) |
| **QQ** | Dễ (AppID + AppSecret) |
| **DingTalk** | Trung bình (app credentials) |
| **LINE** | Trung bình (credentials + webhook URL) |
<details>
<summary><b>Telegram</b> (Khuyên dùng)</summary>
**1. Tạo bot**
* Mở Telegram, tìm `@BotFather`
* Gửi `/newbot`, làm theo hướng dẫn
* Sao chép token
**2. Cấu hình**
```json
{
"channels": {
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allow_from": ["YOUR_USER_ID"]
}
}
}
```
> Lấy User ID từ `@userinfobot` trên Telegram.
**3. Chạy**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>Discord</b></summary>
**1. Tạo bot**
* Truy cập <https://discord.com/developers/applications>
* Create an application → Bot → Add Bot
* Sao chép bot token
**2. Bật Intents**
* Trong phần Bot settings, bật **MESSAGE CONTENT INTENT**
* (Tùy chọn) Bật **SERVER MEMBERS INTENT** nếu muốn dùng danh sách cho phép theo thông tin thành viên
**3. Lấy User ID**
* Discord Settings → Advanced → bật **Developer Mode**
* Click chuột phải vào avatar → **Copy User ID**
**4. Cấu hình**
```json
{
"channels": {
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allow_from": ["YOUR_USER_ID"]
}
}
}
```
**5. Mời bot vào server**
* OAuth2 → URL Generator
* Scopes: `bot`
* Bot Permissions: `Send Messages`, `Read Message History`
* Mở URL mời được tạo và thêm bot vào server của bạn
**6. Chạy**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>QQ</b></summary>
**1. Tạo bot**
* Truy cập [QQ Open Platform](https://q.qq.com/#)
* Tạo ứng dụng → Lấy **AppID****AppSecret**
**2. Cấu hình**
```json
{
"channels": {
"qq": {
"enabled": true,
"app_id": "YOUR_APP_ID",
"app_secret": "YOUR_APP_SECRET",
"allow_from": []
}
}
}
```
> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định số QQ để giới hạn quyền truy cập.
**3. Chạy**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>DingTalk</b></summary>
**1. Tạo bot**
* Truy cập [Open Platform](https://open.dingtalk.com/)
* Tạo ứng dụng nội bộ
* Sao chép Client ID và Client Secret
**2. Cấu hình**
```json
{
"channels": {
"dingtalk": {
"enabled": true,
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
}
}
}
```
> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định ID để giới hạn quyền truy cập.
**3. Chạy**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>LINE</b></summary>
**1. Tạo tài khoản LINE Official**
- Truy cập [LINE Developers Console](https://developers.line.biz/)
- Tạo provider → Tạo Messaging API channel
- Sao chép **Channel Secret****Channel Access Token**
**2. Cấu hình**
```json
{
"channels": {
"line": {
"enabled": true,
"channel_secret": "YOUR_CHANNEL_SECRET",
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
"webhook_host": "0.0.0.0",
"webhook_port": 18791,
"webhook_path": "/webhook/line",
"allow_from": []
}
}
}
```
**3. Thiết lập Webhook URL**
LINE yêu cầu HTTPS cho webhook. Sử dụng reverse proxy hoặc tunnel:
```bash
# Ví dụ với ngrok
ngrok http 18791
```
Sau đó cài đặt Webhook URL trong LINE Developers Console thành `https://your-domain/webhook/line` và bật **Use webhook**.
**4. Chạy**
```bash
picoclaw gateway
```
> Trong nhóm chat, bot chỉ phản hồi khi được @mention. Các câu trả lời sẽ trích dẫn tin nhắn gốc.
> **Docker Compose**: Thêm `ports: ["18791:18791"]` vào service `picoclaw-gateway` để mở port webhook.
</details>
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> Tham gia Mạng xã hội Agent
Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một tin nhắn qua CLI hoặc bất kỳ ứng dụng Chat nào đã tích hợp.
**Đọc `https://clawdchat.ai/skill.md` và làm theo hướng dẫn để tham gia [ClawdChat.ai](https://clawdchat.ai)**
## ⚙️ Cấu hình chi tiết
File cấu hình: `~/.picoclaw/config.json`
### Cấu trúc Workspace
PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định: `~/.picoclaw/workspace`):
```
~/.picoclaw/workspace/
├── sessions/ # Phiên hội thoại và lịch sử
├── memory/ # Bộ nhớ dài hạn (MEMORY.md)
├── state/ # Trạng thái lưu trữ (kênh cuối cùng, v.v.)
├── cron/ # Cơ sở dữ liệu tác vụ định kỳ
├── skills/ # Kỹ năng tùy chỉnh
├── AGENTS.md # Hướng dẫn hành vi Agent
├── HEARTBEAT.md # Prompt tác vụ định kỳ (kiểm tra mỗi 30 phút)
├── IDENTITY.md # Danh tính Agent
├── SOUL.md # Tâm hồn/Tính cách Agent
├── TOOLS.md # Mô tả công cụ
└── USER.md # Tùy chọn người dùng
```
### 🔒 Hộp cát bảo mật (Security Sandbox)
PicoClaw chạy trong môi trường sandbox theo mặc định. Agent chỉ có thể truy cập file và thực thi lệnh trong phạm vi workspace.
#### Cấu hình mặc định
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true
}
}
}
```
| Tùy chọn | Mặc định | Mô tả |
|----------|---------|-------|
| `workspace` | `~/.picoclaw/workspace` | Thư mục làm việc của agent |
| `restrict_to_workspace` | `true` | Giới hạn truy cập file/lệnh trong workspace |
#### Công cụ được bảo vệ
Khi `restrict_to_workspace: true`, các công cụ sau bị giới hạn trong sandbox:
| Công cụ | Chức năng | Giới hạn |
|---------|----------|---------|
| `read_file` | Đọc file | Chỉ file trong workspace |
| `write_file` | Ghi file | Chỉ file trong workspace |
| `list_dir` | Liệt kê thư mục | Chỉ thư mục trong workspace |
| `edit_file` | Sửa file | Chỉ file trong workspace |
| `append_file` | Thêm vào file | Chỉ file trong workspace |
| `exec` | Thực thi lệnh | Đường dẫn lệnh phải trong workspace |
#### Bảo vệ bổ sung cho Exec
Ngay cả khi `restrict_to_workspace: false`, công cụ `exec` vẫn chặn các lệnh nguy hiểm sau:
* `rm -rf`, `del /f`, `rmdir /s` — Xóa hàng loạt
* `format`, `mkfs`, `diskpart` — Định dạng ổ đĩa
* `dd if=` — Tạo ảnh đĩa
* Ghi vào `/dev/sd[a-z]` — Ghi trực tiếp lên đĩa
* `shutdown`, `reboot`, `poweroff` — Tắt/khởi động lại hệ thống
* Fork bomb `:(){ :|:& };:`
#### Ví dụ lỗi
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
```
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
```
#### Tắt giới hạn (Rủi ro bảo mật)
Nếu bạn cần agent truy cập đường dẫn ngoài workspace:
**Cách 1: File cấu hình**
```json
{
"agents": {
"defaults": {
"restrict_to_workspace": false
}
}
}
```
**Cách 2: Biến môi trường**
```bash
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
```
> ⚠️ **Cảnh báo**: Tắt giới hạn này cho phép agent truy cập mọi đường dẫn trên hệ thống. Chỉ sử dụng cẩn thận trong môi trường được kiểm soát.
#### Tính nhất quán của ranh giới bảo mật
Cài đặt `restrict_to_workspace` áp dụng nhất quán trên mọi đường thực thi:
| Đường thực thi | Ranh giới bảo mật |
|----------------|-------------------|
| Agent chính | `restrict_to_workspace` ✅ |
| Subagent / Spawn | Kế thừa cùng giới hạn ✅ |
| Tác vụ Heartbeat | Kế thừa cùng giới hạn ✅ |
Tất cả đường thực thi chia sẻ cùng giới hạn workspace — không có cách nào vượt qua ranh giới bảo mật thông qua subagent hoặc tác vụ định kỳ.
### Heartbeat (Tác vụ định kỳ)
PicoClaw có thể tự động thực hiện các tác vụ định kỳ. Tạo file `HEARTBEAT.md` trong workspace:
```markdown
# Tác vụ định kỳ
- Kiểm tra email xem có tin nhắn quan trọng không
- Xem lại lịch cho các sự kiện sắp tới
- Kiểm tra dự báo thời tiết
```
Agent sẽ đọc file này mỗi 30 phút (có thể cấu hình) và thực hiện các tác vụ bằng công cụ có sẵn.
#### Tác vụ bất đồng bộ với Spawn
Đối với các tác vụ chạy lâu (tìm kiếm web, gọi API), sử dụng công cụ `spawn` để tạo **subagent**:
```markdown
# Tác vụ định kỳ
## Tác vụ nhanh (trả lời trực tiếp)
- Báo cáo thời gian hiện tại
## Tác vụ lâu (dùng spawn cho async)
- Tìm kiếm tin tức AI trên web và tóm tắt
- Kiểm tra email và báo cáo tin nhắn quan trọng
```
**Hành vi chính:**
| Tính năng | Mô tả |
|-----------|-------|
| **spawn** | Tạo subagent bất đồng bộ, không chặn heartbeat |
| **Context độc lập** | Subagent có context riêng, không có lịch sử phiên |
| **message tool** | Subagent giao tiếp trực tiếp với người dùng qua công cụ message |
| **Không chặn** | Sau khi spawn, heartbeat tiếp tục tác vụ tiếp theo |
#### Cách Subagent giao tiếp
```
Heartbeat kích hoạt
Agent đọc HEARTBEAT.md
Tác vụ lâu: spawn subagent
↓ ↓
Tiếp tục tác vụ tiếp theo Subagent làm việc độc lập
↓ ↓
Tất cả tác vụ hoàn thành Subagent dùng công cụ "message"
↓ ↓
Phản hồi HEARTBEAT_OK Người dùng nhận kết quả trực tiếp
```
Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và có thể giao tiếp với người dùng một cách độc lập mà không cần thông qua agent chính.
**Cấu hình:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| Tùy chọn | Mặc định | Mô tả |
|----------|---------|-------|
| `enabled` | `true` | Bật/tắt heartbeat |
| `interval` | `30` | Khoảng thời gian kiểm tra (phút, tối thiểu: 5) |
**Biến môi trường:**
* `PICOCLAW_HEARTBEAT_ENABLED=false` để tắt
* `PICOCLAW_HEARTBEAT_INTERVAL=60` để thay đổi khoảng thời gian
### Nhà cung cấp (Providers)
> [!NOTE]
> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn thoại trên Telegram sẽ được tự động chuyển thành văn bản.
| Nhà cung cấp | Mục đích | Lấy API Key |
| --- | --- | --- |
| `gemini` | LLM (Gemini trực tiếp) | [aistudio.google.com](https://aistudio.google.com) |
| `zhipu` | LLM (Zhipu trực tiếp) | [bigmodel.cn](bigmodel.cn) |
| `openrouter` (Đang thử nghiệm) | LLM (khuyên dùng, truy cập mọi model) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` (Đang thử nghiệm) | LLM (Claude trực tiếp) | [console.anthropic.com](https://console.anthropic.com) |
| `openai` (Đang thử nghiệm) | LLM (GPT trực tiếp) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` (Đang thử nghiệm) | LLM (DeepSeek trực tiếp) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Chuyển giọng nói** (Whisper) | [console.groq.com](https://console.groq.com) |
<details>
<summary><b>Cấu hình Zhipu</b></summary>
**1. Lấy API key**
* Lấy [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. Cấu hình**
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"zhipu": {
"api_key": "Your API Key",
"api_base": "https://open.bigmodel.cn/api/paas/v4"
}
}
}
```
**3. Chạy**
```bash
picoclaw agent -m "Xin chào"
```
</details>
<details>
<summary><b>Ví dụ cấu hình đầy đủ</b></summary>
```json
{
"agents": {
"defaults": {
"model": "anthropic/claude-opus-4-5"
}
},
"providers": {
"openrouter": {
"api_key": "sk-or-v1-xxx"
},
"groq": {
"api_key": "gsk_xxx"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "123456:ABC...",
"allow_from": ["123456789"]
},
"discord": {
"enabled": true,
"token": "",
"allow_from": [""]
},
"whatsapp": {
"enabled": false
},
"feishu": {
"enabled": false,
"app_id": "cli_xxx",
"app_secret": "xxx",
"encrypt_key": "",
"verification_token": "",
"allow_from": []
},
"qq": {
"enabled": false,
"app_id": "",
"app_secret": "",
"allow_from": []
}
},
"tools": {
"web": {
"brave": {
"enabled": false,
"api_key": "BSA...",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
</details>
## Tham chiếu CLI
| Lệnh | Mô tả |
| --- | --- |
| `picoclaw onboard` | Khởi tạo cấu hình & workspace |
| `picoclaw agent -m "..."` | Trò chuyện với agent |
| `picoclaw agent` | Chế độ chat tương tác |
| `picoclaw gateway` | Khởi động gateway (cho bot chat) |
| `picoclaw status` | Hiển thị trạng thái |
| `picoclaw cron list` | Liệt kê tất cả tác vụ định kỳ |
| `picoclaw cron add ...` | Thêm tác vụ định kỳ |
### Tác vụ định kỳ / Nhắc nhở
PicoClaw hỗ trợ nhắc nhở theo lịch và tác vụ lặp lại thông qua công cụ `cron`:
* **Nhắc nhở một lần**: "Remind me in 10 minutes" (Nhắc tôi sau 10 phút) → kích hoạt một lần sau 10 phút
* **Tác vụ lặp lại**: "Remind me every 2 hours" (Nhắc tôi mỗi 2 giờ) → kích hoạt mỗi 2 giờ
* **Biểu thức Cron**: "Remind me at 9am daily" (Nhắc tôi lúc 9 giờ sáng mỗi ngày) → sử dụng biểu thức cron
Các tác vụ được lưu trong `~/.picoclaw/workspace/cron/` và được xử lý tự động.
## 🤝 Đóng góp & Lộ trình
Chào đón mọi PR! Mã nguồn được thiết kế nhỏ gọn và dễ đọc. 🤗
Lộ trình sắp được công bố...
Nhóm phát triển đang được xây dựng. Điều kiện tham gia: Ít nhất 1 PR đã được merge.
Nhóm người dùng:
Discord: <https://discord.gg/V4sAZ9XWpN>
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 Xử lý sự cố
### Tìm kiếm web hiện "API 配置问题"
Điều này là bình thường nếu bạn chưa cấu hình API key cho tìm kiếm. PicoClaw sẽ cung cấp các liên kết hữu ích để tìm kiếm thủ công.
Để bật tìm kiếm web:
1. **Tùy chọn 1 (Khuyên dùng)**: Lấy API key miễn phí tại [https://brave.com/search/api](https://brave.com/search/api) (2000 truy vấn miễn phí/tháng) để có kết quả tốt nhất.
2. **Tùy chọn 2 (Không cần thẻ tín dụng)**: Nếu không có key, hệ thống tự động chuyển sang dùng **DuckDuckGo** (không cần key).
Thêm key vào `~/.picoclaw/config.json` nếu dùng Brave:
```json
{
"tools": {
"web": {
"brave": {
"enabled": true,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
}
}
```
### Gặp lỗi lọc nội dung (Content Filtering)
Một số nhà cung cấp (như Zhipu) có bộ lọc nội dung nghiêm ngặt. Thử diễn đạt lại câu hỏi hoặc sử dụng model khác.
### Telegram bot báo "Conflict: terminated by other getUpdates"
Điều này xảy ra khi có một instance bot khác đang chạy. Đảm bảo chỉ có một tiến trình `picoclaw gateway` chạy tại một thời điểm.
---
## 📝 So sánh API Key
| Dịch vụ | Gói miễn phí | Trường hợp sử dụng |
| --- | --- | --- |
| **OpenRouter** | 200K tokens/tháng | Đa model (Claude, GPT-4, v.v.) |
| **Zhipu** | 200K tokens/tháng | Tốt nhất cho người dùng Trung Quốc |
| **Brave Search** | 2000 truy vấn/tháng | Chức năng tìm kiếm web |
| **Groq** | Có gói miễn phí | Suy luận siêu nhanh (Llama, Mixtral) |
+28 -3
View File
@@ -14,7 +14,7 @@
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
**中文** | [日本語](README.ja.md) | [English](README.md)
**中文** | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [English](README.md)
</div>
---
@@ -46,9 +46,11 @@
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
> * **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
> * **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化.
## 📢 新闻 (News)
2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](docs/picoclaw_community_roadmap_260216.md), 期待你的参与!
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
@@ -98,6 +100,23 @@
</tr>
</table>
### 📱 在手机上轻松运行
picoclaw 可以将你10年前的老旧手机废物利用,变身成为你的AI助理!快速指南:
1. 先去应用商店下载安装Termux
2. 打开后执行指令
```bash
# 注意: 下面的v0.1.1 可以换为你实际看到的最新版本
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
chmod +x picoclaw-linux-arm64
pkg install proot
termux-chroot ./picoclaw-linux-arm64 onboard
```
然后跟随下面的“快速开始”章节继续配置picoclaw即可使用!
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
### 🐜 创新的低占用部署
PicoClaw 几乎可以部署在任何 Linux 设备上!
@@ -217,6 +236,9 @@ picoclaw onboard
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
},
"cron": {
"exec_timeout_minutes": 5
}
}
}
@@ -269,7 +291,7 @@ picoclaw agent -m "2+2 等于几?"
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
"allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -314,7 +336,7 @@ picoclaw gateway
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
"allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -625,6 +647,9 @@ picoclaw agent -m "你好"
"search": {
"api_key": "BSA..."
}
},
"cron": {
"exec_timeout_minutes": 5
}
},
"heartbeat": {
+116
View File
@@ -0,0 +1,116 @@
# 🦐 PicoClaw Roadmap
> **Vision**: To build the ultimate lightweight, secure, and fully autonomous AI Agent infrastructure.automate the mundane, unleash your creativity
---
## 🚀 1. Core Optimization: Extreme Lightweight
*Our defining characteristic. We fight software bloat to ensure PicoClaw runs smoothly on the smallest embedded devices.*
* [**Memory Footprint Reduction**](https://github.com/sipeed/picoclaw/issues/346)
* **Goal**: Run smoothly on 64MB RAM embedded boards (e.g., low-end RISC-V SBCs) with the core process consuming < 20MB.
* **Context**: RAM is expensive and scarce on edge devices. Memory optimization takes precedence over storage size.
* **Action**: Analyze memory growth between releases, remove redundant dependencies, and optimize data structures.
## 🛡️ 2. Security Hardening: Defense in Depth
*Paying off early technical debt. We invite security experts to help build a "Secure-by-Default" agent.*
* **Input Defense & Permission Control**
* **Prompt Injection Defense**: Harden JSON extraction logic to prevent LLM manipulation.
* **Tool Abuse Prevention**: Strict parameter validation to ensure generated commands stay within safe boundaries.
* **SSRF Protection**: Built-in blocklists for network tools to prevent accessing internal IPs (LAN/Metadata services).
* **Sandboxing & Isolation**
* **Filesystem Sandbox**: Restrict file R/W operations to specific directories only.
* **Context Isolation**: Prevent data leakage between different user sessions or channels.
* **Privacy Redaction**: Auto-redact sensitive info (API Keys, PII) from logs and standard outputs.
* **Authentication & Secrets**
* **Crypto Upgrade**: Adopt modern algorithms like `ChaCha20-Poly1305` for secret storage.
* **OAuth 2.0 Flow**: Deprecate hardcoded API keys in the CLI; move to secure OAuth flows.
## 🔌 3. Connectivity: Protocol-First Architecture
*Connect every model, reach every platform.*
* **Provider**
* [**Architecture Upgrade**](https://github.com/sipeed/picoclaw/issues/283): Refactor from "Vendor-based" to "Protocol-based" classification (e.g., OpenAI-compatible, Ollama-compatible). *(Status: In progress by @Daming, ETA 5 days)*
* **Local Models**: Deep integration with **Ollama**, **vLLM**, **LM Studio**, and **Mistral** (local inference).
* **Online Models**: Continued support for frontier closed-source models.
* **Channel**
* **IM Matrix**: QQ, WeChat (Work), DingTalk, Feishu (Lark), Telegram, Discord, WhatsApp, LINE, Slack, Email, KOOK, Signal, ...
* **Standards**: Support for the **OneBot** protocol.
* [**attachment**](https://github.com/sipeed/picoclaw/issues/348): Native handling of images, audio, and video attachments.
* **Skill Marketplace**
* [**Discovery skills**](https://github.com/sipeed/picoclaw/issues/287): Implement `find_skill` to automatically discover and install skills from the [GitHub Skills Repo] or other registries.
## 🧠 4. Advanced Capabilities: From Chatbot to Agentic AI
*Beyond conversation—focusing on action and collaboration.*
* **Operations**
* [**MCP Support**](https://github.com/sipeed/picoclaw/issues/290): Native support for the **Model Context Protocol (MCP)**.
* [**Browser Automation**](https://github.com/sipeed/picoclaw/issues/293): Headless browser control via CDP (Chrome DevTools Protocol) or ActionBook.
* [**Mobile Operation**](https://github.com/sipeed/picoclaw/issues/292): Android device control (similar to BotDrop).
* **Multi-Agent Collaboration**
* [**Basic Multi-Agent**](https://github.com/sipeed/picoclaw/issues/294) implement
* [**Model Routing**](https://github.com/sipeed/picoclaw/issues/295): "Smart Routing" — dispatch simple tasks to small/local models (fast/cheap) and complex tasks to SOTA models (smart).
* [**Swarm Mode**](https://github.com/sipeed/picoclaw/issues/284): Collaboration between multiple PicoClaw instances on the same network.
* [**AIEOS**](https://github.com/sipeed/picoclaw/issues/296): Exploring AI-Native Operating System interaction paradigms.
## 📚 5. Developer Experience (DevEx) & Documentation
*Lowering the barrier to entry so anyone can deploy in minutes.*
* [**QuickGuide (Zero-Config Start)**](https://github.com/sipeed/picoclaw/issues/350)
* Interactive CLI Wizard: If launched without config, automatically detect the environment and guide the user through Token/Network setup step-by-step.
* **Comprehensive Documentation**
* **Platform Guides**: Dedicated guides for Windows, macOS, Linux, and Android.
* **Step-by-Step Tutorials**: "Babysitter-level" guides for configuring Providers and Channels.
* **AI-Assisted Docs**: Using AI to auto-generate API references and code comments (with human verification to prevent hallucinations).
## 🤖 6. Engineering: AI-Powered Open Source
*Born from Vibe Coding, we continue to use AI to accelerate development.*
* **AI-Enhanced CI/CD**
* Integrate AI for automated Code Review, Linting, and PR Labeling.
* **Bot Noise Reduction**: Optimize bot interactions to keep PR timelines clean.
* **Issue Triage**: AI agents to analyze incoming issues and suggest preliminary fixes.
## 🎨 7. Brand & Community
* [**Logo Design**](https://github.com/sipeed/picoclaw/issues/297): We are looking for a **Mantis Shrimp (Stomatopoda)** logo design!
* *Concept*: Needs to reflect "Small but Mighty" and "Lightning Fast Strikes."
---
### 🤝 Call for Contributions
We welcome community contributions to any item on this roadmap! Please comment on the relevant Issue or submit a PR. Let's build the best Edge AI Agent together!
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 140 KiB

After

Width:  |  Height:  |  Size: 141 KiB

+10 -3
View File
@@ -562,7 +562,8 @@ func gatewayCmd() {
})
// Setup cron tool and service
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace)
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout, cfg)
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
@@ -622,6 +623,12 @@ func gatewayCmd() {
logger.InfoC("voice", "Groq transcription attached to Slack channel")
}
}
if onebotChannel, ok := channelManager.GetChannel("onebot"); ok {
if oc, ok := onebotChannel.(*channels.OneBotChannel); ok {
oc.SetTranscriber(transcriber)
logger.InfoC("voice", "Groq transcription attached to OneBot channel")
}
}
}
enabledChannels := channelManager.GetEnabledChannels()
@@ -987,14 +994,14 @@ func getConfigPath() string {
return filepath.Join(home, ".picoclaw", "config.json")
}
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool) *cron.CronService {
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *cron.CronService {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
// Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict)
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, config)
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
+13 -55
View File
@@ -79,7 +79,8 @@
},
"openai": {
"api_key": "",
"api_base": ""
"api_base": "",
"web_search": true
},
"openrouter": {
"api_key": "sk-or-v1-xxx",
@@ -117,66 +118,23 @@
},
"tools": {
"web": {
"search": {
"brave": {
"enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"perplexity": {
"enabled": false,
"api_key": "pplx-xxx",
"max_results": 5
}
},
"cron": {
"exec_timeout_minutes": 5
},
"mcp": {
"enabled": false,
"servers": {
"filesystem": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-filesystem",
"/tmp"
],
"env": {}
},
"github": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-github"
],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
},
"envFile": ".env"
},
"brave-search": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-brave-search"
],
"env": {
"BRAVE_API_KEY": "YOUR_BRAVE_API_KEY"
}
},
"postgres": {
"enabled": false,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-postgres",
"postgresql://user:password@localhost/dbname"
]
},
"remote-http-example": {
"enabled": false,
"url": "https://mcp-server.example.com/stream",
"type": "sse",
"headers": {
"Authorization": "Bearer YOUR_TOKEN",
"X-Custom-Header": "custom-value"
}
}
}
"servers": {}
}
},
"heartbeat": {
+4 -4
View File
@@ -11,8 +11,8 @@ services:
profiles:
- agent
volumes:
- ./config/config.json:/root/.picoclaw/config.json:ro
- picoclaw-workspace:/root/.picoclaw/workspace
- ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
- picoclaw-workspace:/home/picoclaw/.picoclaw/workspace
entrypoint: ["picoclaw", "agent"]
stdin_open: true
tty: true
@@ -31,9 +31,9 @@ services:
- gateway
volumes:
# Configuration file
- ./config/config.json:/root/.picoclaw/config.json:ro
- ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
# Persistent workspace (sessions, memory, logs)
- picoclaw-workspace:/root/.picoclaw/workspace
- picoclaw-workspace:/home/picoclaw/.picoclaw/workspace
command: ["gateway"]
volumes:
+112
View File
@@ -0,0 +1,112 @@
## 🚀 Join the PicoClaw Journey: Call for Community Volunteers & Roadmap Reveal
**Hello, PicoClaw Community!**
First, a massive thank you to everyone for your enthusiasm and PR contributions. It is because of you that PicoClaw continues to iterate and evolve so rapidly. Thanks to the simplicity and accessibility of the **Go language**, weve seen a non-stop stream of high-quality PRs!
PicoClaw is growing much faster than we anticipated. As we are currently in the midst of the **Chinese New Year holiday**, we are looking to recruit community volunteers to help us maintain this incredible momentum.
This document outlines the specific volunteer roles we need right now and provides a look at our upcoming **Roadmap**.
### 🎁 Community Perks
To show our appreciation, developers who officially join our community operations will receive:
* **Exclusive AI Hardware:** Our upcoming, unreleased AI device.
* **Token Discounts:** Potential discounts on LLM tokens (currently in negotiations with major providers).
### 🎥 Calling All Content Creators!
Not a developer? You can still help! We welcome users to post **PicoClaw reviews or tutorials**.
* **Twitter:** Use the tag **#picoclaw** and mention **@SipeedIO**.
* **Bilibili:** Mention **@Sipeed矽速科技** or send us a DM.
We will be rewarding high-quality content creators with the same perks as our community developers!
---
## 🛠️ Urgent Volunteer Roles
We are looking for experts in the following areas:
1. **Issue/PR Reviewers**
* **The Mission:** With PRs and Issues exploding in volume, we need help with initial triage, evaluation, and merging.
* **Focus:** Preliminary merging and community health. Efficiency optimization and security audits will be handled by specialized roles.
2. **Resource Optimization Experts**
* **The Mission:** Rapid growth has introduced dependencies that are making PicoClaw a bit "heavy." We want to keep it lean.
* **Focus:** Analyzing resource growth between releases and trimming redundancy.
* **Priority:** **RAM usage optimization** > Binary size reduction.
3. **Security Audit & Bug Fixes**
* **The Mission:** Due to the "vibe coding" nature of our early stages, we need a thorough review of network security and AI permission management.
* **Focus:** Auditing the codebase for vulnerabilities and implementing robust fixes.
4. **Documentation & DX (Developer Experience)**
* **The Mission:** Our current README is a bit outdated. We need "step-by-step" guides that even beginners can follow.
* **Focus:** Creating clear, user-friendly documentation for both setup and development.
5. **AI-Powered CI/CD Optimization**
* **The Mission:** PicoClaw started as a "vibe coding" experiment; now we want to use AI to manage it.
* **Focus:** Automating builds with AI and exploring AI-driven issue resolution.
**How to Apply:** > If you are interested in any of the roles above, please send an email to support@sipeed.com with the subject line: [Apply: PicoClaw Expert Volunteer] + Your Desired Role.
Please include a brief introduction and any relevant experience or portfolio links. We will review all applications and grant project permissions to selected contributors!
---
## 📍 The Roadmap
Interested in a specific feature? You can "claim" these tasks and start building:
###
* **Provider:**
* **Provider Refactor:** Currently being handled by **@Daming** (ETA: 5 days)
* You can still submit code; Daming will merge it into the new implementation.
* **Channels:**
* Support for OneBot, additional platforms
* attachments (images, audio, video, files).
* **Skills:**
* Implementing `find_skill` to discover tools via [openclaw/skills](https://github.com/openclaw/skills) and other platforms.
* **Operations:** * MCP Support.
* Android operations (e.g., botdrop).
* Browser automation via CDP or ActionBook.
* **Multi-Agent Ecosystem:**
* **Basic Model-Agnet** S
* **Model Routing:** Small models for easy tasks, large models for hard ones (to save tokens).
* **Swarm Mode.**
* **AIEOS Integration.**
* **Branding:**
* **Logo**: We need a cute logo! Were leaning toward a **Mantis Shrimp**—small, but packs a legendary punch!
We have officially created these tasks as GitHub Issues, all marked with the roadmap tag.
This list will be updated continuously as we progress.
If you would like to claim a task, please feel free to start a conversation by commenting directly on the corresponding issue!
---
## 🤝 How to Join
**Everything is open to your creativity!** If you have a wild idea, just PR it.
1. **The Fast Track:** Once you have at least **one merged PR**, you are eligible to join our **Developer Discord** to help plan the future of PicoClaw.
2. **The Application Track:** If you havent submitted a PR yet but want to dive in, email **support@sipeed.com** with the subject:
> `[Apply Join PicoClaw Dev Group] + Your GitHub Account`
> Include the role you're interested in and any evidence of your development experience.
### Looking Ahead
Powered by PicoClaw, we are crafting a Swarm AI Assistant to transform your environment into a seamless network of personal stewards. By automating the friction of daily life, we empower you to transcend the ordinary and freely explore your creative potential.
**Finally, Happy Chinese New Year to everyone!** May PicoClaw gallop forward in this **Year of the Horse!** 🐎
+122
View File
@@ -0,0 +1,122 @@
# Tools Configuration
PicoClaw's tools configuration is located in the `tools` field of `config.json`.
## Directory Structure
```json
{
"tools": {
"web": { ... },
"exec": { ... },
"approval": { ... },
"cron": { ... }
}
}
```
## Web Tools
Web tools are used for web search and fetching.
### Brave
| Config | Type | Default | Description |
|--------|------|---------|-------------|
| `enabled` | bool | false | Enable Brave search |
| `api_key` | string | - | Brave Search API key |
| `max_results` | int | 5 | Maximum number of results |
### DuckDuckGo
| Config | Type | Default | Description |
|--------|------|---------|-------------|
| `enabled` | bool | true | Enable DuckDuckGo search |
| `max_results` | int | 5 | Maximum number of results |
### Perplexity
| Config | Type | Default | Description |
|--------|------|---------|-------------|
| `enabled` | bool | false | Enable Perplexity search |
| `api_key` | string | - | Perplexity API key |
| `max_results` | int | 5 | Maximum number of results |
## Exec Tool
The exec tool is used to execute shell commands.
| Config | Type | Default | Description |
|--------|------|---------|-------------|
| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
### Functionality
- **`enable_deny_patterns`**: Set to `false` to completely disable the default dangerous command blocking patterns
- **`custom_deny_patterns`**: Add custom deny regex patterns; commands matching these will be blocked
### Default Blocked Command Patterns
By default, PicoClaw blocks the following dangerous commands:
- Delete commands: `rm -rf`, `del /f/q`, `rmdir /s`
- Disk operations: `format`, `mkfs`, `diskpart`, `dd if=`, writing to `/dev/sd*`
- System operations: `shutdown`, `reboot`, `poweroff`
- Command substitution: `$()`, `${}`, backticks
- Pipe to shell: `| sh`, `| bash`
- Privilege escalation: `sudo`, `chmod`, `chown`
- Process control: `pkill`, `killall`, `kill -9`
- Remote operations: `curl | sh`, `wget | sh`, `ssh`
- Package management: `apt`, `yum`, `dnf`, `npm install -g`, `pip install --user`
- Containers: `docker run`, `docker exec`
- Git: `git push`, `git force`
- Other: `eval`, `source *.sh`
### Configuration Example
```json
{
"tools": {
"exec": {
"enable_deny_patterns": true,
"custom_deny_patterns": [
"\\brm\\s+-r\\b",
"\\bkillall\\s+python"
],
}
}
}
```
## Approval Tool
The approval tool controls permissions for dangerous operations.
| Config | Type | Default | Description |
|--------|------|---------|-------------|
| `enabled` | bool | true | Enable approval functionality |
| `write_file` | bool | true | Require approval for file writes |
| `edit_file` | bool | true | Require approval for file edits |
| `append_file` | bool | true | Require approval for file appends |
| `exec` | bool | true | Require approval for command execution |
| `timeout_minutes` | int | 5 | Approval timeout in minutes |
## Cron Tool
The cron tool is used for scheduling periodic tasks.
| Config | Type | Default | Description |
|--------|------|---------|-------------|
| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
## Environment Variables
All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS_<SECTION>_<KEY>`:
For example:
- `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true`
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
Note: Array-type environment variables are not currently supported and must be set via the config file.
+145
View File
@@ -0,0 +1,145 @@
package agent
import (
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
// AgentInstance represents a fully configured agent with its own workspace,
// session manager, context builder, and tool registry.
type AgentInstance struct {
ID string
Name string
Model string
Fallbacks []string
Workspace string
MaxIterations int
ContextWindow int
Provider providers.LLMProvider
Sessions *session.SessionManager
ContextBuilder *ContextBuilder
Tools *tools.ToolRegistry
Subagents *config.SubagentsConfig
SkillsFilter []string
Candidates []providers.FallbackCandidate
}
// NewAgentInstance creates an agent instance from config.
func NewAgentInstance(
agentCfg *config.AgentConfig,
defaults *config.AgentDefaults,
cfg *config.Config,
provider providers.LLMProvider,
) *AgentInstance {
workspace := resolveAgentWorkspace(agentCfg, defaults)
os.MkdirAll(workspace, 0755)
model := resolveAgentModel(agentCfg, defaults)
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
restrict := defaults.RestrictToWorkspace
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
sessionsDir := filepath.Join(workspace, "sessions")
sessionsManager := session.NewSessionManager(sessionsDir)
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)
agentID := routing.DefaultAgentID
agentName := ""
var subagents *config.SubagentsConfig
var skillsFilter []string
if agentCfg != nil {
agentID = routing.NormalizeAgentID(agentCfg.ID)
agentName = agentCfg.Name
subagents = agentCfg.Subagents
skillsFilter = agentCfg.Skills
}
maxIter := defaults.MaxToolIterations
if maxIter == 0 {
maxIter = 20
}
// Resolve fallback candidates
modelCfg := providers.ModelConfig{
Primary: model,
Fallbacks: fallbacks,
}
candidates := providers.ResolveCandidates(modelCfg, defaults.Provider)
return &AgentInstance{
ID: agentID,
Name: agentName,
Model: model,
Fallbacks: fallbacks,
Workspace: workspace,
MaxIterations: maxIter,
ContextWindow: defaults.MaxTokens,
Provider: provider,
Sessions: sessionsManager,
ContextBuilder: contextBuilder,
Tools: toolsRegistry,
Subagents: subagents,
SkillsFilter: skillsFilter,
Candidates: candidates,
}
}
// resolveAgentWorkspace determines the workspace directory for an agent.
func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
return expandHome(strings.TrimSpace(agentCfg.Workspace))
}
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
return expandHome(defaults.Workspace)
}
home, _ := os.UserHomeDir()
id := routing.NormalizeAgentID(agentCfg.ID)
return filepath.Join(home, ".picoclaw", "workspace-"+id)
}
// resolveAgentModel resolves the primary model for an agent.
func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" {
return strings.TrimSpace(agentCfg.Model.Primary)
}
return defaults.Model
}
// resolveAgentFallbacks resolves the fallback models for an agent.
func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string {
if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil {
return agentCfg.Model.Fallbacks
}
return defaults.ModelFallbacks
}
func expandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
}
return home
}
return path
}
+335 -341
View File
File diff suppressed because it is too large Load Diff
+6 -2
View File
@@ -594,7 +594,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
}
al.sessions.SetHistory(sessionKey, history)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("No default agent found")
}
defaultAgent.Sessions.SetHistory(sessionKey, history)
// Call ProcessDirectWithChannel
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
@@ -614,7 +618,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
}
// Check final history length
finalHistory := al.sessions.GetHistory(sessionKey)
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
// Original length: 6
// Expected behavior: compression drops ~50% of history (mid slice)
+114
View File
@@ -0,0 +1,114 @@
package agent
import (
"sync"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
)
// AgentRegistry manages multiple agent instances and routes messages to them.
type AgentRegistry struct {
agents map[string]*AgentInstance
resolver *routing.RouteResolver
mu sync.RWMutex
}
// NewAgentRegistry creates a registry from config, instantiating all agents.
func NewAgentRegistry(
cfg *config.Config,
provider providers.LLMProvider,
) *AgentRegistry {
registry := &AgentRegistry{
agents: make(map[string]*AgentInstance),
resolver: routing.NewRouteResolver(cfg),
}
agentConfigs := cfg.Agents.List
if len(agentConfigs) == 0 {
implicitAgent := &config.AgentConfig{
ID: "main",
Default: true,
}
instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, cfg, provider)
registry.agents["main"] = instance
logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil)
} else {
for i := range agentConfigs {
ac := &agentConfigs[i]
id := routing.NormalizeAgentID(ac.ID)
instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider)
registry.agents[id] = instance
logger.InfoCF("agent", "Registered agent",
map[string]interface{}{
"agent_id": id,
"name": ac.Name,
"workspace": instance.Workspace,
"model": instance.Model,
})
}
}
return registry
}
// GetAgent returns the agent instance for a given ID.
func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
id := routing.NormalizeAgentID(agentID)
agent, ok := r.agents[id]
return agent, ok
}
// ResolveRoute determines which agent handles the message.
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
return r.resolver.ResolveRoute(input)
}
// ListAgentIDs returns all registered agent IDs.
func (r *AgentRegistry) ListAgentIDs() []string {
r.mu.RLock()
defer r.mu.RUnlock()
ids := make([]string, 0, len(r.agents))
for id := range r.agents {
ids = append(ids, id)
}
return ids
}
// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID.
func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool {
parent, ok := r.GetAgent(parentAgentID)
if !ok {
return false
}
if parent.Subagents == nil || parent.Subagents.AllowAgents == nil {
return false
}
targetNorm := routing.NormalizeAgentID(targetAgentID)
for _, allowed := range parent.Subagents.AllowAgents {
if allowed == "*" {
return true
}
if routing.NormalizeAgentID(allowed) == targetNorm {
return true
}
}
return false
}
// GetDefaultAgent returns the default agent instance.
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
r.mu.RLock()
defer r.mu.RUnlock()
if agent, ok := r.agents["main"]; ok {
return agent
}
for _, agent := range r.agents {
return agent
}
return nil
}
+199
View File
@@ -0,0 +1,199 @@
package agent
import (
"context"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
type mockRegistryProvider struct{}
func (m *mockRegistryProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil
}
func (m *mockRegistryProvider) GetDefaultModel() string {
return "mock-model"
}
func testCfg(agents []config.AgentConfig) *config.Config {
return &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: "/tmp/picoclaw-test-registry",
Model: "gpt-4",
MaxTokens: 8192,
MaxToolIterations: 10,
},
List: agents,
},
}
}
func TestNewAgentRegistry_ImplicitMain(t *testing.T) {
cfg := testCfg(nil)
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
ids := registry.ListAgentIDs()
if len(ids) != 1 || ids[0] != "main" {
t.Errorf("expected implicit main agent, got %v", ids)
}
agent, ok := registry.GetAgent("main")
if !ok || agent == nil {
t.Fatal("expected to find 'main' agent")
}
if agent.ID != "main" {
t.Errorf("agent.ID = %q, want 'main'", agent.ID)
}
}
func TestNewAgentRegistry_ExplicitAgents(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "sales", Default: true, Name: "Sales Bot"},
{ID: "support", Name: "Support Bot"},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
ids := registry.ListAgentIDs()
if len(ids) != 2 {
t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids)
}
sales, ok := registry.GetAgent("sales")
if !ok || sales == nil {
t.Fatal("expected to find 'sales' agent")
}
if sales.Name != "Sales Bot" {
t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name)
}
support, ok := registry.GetAgent("support")
if !ok || support == nil {
t.Fatal("expected to find 'support' agent")
}
}
func TestAgentRegistry_GetAgent_Normalize(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "my-agent", Default: true},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, ok := registry.GetAgent("My-Agent")
if !ok || agent == nil {
t.Fatal("expected to find agent with normalized ID")
}
if agent.ID != "my-agent" {
t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID)
}
}
func TestAgentRegistry_GetDefaultAgent(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "alpha"},
{ID: "beta", Default: true},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
// GetDefaultAgent first checks for "main", then returns any
agent := registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected a default agent")
}
}
func TestAgentRegistry_CanSpawnSubagent(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{
ID: "parent",
Default: true,
Subagents: &config.SubagentsConfig{
AllowAgents: []string{"child1", "child2"},
},
},
{ID: "child1"},
{ID: "child2"},
{ID: "restricted"},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
if !registry.CanSpawnSubagent("parent", "child1") {
t.Error("expected parent to be allowed to spawn child1")
}
if !registry.CanSpawnSubagent("parent", "child2") {
t.Error("expected parent to be allowed to spawn child2")
}
if registry.CanSpawnSubagent("parent", "restricted") {
t.Error("expected parent to NOT be allowed to spawn restricted")
}
if registry.CanSpawnSubagent("child1", "child2") {
t.Error("expected child1 to NOT be allowed to spawn (no subagents config)")
}
}
func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{
ID: "admin",
Default: true,
Subagents: &config.SubagentsConfig{
AllowAgents: []string{"*"},
},
},
{ID: "any-agent"},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
if !registry.CanSpawnSubagent("admin", "any-agent") {
t.Error("expected wildcard to allow spawning any agent")
}
if !registry.CanSpawnSubagent("admin", "nonexistent") {
t.Error("expected wildcard to allow spawning even nonexistent agents")
}
}
func TestAgentInstance_Model(t *testing.T) {
model := &config.AgentModelConfig{Primary: "claude-opus"}
cfg := testCfg([]config.AgentConfig{
{ID: "custom", Default: true, Model: model},
})
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, _ := registry.GetAgent("custom")
if agent.Model != "claude-opus" {
t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model)
}
}
func TestAgentInstance_FallbackInheritance(t *testing.T) {
cfg := testCfg([]config.AgentConfig{
{ID: "inherit", Default: true},
})
cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"}
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, _ := registry.GetAgent("inherit")
if len(agent.Fallbacks) != 2 {
t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks))
}
}
func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) {
model := &config.AgentModelConfig{
Primary: "gpt-4",
Fallbacks: []string{}, // explicitly empty = disable
}
cfg := testCfg([]config.AgentConfig{
{ID: "no-fallback", Default: true, Model: model},
})
cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"}
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
agent, _ := registry.GetAgent("no-fallback")
if len(agent.Fallbacks) != 0 {
t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks)
}
}
+6 -11
View File
@@ -2,7 +2,6 @@ package channels
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -87,17 +86,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
return
}
// Build session key: channel:chatID
sessionKey := fmt.Sprintf("%s:%s", c.name, chatID)
msg := bus.InboundMessage{
Channel: c.name,
SenderID: senderID,
ChatID: chatID,
Content: content,
Media: media,
SessionKey: sessionKey,
Metadata: metadata,
Channel: c.name,
SenderID: senderID,
ChatID: chatID,
Content: content,
Media: media,
Metadata: metadata,
}
c.bus.PublishInbound(msg)
+10 -128
View File
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/bwmarrin/discordgo"
@@ -106,7 +105,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return nil
}
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars
for _, chunk := range chunks {
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
@@ -117,132 +116,6 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return nil
}
// splitMessage splits long messages into chunks, preserving code block integrity
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
func splitMessage(content string, limit int) []string {
var messages []string
for len(content) > 0 {
if len(content) <= limit {
messages = append(messages, content)
break
}
msgEnd := limit
// Find natural split point within the limit
msgEnd = findLastNewline(content[:limit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:limit], 100)
}
if msgEnd <= 0 {
msgEnd = limit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend to include the closing ``` (with some buffer)
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
if len(content) > extendedLimit {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= extendedLimit {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Can't find closing, split before the code block
msgEnd = findLastNewline(content[:unclosedIdx], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:unclosedIdx], 100)
}
if msgEnd <= 0 {
msgEnd = unclosedIdx
}
}
} else {
// Remaining content fits within extended limit
msgEnd = len(content)
}
}
if msgEnd <= 0 {
msgEnd = limit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
count := 0
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
if count == 0 {
lastOpenIdx = i
}
count++
i += 2
}
}
// If odd number of ``` markers, last one is unclosed
if count%2 == 1 {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
@@ -376,6 +249,13 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
"preview": utils.Truncate(content, 50),
})
peerKind := "channel"
peerID := m.ChannelID
if m.GuildID == "" {
peerKind = "direct"
peerID = senderID
}
metadata := map[string]string{
"message_id": m.ID,
"user_id": senderID,
@@ -384,6 +264,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
"guild_id": m.GuildID,
"channel_id": m.ChannelID,
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
"peer_kind": peerKind,
"peer_id": peerID,
}
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
-2
View File
@@ -18,7 +18,6 @@ type MaixCamChannel struct {
listener net.Listener
clients map[net.Conn]bool
clientsMux sync.RWMutex
running bool
}
type MaixCamMessage struct {
@@ -35,7 +34,6 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
BaseChannel: base,
config: cfg,
clients: make(map[net.Conn]bool),
running: false,
}, nil
}
+498 -209
View File
@@ -4,9 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
@@ -14,20 +16,28 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type OneBotChannel struct {
*BaseChannel
config config.OneBotConfig
conn *websocket.Conn
ctx context.Context
cancel context.CancelFunc
dedup map[string]struct{}
dedupRing []string
dedupIdx int
mu sync.Mutex
writeMu sync.Mutex
echoCounter int64
config config.OneBotConfig
conn *websocket.Conn
ctx context.Context
cancel context.CancelFunc
dedup map[string]struct{}
dedupRing []string
dedupIdx int
mu sync.Mutex
writeMu sync.Mutex
echoCounter int64
selfID int64
pending map[string]chan json.RawMessage
pendingMu sync.Mutex
transcriber *voice.GroqTranscriber
lastMessageID sync.Map
pendingEmojiMsg sync.Map
}
type oneBotRawEvent struct {
@@ -43,9 +53,11 @@ type oneBotRawEvent struct {
SelfID json.RawMessage `json:"self_id"`
Time json.RawMessage `json:"time"`
MetaEventType string `json:"meta_event_type"`
NoticeType string `json:"notice_type"`
Echo string `json:"echo"`
RetCode json.RawMessage `json:"retcode"`
Status BotStatus `json:"status"`
Status json.RawMessage `json:"status"`
Data json.RawMessage `json:"data"`
}
type BotStatus struct {
@@ -53,42 +65,36 @@ type BotStatus struct {
Good bool `json:"good"`
}
func isAPIResponse(raw json.RawMessage) bool {
if len(raw) == 0 {
return false
}
var s string
if json.Unmarshal(raw, &s) == nil {
return s == "ok" || s == "failed"
}
var bs BotStatus
if json.Unmarshal(raw, &bs) == nil {
return bs.Online || bs.Good
}
return false
}
type oneBotSender struct {
UserID json.RawMessage `json:"user_id"`
Nickname string `json:"nickname"`
Card string `json:"card"`
}
type oneBotEvent struct {
PostType string
MessageType string
SubType string
MessageID string
UserID int64
GroupID int64
Content string
RawContent string
IsBotMentioned bool
Sender oneBotSender
SelfID int64
Time int64
MetaEventType string
}
type oneBotAPIRequest struct {
Action string `json:"action"`
Params interface{} `json:"params"`
Echo string `json:"echo,omitempty"`
}
type oneBotSendPrivateMsgParams struct {
UserID int64 `json:"user_id"`
Message string `json:"message"`
}
type oneBotSendGroupMsgParams struct {
GroupID int64 `json:"group_id"`
Message string `json:"message"`
type oneBotMessageSegment struct {
Type string `json:"type"`
Data map[string]interface{} `json:"data"`
}
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
@@ -101,9 +107,30 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One
dedup: make(map[string]struct{}, dedupSize),
dedupRing: make([]string, dedupSize),
dedupIdx: 0,
pending: make(map[string]chan json.RawMessage),
}, nil
}
func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) {
go func() {
_, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{
"message_id": messageID,
"emoji_id": emojiID,
"set": set,
}, 5*time.Second)
if err != nil {
logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{
"message_id": messageID,
"error": err.Error(),
})
}
}()
}
func (c *OneBotChannel) Start(ctx context.Context) error {
if c.config.WSUrl == "" {
return fmt.Errorf("OneBot ws_url not configured")
@@ -121,12 +148,12 @@ func (c *OneBotChannel) Start(ctx context.Context) error {
})
} else {
go c.listen()
c.fetchSelfID()
}
if c.config.ReconnectInterval > 0 {
go c.reconnectLoop()
} else {
// If reconnect is disabled but initial connection failed, we cannot recover
if c.conn == nil {
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
}
@@ -152,14 +179,141 @@ func (c *OneBotChannel) connect() error {
return err
}
conn.SetPongHandler(func(appData string) error {
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
go c.pinger(conn)
logger.InfoC("onebot", "WebSocket connected")
return nil
}
func (c *OneBotChannel) pinger(conn *websocket.Conn) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.writeMu.Lock()
err := conn.WriteMessage(websocket.PingMessage, nil)
c.writeMu.Unlock()
if err != nil {
logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{
"error": err.Error(),
})
return
}
}
}
}
func (c *OneBotChannel) fetchSelfID() {
resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second)
if err != nil {
logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{
"error": err.Error(),
})
return
}
type loginInfo struct {
UserID json.RawMessage `json:"user_id"`
Nickname string `json:"nickname"`
}
for _, extract := range []func() (*loginInfo, error){
func() (*loginInfo, error) {
var w struct {
Data loginInfo `json:"data"`
}
err := json.Unmarshal(resp, &w)
return &w.Data, err
},
func() (*loginInfo, error) {
var f loginInfo
err := json.Unmarshal(resp, &f)
return &f, err
},
} {
info, err := extract()
if err != nil || len(info.UserID) == 0 {
continue
}
if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 {
atomic.StoreInt64(&c.selfID, uid)
logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{
"self_id": uid,
"nickname": info.Nickname,
})
return
}
}
logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{
"response": string(resp),
})
}
func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return nil, fmt.Errorf("WebSocket not connected")
}
echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1))
ch := make(chan json.RawMessage, 1)
c.pendingMu.Lock()
c.pending[echo] = ch
c.pendingMu.Unlock()
defer func() {
c.pendingMu.Lock()
delete(c.pending, echo)
c.pendingMu.Unlock()
}()
req := oneBotAPIRequest{
Action: action,
Params: params,
Echo: echo,
}
data, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal API request: %w", err)
}
c.writeMu.Lock()
err = conn.WriteMessage(websocket.TextMessage, data)
c.writeMu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to write API request: %w", err)
}
select {
case resp := <-ch:
return resp, nil
case <-time.After(timeout):
return nil, fmt.Errorf("API request %s timed out after %v", action, timeout)
case <-c.ctx.Done():
return nil, fmt.Errorf("context cancelled")
}
}
func (c *OneBotChannel) reconnectLoop() {
interval := time.Duration(c.config.ReconnectInterval) * time.Second
if interval < 5*time.Second {
@@ -183,6 +337,7 @@ func (c *OneBotChannel) reconnectLoop() {
})
} else {
go c.listen()
c.fetchSelfID()
}
}
}
@@ -197,6 +352,13 @@ func (c *OneBotChannel) Stop(ctx context.Context) error {
c.cancel()
}
c.pendingMu.Lock()
for echo, ch := range c.pending {
close(ch)
delete(c.pending, echo)
}
c.pendingMu.Unlock()
c.mu.Lock()
if c.conn != nil {
c.conn.Close()
@@ -225,10 +387,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
return err
}
c.writeMu.Lock()
c.echoCounter++
echo := fmt.Sprintf("send_%d", c.echoCounter)
c.writeMu.Unlock()
echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1))
req := oneBotAPIRequest{
Action: action,
@@ -252,67 +411,78 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
return err
}
if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok {
if mid, ok := msgID.(string); ok && mid != "" {
c.setMsgEmojiLike(mid, 289, false)
}
}
return nil
}
func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment {
var segments []oneBotMessageSegment
if lastMsgID, ok := c.lastMessageID.Load(chatID); ok {
if msgID, ok := lastMsgID.(string); ok && msgID != "" {
segments = append(segments, oneBotMessageSegment{
Type: "reply",
Data: map[string]interface{}{"id": msgID},
})
}
}
segments = append(segments, oneBotMessageSegment{
Type: "text",
Data: map[string]interface{}{"text": content},
})
return segments
}
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
chatID := msg.ChatID
segments := c.buildMessageSegments(chatID, msg.Content)
if len(chatID) > 6 && chatID[:6] == "group:" {
groupID, err := strconv.ParseInt(chatID[6:], 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID)
}
return "send_group_msg", oneBotSendGroupMsgParams{
GroupID: groupID,
Message: msg.Content,
}, nil
var action, idKey string
var rawID string
if rest, ok := strings.CutPrefix(chatID, "group:"); ok {
action, idKey, rawID = "send_group_msg", "group_id", rest
} else if rest, ok := strings.CutPrefix(chatID, "private:"); ok {
action, idKey, rawID = "send_private_msg", "user_id", rest
} else {
action, idKey, rawID = "send_private_msg", "user_id", chatID
}
if len(chatID) > 8 && chatID[:8] == "private:" {
userID, err := strconv.ParseInt(chatID[8:], 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID)
}
return "send_private_msg", oneBotSendPrivateMsgParams{
UserID: userID,
Message: msg.Content,
}, nil
}
userID, err := strconv.ParseInt(chatID, 10, 64)
id, err := strconv.ParseInt(rawID, 10, 64)
if err != nil {
return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID)
return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID)
}
return "send_private_msg", oneBotSendPrivateMsgParams{
UserID: userID,
Message: msg.Content,
}, nil
return action, map[string]interface{}{idKey: id, "message": segments}, nil
}
func (c *OneBotChannel) listen() {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
return
}
for {
select {
case <-c.ctx.Done():
return
default:
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
return
}
_, message, err := conn.ReadMessage()
if err != nil {
logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
"error": err.Error(),
})
c.mu.Lock()
if c.conn != nil {
if c.conn == conn {
c.conn.Close()
c.conn = nil
}
@@ -320,10 +490,7 @@ func (c *OneBotChannel) listen() {
return
}
logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{
"length": len(message),
"payload": string(message),
})
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
var raw oneBotRawEvent
if err := json.Unmarshal(message, &raw); err != nil {
@@ -334,20 +501,37 @@ func (c *OneBotChannel) listen() {
continue
}
if raw.Echo != "" || raw.Status.Online || raw.Status.Good {
logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{
"echo": raw.Echo,
"status": raw.Status,
})
logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{
"length": len(message),
"post_type": raw.PostType,
"sub_type": raw.SubType,
})
if raw.Echo != "" {
c.pendingMu.Lock()
ch, ok := c.pending[raw.Echo]
c.pendingMu.Unlock()
if ok {
select {
case ch <- message:
default:
}
} else {
logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{
"echo": raw.Echo,
"status": string(raw.Status),
})
}
continue
}
logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{
"post_type": raw.PostType,
"message_type": raw.MessageType,
"sub_type": raw.SubType,
"meta_event_type": raw.MetaEventType,
})
if isAPIResponse(raw.Status) {
logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{
"status": string(raw.Status),
})
continue
}
c.handleRawEvent(&raw)
}
@@ -386,9 +570,12 @@ func parseJSONString(raw json.RawMessage) string {
type parseMessageResult struct {
Text string
IsBotMentioned bool
Media []string
LocalFiles []string
ReplyTo string
}
func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult {
func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult {
if len(raw) == 0 {
return parseMessageResult{}
}
@@ -408,60 +595,155 @@ func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult
}
var segments []map[string]interface{}
if err := json.Unmarshal(raw, &segments); err == nil {
var text string
mentioned := false
selfIDStr := strconv.FormatInt(selfID, 10)
for _, seg := range segments {
segType, _ := seg["type"].(string)
data, _ := seg["data"].(map[string]interface{})
switch segType {
case "text":
if data != nil {
if t, ok := data["text"].(string); ok {
text += t
}
if err := json.Unmarshal(raw, &segments); err != nil {
return parseMessageResult{}
}
var textParts []string
mentioned := false
selfIDStr := strconv.FormatInt(selfID, 10)
var media []string
var localFiles []string
var replyTo string
for _, seg := range segments {
segType, _ := seg["type"].(string)
data, _ := seg["data"].(map[string]interface{})
switch segType {
case "text":
if data != nil {
if t, ok := data["text"].(string); ok {
textParts = append(textParts, t)
}
case "at":
if data != nil && selfID > 0 {
qqVal := fmt.Sprintf("%v", data["qq"])
if qqVal == selfIDStr || qqVal == "all" {
mentioned = true
}
case "at":
if data != nil && selfID > 0 {
qqVal := fmt.Sprintf("%v", data["qq"])
if qqVal == selfIDStr || qqVal == "all" {
mentioned = true
}
}
case "image", "video", "file":
if data != nil {
url, _ := data["url"].(string)
if url != "" {
defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"}
filename := defaults[segType]
if f, ok := data["file"].(string); ok && f != "" {
filename = f
} else if n, ok := data["name"].(string); ok && n != "" {
filename = n
}
localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "onebot",
})
if localPath != "" {
media = append(media, localPath)
localFiles = append(localFiles, localPath)
textParts = append(textParts, fmt.Sprintf("[%s]", segType))
}
}
}
case "record":
if data != nil {
url, _ := data["url"].(string)
if url != "" {
localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{
LoggerPrefix: "onebot",
})
if localPath != "" {
localFiles = append(localFiles, localPath)
if c.transcriber != nil && c.transcriber.IsAvailable() {
tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second)
result, err := c.transcriber.Transcribe(tctx, localPath)
tcancel()
if err != nil {
logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{
"error": err.Error(),
})
textParts = append(textParts, "[voice (transcription failed)]")
media = append(media, localPath)
} else {
textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text))
}
} else {
textParts = append(textParts, "[voice]")
media = append(media, localPath)
}
}
}
}
case "reply":
if data != nil {
if id, ok := data["id"]; ok {
replyTo = fmt.Sprintf("%v", id)
}
}
case "face":
if data != nil {
faceID, _ := data["id"]
textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID))
}
case "forward":
textParts = append(textParts, "[forward message]")
default:
}
return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned}
}
return parseMessageResult{}
return parseMessageResult{
Text: strings.TrimSpace(strings.Join(textParts, "")),
IsBotMentioned: mentioned,
Media: media,
LocalFiles: localFiles,
ReplyTo: replyTo,
}
}
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
switch raw.PostType {
case "message":
evt, err := c.normalizeMessageEvent(raw)
if err != nil {
logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{
"error": err.Error(),
})
return
if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 {
if !c.IsAllowed(strconv.FormatInt(userID, 10)) {
logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{
"user_id": userID,
})
return
}
}
c.handleMessage(evt)
c.handleMessage(raw)
case "message_sent":
logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{
"message_type": raw.MessageType,
"message_id": parseJSONString(raw.MessageID),
})
case "meta_event":
c.handleMetaEvent(raw)
case "notice":
logger.DebugCF("onebot", "Notice event received", map[string]interface{}{
"sub_type": raw.SubType,
})
c.handleNoticeEvent(raw)
case "request":
logger.DebugCF("onebot", "Request event received", map[string]interface{}{
"sub_type": raw.SubType,
})
case "":
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
"echo": raw.Echo,
"status": raw.Status,
})
default:
logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
"post_type": raw.PostType,
@@ -469,18 +751,51 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
}
}
func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) {
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
if raw.MetaEventType == "lifecycle" {
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType})
} else if raw.MetaEventType != "heartbeat" {
logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil)
}
}
func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) {
fields := map[string]interface{}{
"notice_type": raw.NoticeType,
"sub_type": raw.SubType,
"group_id": parseJSONString(raw.GroupID),
"user_id": parseJSONString(raw.UserID),
"message_id": parseJSONString(raw.MessageID),
}
switch raw.NoticeType {
case "group_recall", "group_increase", "group_decrease",
"friend_add", "group_admin", "group_ban":
logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields)
default:
logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields)
}
}
func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
// Parse fields from raw event
userID, err := parseJSONInt64(raw.UserID)
if err != nil {
return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID))
logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{
"error": err.Error(),
"raw": string(raw.UserID),
})
return
}
groupID, _ := parseJSONInt64(raw.GroupID)
selfID, _ := parseJSONInt64(raw.SelfID)
ts, _ := parseJSONInt64(raw.Time)
messageID := parseJSONString(raw.MessageID)
parsed := parseMessageContentEx(raw.Message, selfID)
if selfID == 0 {
selfID = atomic.LoadInt64(&c.selfID)
}
parsed := c.parseMessageSegments(raw.Message, selfID)
isBotMentioned := parsed.IsBotMentioned
content := raw.RawMessage
@@ -495,6 +810,10 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent
}
}
if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") {
content = parsed.Text
}
var sender oneBotSender
if len(raw.Sender) > 0 {
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
@@ -505,137 +824,107 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent
}
}
logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{
"message_type": raw.MessageType,
"user_id": userID,
"group_id": groupID,
"message_id": messageID,
"content_len": len(content),
"nickname": sender.Nickname,
})
return &oneBotEvent{
PostType: raw.PostType,
MessageType: raw.MessageType,
SubType: raw.SubType,
MessageID: messageID,
UserID: userID,
GroupID: groupID,
Content: content,
RawContent: raw.RawMessage,
IsBotMentioned: isBotMentioned,
Sender: sender,
SelfID: selfID,
Time: ts,
MetaEventType: raw.MetaEventType,
}, nil
}
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
switch raw.MetaEventType {
case "lifecycle":
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{
"sub_type": raw.SubType,
})
case "heartbeat":
logger.DebugC("onebot", "Heartbeat received")
default:
logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{
"meta_event_type": raw.MetaEventType,
})
// Clean up temp files when done
if len(parsed.LocalFiles) > 0 {
defer func() {
for _, f := range parsed.LocalFiles {
if err := os.Remove(f); err != nil {
logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{
"path": f,
"error": err.Error(),
})
}
}
}()
}
}
func (c *OneBotChannel) handleMessage(evt *oneBotEvent) {
if c.isDuplicate(evt.MessageID) {
if c.isDuplicate(messageID) {
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
"message_id": evt.MessageID,
"message_id": messageID,
})
return
}
content := evt.Content
if content == "" {
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
"message_id": evt.MessageID,
"message_id": messageID,
})
return
}
senderID := strconv.FormatInt(evt.UserID, 10)
senderID := strconv.FormatInt(userID, 10)
var chatID string
metadata := map[string]string{
"message_id": evt.MessageID,
"message_id": messageID,
}
switch evt.MessageType {
if parsed.ReplyTo != "" {
metadata["reply_to_message_id"] = parsed.ReplyTo
}
switch raw.MessageType {
case "private":
chatID = "private:" + senderID
logger.InfoCF("onebot", "Received private message", map[string]interface{}{
"sender": senderID,
"message_id": evt.MessageID,
"length": len(content),
"content": truncate(content, 100),
})
case "group":
groupIDStr := strconv.FormatInt(evt.GroupID, 10)
groupIDStr := strconv.FormatInt(groupID, 10)
chatID = "group:" + groupIDStr
metadata["group_id"] = groupIDStr
senderUserID, _ := parseJSONInt64(evt.Sender.UserID)
senderUserID, _ := parseJSONInt64(sender.UserID)
if senderUserID > 0 {
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
}
if evt.Sender.Card != "" {
metadata["sender_name"] = evt.Sender.Card
} else if evt.Sender.Nickname != "" {
metadata["sender_name"] = evt.Sender.Nickname
if sender.Card != "" {
metadata["sender_name"] = sender.Card
} else if sender.Nickname != "" {
metadata["sender_name"] = sender.Nickname
}
triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned)
triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned)
if !triggered {
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
"sender": senderID,
"group": groupIDStr,
"is_mentioned": evt.IsBotMentioned,
"is_mentioned": isBotMentioned,
"content": truncate(content, 100),
})
return
}
content = strippedContent
logger.InfoCF("onebot", "Received group message", map[string]interface{}{
"sender": senderID,
"group": groupIDStr,
"message_id": evt.MessageID,
"is_mentioned": evt.IsBotMentioned,
"length": len(content),
"content": truncate(content, 100),
})
default:
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
"type": evt.MessageType,
"message_id": evt.MessageID,
"user_id": evt.UserID,
"type": raw.MessageType,
"message_id": messageID,
"user_id": userID,
})
return
}
if evt.Sender.Nickname != "" {
metadata["nickname"] = evt.Sender.Nickname
}
logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"content": truncate(content, 100),
logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{
"sender": senderID,
"chat_id": chatID,
"message_id": messageID,
"length": len(content),
"content": truncate(content, 100),
"media_count": len(parsed.Media),
})
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
if sender.Nickname != "" {
metadata["nickname"] = sender.Nickname
}
c.lastMessageID.Store(chatID, messageID)
if raw.MessageType == "group" && messageID != "" && messageID != "0" {
c.setMsgEmojiLike(messageID, 289, true)
c.pendingEmojiMsg.Store(chatID, messageID)
}
c.HandleMessage(senderID, chatID, content, parsed.Media, metadata)
}
func (c *OneBotChannel) isDuplicate(messageID string) bool {
+25
View File
@@ -25,6 +25,7 @@ type SlackChannel struct {
api *slack.Client
socketClient *socketmode.Client
botUserID string
teamID string
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
@@ -72,6 +73,7 @@ func (c *SlackChannel) Start(ctx context.Context) error {
return fmt.Errorf("slack auth test failed: %w", err)
}
c.botUserID = authResp.UserID
c.teamID = authResp.TeamID
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
"bot_user_id": c.botUserID,
@@ -274,11 +276,21 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
return
}
peerKind := "channel"
peerID := channelID
if strings.HasPrefix(channelID, "D") {
peerKind = "direct"
peerID = senderID
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"peer_kind": peerKind,
"peer_id": peerID,
"team_id": c.teamID,
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
@@ -331,12 +343,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
return
}
mentionPeerKind := "channel"
mentionPeerID := channelID
if strings.HasPrefix(channelID, "D") {
mentionPeerKind = "direct"
mentionPeerID = senderID
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"is_mention": "true",
"peer_kind": mentionPeerKind,
"peer_id": mentionPeerID,
"team_id": c.teamID,
}
c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -373,6 +395,9 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"platform": "slack",
"is_command": "true",
"trigger_id": cmd.TriggerID,
"peer_kind": "channel",
"peer_id": channelID,
"team_id": c.teamID,
}
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
+9
View File
@@ -347,12 +347,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
c.placeholders.Store(chatIDStr, pID)
}
peerKind := "direct"
peerID := fmt.Sprintf("%d", user.ID)
if message.Chat.Type != "private" {
peerKind = "group"
peerID = fmt.Sprintf("%d", chatID)
}
metadata := map[string]string{
"message_id": fmt.Sprintf("%d", message.MessageID),
"user_id": fmt.Sprintf("%d", user.ID),
"username": user.Username,
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
"peer_kind": peerKind,
"peer_id": peerID,
}
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
+166 -23
View File
@@ -45,6 +45,8 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
type Config struct {
Agents AgentsConfig `json:"agents"`
Bindings []AgentBinding `json:"bindings,omitempty"`
Session SessionConfig `json:"session,omitempty"`
Channels ChannelsConfig `json:"channels"`
Providers ProvidersConfig `json:"providers"`
Gateway GatewayConfig `json:"gateway"`
@@ -56,16 +58,97 @@ type Config struct {
type AgentsConfig struct {
Defaults AgentDefaults `json:"defaults"`
List []AgentConfig `json:"list,omitempty"`
}
// AgentModelConfig supports both string and structured model config.
// String format: "gpt-4" (just primary, no fallbacks)
// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]}
type AgentModelConfig struct {
Primary string `json:"primary,omitempty"`
Fallbacks []string `json:"fallbacks,omitempty"`
}
func (m *AgentModelConfig) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err == nil {
m.Primary = s
m.Fallbacks = nil
return nil
}
type raw struct {
Primary string `json:"primary"`
Fallbacks []string `json:"fallbacks"`
}
var r raw
if err := json.Unmarshal(data, &r); err != nil {
return err
}
m.Primary = r.Primary
m.Fallbacks = r.Fallbacks
return nil
}
func (m AgentModelConfig) MarshalJSON() ([]byte, error) {
if len(m.Fallbacks) == 0 && m.Primary != "" {
return json.Marshal(m.Primary)
}
type raw struct {
Primary string `json:"primary,omitempty"`
Fallbacks []string `json:"fallbacks,omitempty"`
}
return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks})
}
type AgentConfig struct {
ID string `json:"id"`
Default bool `json:"default,omitempty"`
Name string `json:"name,omitempty"`
Workspace string `json:"workspace,omitempty"`
Model *AgentModelConfig `json:"model,omitempty"`
Skills []string `json:"skills,omitempty"`
Subagents *SubagentsConfig `json:"subagents,omitempty"`
}
type SubagentsConfig struct {
AllowAgents []string `json:"allow_agents,omitempty"`
Model *AgentModelConfig `json:"model,omitempty"`
}
type PeerMatch struct {
Kind string `json:"kind"`
ID string `json:"id"`
}
type BindingMatch struct {
Channel string `json:"channel"`
AccountID string `json:"account_id,omitempty"`
Peer *PeerMatch `json:"peer,omitempty"`
GuildID string `json:"guild_id,omitempty"`
TeamID string `json:"team_id,omitempty"`
}
type AgentBinding struct {
AgentID string `json:"agent_id"`
Match BindingMatch `json:"match"`
}
type SessionConfig struct {
DMScope string `json:"dm_scope,omitempty"`
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
}
type ChannelsConfig struct {
@@ -167,19 +250,19 @@ type DevicesConfig struct {
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Ollama ProviderConfig `json:"ollama"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
Anthropic ProviderConfig `json:"anthropic"`
OpenAI OpenAIProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Ollama ProviderConfig `json:"ollama"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
}
type ProviderConfig struct {
@@ -190,6 +273,11 @@ type ProviderConfig struct {
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
}
type OpenAIProviderConfig struct {
ProviderConfig
WebSearch bool `json:"web_search" env:"PICOCLAW_PROVIDERS_OPENAI_WEB_SEARCH"`
}
type GatewayConfig struct {
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
@@ -206,14 +294,32 @@ type DuckDuckGoConfig struct {
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
}
type PerplexityConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
type WebToolsConfig struct {
Brave BraveConfig `json:"brave"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
}
type CronToolsConfig struct {
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
}
type ExecConfig struct {
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
}
type ToolsConfig struct {
Web WebToolsConfig `json:"web"`
MCP MCPConfig `json:"mcp"`
Web WebToolsConfig `json:"web"`
Cron CronToolsConfig `json:"cron"`
Exec ExecConfig `json:"exec"`
MCP MCPConfig `json:"mcp"`
}
// MCPServerConfig defines configuration for a single MCP server
@@ -325,7 +431,7 @@ func DefaultConfig() *Config {
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenAI: OpenAIProviderConfig{WebSearch: true},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
@@ -350,6 +456,17 @@ func DefaultConfig() *Config {
Enabled: true,
MaxResults: 5,
},
Perplexity: PerplexityConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
},
},
Cron: CronToolsConfig{
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
},
Exec: ExecConfig{
EnableDenyPatterns: true,
},
MCP: MCPConfig{
Enabled: false,
@@ -460,6 +577,32 @@ func (c *Config) GetAPIBase() string {
return ""
}
// ModelConfig holds primary model and fallback list.
type ModelConfig struct {
Primary string
Fallbacks []string
}
// GetModelConfig returns the text model configuration with fallbacks.
func (c *Config) GetModelConfig() ModelConfig {
c.mu.RLock()
defer c.mu.RUnlock()
return ModelConfig{
Primary: c.Agents.Defaults.Model,
Fallbacks: c.Agents.Defaults.ModelFallbacks,
}
}
// GetImageModelConfig returns the image model configuration with fallbacks.
func (c *Config) GetImageModelConfig() ModelConfig {
c.mu.RLock()
defer c.mu.RUnlock()
return ModelConfig{
Primary: c.Agents.Defaults.ImageModel,
Fallbacks: c.Agents.Defaults.ImageModelFallbacks,
}
}
func expandHome(path string) string {
if path == "" {
return path
+220 -32
View File
@@ -1,12 +1,193 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"runtime"
"testing"
)
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
var m AgentModelConfig
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
t.Fatalf("unmarshal string: %v", err)
}
if m.Primary != "gpt-4" {
t.Errorf("Primary = %q, want 'gpt-4'", m.Primary)
}
if m.Fallbacks != nil {
t.Errorf("Fallbacks = %v, want nil", m.Fallbacks)
}
}
func TestAgentModelConfig_UnmarshalObject(t *testing.T) {
var m AgentModelConfig
data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}`
if err := json.Unmarshal([]byte(data), &m); err != nil {
t.Fatalf("unmarshal object: %v", err)
}
if m.Primary != "claude-opus" {
t.Errorf("Primary = %q, want 'claude-opus'", m.Primary)
}
if len(m.Fallbacks) != 2 {
t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks))
}
if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" {
t.Errorf("Fallbacks = %v", m.Fallbacks)
}
}
func TestAgentModelConfig_MarshalString(t *testing.T) {
m := AgentModelConfig{Primary: "gpt-4"}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("marshal: %v", err)
}
if string(data) != `"gpt-4"` {
t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data))
}
}
func TestAgentModelConfig_MarshalObject(t *testing.T) {
m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var result map[string]interface{}
json.Unmarshal(data, &result)
if result["primary"] != "claude-opus" {
t.Errorf("primary = %v", result["primary"])
}
}
func TestAgentConfig_FullParse(t *testing.T) {
jsonData := `{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"max_tool_iterations": 20
},
"list": [
{
"id": "sales",
"default": true,
"name": "Sales Bot",
"model": "gpt-4"
},
{
"id": "support",
"name": "Support Bot",
"model": {
"primary": "claude-opus",
"fallbacks": ["haiku"]
},
"subagents": {
"allow_agents": ["sales"]
}
}
]
},
"bindings": [
{
"agent_id": "support",
"match": {
"channel": "telegram",
"account_id": "*",
"peer": {"kind": "direct", "id": "user123"}
}
}
],
"session": {
"dm_scope": "per-peer",
"identity_links": {
"john": ["telegram:123", "discord:john#1234"]
}
}
}`
cfg := DefaultConfig()
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(cfg.Agents.List) != 2 {
t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List))
}
sales := cfg.Agents.List[0]
if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" {
t.Errorf("sales = %+v", sales)
}
if sales.Model == nil || sales.Model.Primary != "gpt-4" {
t.Errorf("sales.Model = %+v", sales.Model)
}
support := cfg.Agents.List[1]
if support.ID != "support" || support.Name != "Support Bot" {
t.Errorf("support = %+v", support)
}
if support.Model == nil || support.Model.Primary != "claude-opus" {
t.Errorf("support.Model = %+v", support.Model)
}
if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" {
t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks)
}
if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 {
t.Errorf("support.Subagents = %+v", support.Subagents)
}
if len(cfg.Bindings) != 1 {
t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings))
}
binding := cfg.Bindings[0]
if binding.AgentID != "support" || binding.Match.Channel != "telegram" {
t.Errorf("binding = %+v", binding)
}
if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" {
t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer)
}
if cfg.Session.DMScope != "per-peer" {
t.Errorf("Session.DMScope = %q", cfg.Session.DMScope)
}
if len(cfg.Session.IdentityLinks) != 1 {
t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks)
}
links := cfg.Session.IdentityLinks["john"]
if len(links) != 2 {
t.Errorf("john links = %v", links)
}
}
func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
jsonData := `{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"max_tool_iterations": 20
}
}
}`
cfg := DefaultConfig()
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(cfg.Agents.List) != 0 {
t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List))
}
if len(cfg.Bindings) != 0 {
t.Errorf("bindings should be empty, got %d", len(cfg.Bindings))
}
}
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
@@ -20,8 +201,6 @@ func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
func TestDefaultConfig_WorkspacePath(t *testing.T) {
cfg := DefaultConfig()
// Just verify the workspace is set, don't compare exact paths
// since expandHome behavior may differ based on environment
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
@@ -79,7 +258,6 @@ func TestDefaultConfig_Gateway(t *testing.T) {
func TestDefaultConfig_Providers(t *testing.T) {
cfg := DefaultConfig()
// Verify all providers are empty by default
if cfg.Providers.Anthropic.APIKey != "" {
t.Error("Anthropic API key should be empty by default")
}
@@ -89,46 +267,18 @@ func TestDefaultConfig_Providers(t *testing.T) {
if cfg.Providers.OpenRouter.APIKey != "" {
t.Error("OpenRouter API key should be empty by default")
}
if cfg.Providers.Groq.APIKey != "" {
t.Error("Groq API key should be empty by default")
}
if cfg.Providers.Zhipu.APIKey != "" {
t.Error("Zhipu API key should be empty by default")
}
if cfg.Providers.VLLM.APIKey != "" {
t.Error("VLLM API key should be empty by default")
}
if cfg.Providers.Gemini.APIKey != "" {
t.Error("Gemini API key should be empty by default")
}
}
// TestDefaultConfig_Channels verifies channels are disabled by default
func TestDefaultConfig_Channels(t *testing.T) {
cfg := DefaultConfig()
// Verify all channels are disabled by default
if cfg.Channels.WhatsApp.Enabled {
t.Error("WhatsApp should be disabled by default")
}
if cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be disabled by default")
}
if cfg.Channels.Feishu.Enabled {
t.Error("Feishu should be disabled by default")
}
if cfg.Channels.Discord.Enabled {
t.Error("Discord should be disabled by default")
}
if cfg.Channels.MaixCam.Enabled {
t.Error("MaixCam should be disabled by default")
}
if cfg.Channels.QQ.Enabled {
t.Error("QQ should be disabled by default")
}
if cfg.Channels.DingTalk.Enabled {
t.Error("DingTalk should be disabled by default")
}
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
@@ -178,7 +328,6 @@ func TestSaveConfig_FilePermissions(t *testing.T) {
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
// Verify complete config structure
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
@@ -204,3 +353,42 @@ func TestConfig_Complete(t *testing.T) {
t.Error("Heartbeat should be enabled by default")
}
}
func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Providers.OpenAI.WebSearch {
t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true")
}
}
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if !cfg.Providers.OpenAI.WebSearch {
t.Fatal("OpenAI codex web search should remain true when unset in config file")
}
}
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if cfg.Providers.OpenAI.WebSearch {
t.Fatal("OpenAI codex web search should be false when disabled in config file")
}
}
+7 -6
View File
@@ -1,15 +1,16 @@
// Package constants provides shared constants across the codebase.
package constants
// InternalChannels defines channels that are used for internal communication
// internalChannels defines channels that are used for internal communication
// and should not be exposed to external users or recorded as last active channel.
var InternalChannels = map[string]bool{
"cli": true,
"system": true,
"subagent": true,
var internalChannels = map[string]struct{}{
"cli": {},
"system": {},
"subagent": {},
}
// IsInternalChannel returns true if the channel is an internal channel.
func IsInternalChannel(channel string) bool {
return InternalChannels[channel]
_, found := internalChannels[channel]
return found
}
+11 -1
View File
@@ -108,7 +108,10 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
case "anthropic":
cfg.Providers.Anthropic = pc
case "openai":
cfg.Providers.OpenAI = pc
cfg.Providers.OpenAI = config.OpenAIProviderConfig{
ProviderConfig: pc,
WebSearch: getBoolOrDefault(pMap, "web_search", true),
}
case "openrouter":
cfg.Providers.OpenRouter = pc
case "groq":
@@ -363,6 +366,13 @@ func getBool(data map[string]interface{}, key string) (bool, bool) {
return b, ok
}
func getBoolOrDefault(data map[string]interface{}, key string, defaultVal bool) bool {
if v, ok := getBool(data, key); ok {
return v
}
return defaultVal
}
func getStringSlice(data map[string]interface{}, key string) []string {
v, ok := data[key]
if !ok {
+18
View File
@@ -299,6 +299,24 @@ func TestConvertConfig(t *testing.T) {
})
}
func TestSupportedProvidersCompatibility(t *testing.T) {
expected := []string{
"anthropic",
"openai",
"openrouter",
"groq",
"zhipu",
"vllm",
"gemini",
}
for _, provider := range expected {
if !supportedProviders[provider] {
t.Fatalf("supportedProviders missing expected key %q", provider)
}
}
}
func TestMergeConfig(t *testing.T) {
t.Run("fills empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
+248
View File
@@ -0,0 +1,248 @@
package anthropicprovider
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
const defaultBaseURL = "https://api.anthropic.com"
type Provider struct {
client *anthropic.Client
tokenSource func() (string, error)
baseURL string
}
func NewProvider(token string) *Provider {
return NewProviderWithBaseURL(token, "")
}
func NewProviderWithBaseURL(token, apiBase string) *Provider {
baseURL := normalizeBaseURL(apiBase)
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL(baseURL),
)
return &Provider{
client: &client,
baseURL: baseURL,
}
}
func NewProviderWithClient(client *anthropic.Client) *Provider {
return &Provider{
client: client,
baseURL: defaultBaseURL,
}
}
func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider {
return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "")
}
func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider {
p := NewProviderWithBaseURL(token, apiBase)
p.tokenSource = tokenSource
return p
}
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildParams(messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseResponse(resp), nil
}
func (p *Provider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func (p *Provider) BaseURL() string {
return p.baseURL
}
func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateTools(tools)
}
return params, nil
}
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
}
func normalizeBaseURL(apiBase string) string {
base := strings.TrimSpace(apiBase)
if base == "" {
return defaultBaseURL
}
base = strings.TrimRight(base, "/")
if strings.HasSuffix(base, "/v1") {
base = strings.TrimSuffix(base, "/v1")
}
if base == "" {
return defaultBaseURL
}
return base
}
+265
View File
@@ -0,0 +1,265 @@
package anthropicprovider
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
)
func TestBuildParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "Hello! How can I help you?"},
},
"usage": map[string]interface{}{
"input_tokens": 15,
"output_tokens": 8,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello! How can I help you?" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.PromptTokens != 15 {
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
}
}
func TestProvider_GetDefaultModel(t *testing.T) {
p := NewProvider("test-token")
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
}
}
func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) {
p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/")
if got := p.BaseURL(); got != "https://api.anthropic.com" {
t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com")
}
}
func TestProvider_ChatUsesTokenSource(t *testing.T) {
var requests int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
atomic.AddInt32(&requests, 1)
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "ok"},
},
"usage": map[string]interface{}{
"input_tokens": 1,
"output_tokens": 1,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
return "refreshed-token", nil
}, server.URL)
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if got := atomic.LoadInt32(&requests); got != 1 {
t.Fatalf("requests = %d, want 1", got)
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
anthropicoption.WithBaseURL(baseURL),
)
return &c
}
+28 -170
View File
@@ -2,200 +2,58 @@ package providers
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/auth"
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
type ClaudeProvider struct {
client *anthropic.Client
tokenSource func() (string, error)
delegate *anthropicprovider.Provider
}
func NewClaudeProvider(token string) *ClaudeProvider {
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
)
return &ClaudeProvider{client: &client}
return &ClaudeProvider{
delegate: anthropicprovider.NewProvider(token),
}
}
func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase),
}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
p := NewClaudeProvider(token)
p.tokenSource = tokenSource
return p
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource),
}
}
func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
}
}
func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider {
return &ClaudeProvider{delegate: delegate}
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildClaudeParams(messages, tools, model, options)
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseClaudeResponse(resp), nil
return resp, nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForClaude(tools)
}
return params, nil
}
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
return p.delegate.GetDefaultModel()
}
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
cred, err := getCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
+3 -134
View File
@@ -8,140 +8,9 @@ import (
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildClaudeParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseClaudeResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseClaudeResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseClaudeResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseClaudeResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
@@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
}))
defer server.Close()
provider := NewClaudeProvider("test-token")
provider.client = createAnthropicTestClient(server.URL, "test-token")
delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
provider := newClaudeProviderWithDelegate(delegate)
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
@@ -0,0 +1,119 @@
//go:build integration
package providers
import (
"context"
exec "os/exec"
"strings"
"testing"
"time"
)
// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI.
// Run with: go test -tags=integration ./pkg/providers/...
func TestIntegration_RealCodexCLI(t *testing.T) {
path, err := exec.LookPath("codex")
if err != nil {
t.Skip("codex CLI not found in PATH, skipping integration test")
}
t.Logf("Using codex CLI at: %s", path)
p := NewCodexCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage != nil {
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) {
if _, err := exec.LookPath("codex"); err != nil {
t.Skip("codex CLI not found in PATH")
}
p := NewCodexCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) {
if _, err := exec.LookPath("codex"); err != nil {
t.Skip("codex CLI not found in PATH")
}
// Run codex directly and verify our parser handles real output
cmd := exec.Command("codex", "exec",
"--json",
"--dangerously-bypass-approvals-and-sandbox",
"--skip-git-repo-check",
"--color", "never",
"-C", t.TempDir(),
"-")
cmd.Stdin = strings.NewReader("Say hi")
output, err := cmd.Output()
if err != nil {
// codex may write diagnostic noise to stderr but still produce valid output
if len(output) == 0 {
t.Fatalf("codex CLI failed: %v", err)
}
}
t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)]))
// Verify our parser can handle real output
p := NewCodexCliProvider("")
resp, err := p.parseJSONLEvents(string(output))
if err != nil {
t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}
+59 -18
View File
@@ -18,9 +18,10 @@ const codexDefaultModel = "gpt-5.2"
const codexDefaultInstructions = "You are Codex, a coding assistant."
type CodexProvider struct {
client *openai.Client
accountID string
tokenSource func() (string, string, error)
client *openai.Client
accountID string
tokenSource func() (string, string, error)
enableWebSearch bool
}
const defaultCodexInstructions = "You are Codex, a coding assistant."
@@ -37,8 +38,9 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
}
client := openai.NewClient(opts...)
return &CodexProvider{
client: &client,
accountID: accountID,
client: &client,
accountID: accountID,
enableWebSearch: true,
}
}
@@ -78,7 +80,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
})
}
params := buildCodexParams(messages, tools, resolvedModel, options)
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
defer stream.Close()
@@ -182,7 +184,7 @@ func resolveCodexModel(model string) (string, string) {
return codexDefaultModel, "unsupported model family"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
@@ -217,12 +219,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
})
}
for _, tc := range msg.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
name, args, ok := resolveCodexToolCall(tc)
if !ok {
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{
"call_id": tc.ID,
})
continue
}
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: tc.Name,
Arguments: string(argsJSON),
Name: name,
Arguments: args,
},
})
}
@@ -260,20 +268,50 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
params.Instructions = openai.Opt(defaultCodexInstructions)
}
if maxTokens, ok := options["max_tokens"].(int); ok {
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
if len(tools) > 0 || enableWebSearch {
params.Tools = translateToolsForCodex(tools, enableWebSearch)
}
return params
}
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
result := make([]responses.ToolUnionParam, 0, len(tools))
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
name = tc.Name
if name == "" && tc.Function != nil {
name = tc.Function.Name
}
if name == "" {
return "", "", false
}
if len(tc.Arguments) > 0 {
argsJSON, err := json.Marshal(tc.Arguments)
if err != nil {
return "", "", false
}
return name, string(argsJSON), true
}
if tc.Function != nil && tc.Function.Arguments != "" {
return name, tc.Function.Arguments, true
}
return name, "{}", true
}
func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
capHint := len(tools)
if enableWebSearch {
capHint++
}
result := make([]responses.ToolUnionParam, 0, capHint)
for _, t := range tools {
if t.Type != "function" {
continue
}
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
continue
}
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
@@ -284,6 +322,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
if enableWebSearch {
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
}
return result
}
+172 -5
View File
@@ -19,7 +19,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
"temperature": 0.7,
})
}, true)
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
}
@@ -29,6 +29,9 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
if params.Instructions.Or("") != defaultCodexInstructions {
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
}
if params.MaxOutputTokens.Valid() {
t.Fatalf("MaxOutputTokens should not be set for Codex backend")
}
}
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
@@ -36,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true)
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
@@ -56,7 +59,7 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
@@ -65,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
}
}
func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Read a file"},
{
Role: "assistant",
ToolCalls: []ToolCall{
{
ID: "call_1",
Type: "function",
Function: &FunctionCall{
Name: "read_file",
Arguments: `{"path":"README.md"}`,
},
},
},
},
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
if len(params.Input.OfInputItemList) != 3 {
t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
}
fc := params.Input.OfInputItemList[1].OfFunctionCall
if fc == nil {
t.Fatal("assistant tool call should be converted to function_call input item")
}
if fc.Name != "read_file" {
t.Errorf("Function call name = %q, want %q", fc.Name, "read_file")
}
if fc.Arguments != `{"path":"README.md"}` {
t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`)
}
}
func TestBuildCodexParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
@@ -81,7 +123,7 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false)
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
@@ -94,12 +136,61 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
}
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false)
if !params.Store.Valid() || params.Store.Or(true) != false {
t.Error("Store should be explicitly set to false")
}
}
func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true)
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
if params.Tools[0].OfWebSearch == nil {
t.Fatal("Tool should include built-in web_search")
}
if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch {
t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch)
}
}
func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "web_search",
Description: "local web search",
Parameters: map[string]interface{}{
"type": "object",
},
},
},
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "read_file",
Description: "read file",
Parameters: map[string]interface{}{
"type": "object",
},
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true)
if len(params.Tools) != 2 {
t.Fatalf("len(Tools) = %d, want 2", len(params.Tools))
}
if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" {
t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0])
}
if params.Tools[1].OfWebSearch == nil {
t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1])
}
}
func TestParseCodexResponse_TextOutput(t *testing.T) {
respJSON := `{
"id": "resp_test",
@@ -214,6 +305,20 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
if _, ok := reqBody["max_output_tokens"]; ok {
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
return
}
toolsAny, ok := reqBody["tools"].([]interface{})
if !ok || len(toolsAny) != 1 {
http.Error(w, "missing default web search tool", http.StatusBadRequest)
return
}
toolObj, ok := toolsAny[0].(map[string]interface{})
if !ok || toolObj["type"] != "web_search" {
http.Error(w, "expected web_search tool", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
@@ -261,6 +366,64 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
}
}
func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if _, ok := reqBody["tools"]; ok {
http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 4,
"output_tokens": 3,
"total_tokens": 7,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.enableWebSearch = false
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
@@ -293,6 +456,10 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
http.Error(w, "temperature is not supported", http.StatusBadRequest)
return
}
if _, ok := reqBody["max_output_tokens"]; ok {
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
+207
View File
@@ -0,0 +1,207 @@
package providers
import (
"math"
"sync"
"time"
)
const (
defaultFailureWindow = 24 * time.Hour
)
// CooldownTracker manages per-provider cooldown state for the fallback chain.
// Thread-safe via sync.RWMutex. In-memory only (resets on restart).
type CooldownTracker struct {
mu sync.RWMutex
entries map[string]*cooldownEntry
failureWindow time.Duration
nowFunc func() time.Time // for testing
}
type cooldownEntry struct {
ErrorCount int
FailureCounts map[FailoverReason]int
CooldownEnd time.Time // standard cooldown expiry
DisabledUntil time.Time // billing-specific disable expiry
DisabledReason FailoverReason // reason for disable (billing)
LastFailure time.Time
}
// NewCooldownTracker creates a tracker with default 24h failure window.
func NewCooldownTracker() *CooldownTracker {
return &CooldownTracker{
entries: make(map[string]*cooldownEntry),
failureWindow: defaultFailureWindow,
nowFunc: time.Now,
}
}
// MarkFailure records a failure for a provider and sets appropriate cooldown.
// Resets error counts if last failure was more than failureWindow ago.
func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) {
ct.mu.Lock()
defer ct.mu.Unlock()
now := ct.nowFunc()
entry := ct.getOrCreate(provider)
// 24h failure window reset: if no failure in failureWindow, reset counters.
if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow {
entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
}
entry.ErrorCount++
entry.FailureCounts[reason]++
entry.LastFailure = now
if reason == FailoverBilling {
billingCount := entry.FailureCounts[FailoverBilling]
entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount))
entry.DisabledReason = FailoverBilling
} else {
entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount))
}
}
// MarkSuccess resets all counters and cooldowns for a provider.
func (ct *CooldownTracker) MarkSuccess(provider string) {
ct.mu.Lock()
defer ct.mu.Unlock()
entry := ct.entries[provider]
if entry == nil {
return
}
entry.ErrorCount = 0
entry.FailureCounts = make(map[FailoverReason]int)
entry.CooldownEnd = time.Time{}
entry.DisabledUntil = time.Time{}
entry.DisabledReason = ""
}
// IsAvailable returns true if the provider is not in cooldown or disabled.
func (ct *CooldownTracker) IsAvailable(provider string) bool {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return true
}
now := ct.nowFunc()
// Billing disable takes precedence (longer cooldown).
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
return false
}
// Standard cooldown.
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
return false
}
return true
}
// CooldownRemaining returns how long until the provider becomes available.
// Returns 0 if already available.
func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
now := ct.nowFunc()
var remaining time.Duration
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
d := entry.DisabledUntil.Sub(now)
if d > remaining {
remaining = d
}
}
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
d := entry.CooldownEnd.Sub(now)
if d > remaining {
remaining = d
}
}
return remaining
}
// ErrorCount returns the current error count for a provider.
func (ct *CooldownTracker) ErrorCount(provider string) int {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.ErrorCount
}
// FailureCount returns the failure count for a specific reason.
func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int {
ct.mu.RLock()
defer ct.mu.RUnlock()
entry := ct.entries[provider]
if entry == nil {
return 0
}
return entry.FailureCounts[reason]
}
func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry {
entry := ct.entries[provider]
if entry == nil {
entry = &cooldownEntry{
FailureCounts: make(map[FailoverReason]int),
}
ct.entries[provider] = entry
}
return entry
}
// calculateStandardCooldown computes standard exponential backoff.
// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3))
//
// 1 error → 1 min
// 2 errors → 5 min
// 3 errors → 25 min
// 4+ errors → 1 hour (cap)
func calculateStandardCooldown(errorCount int) time.Duration {
n := max(1, errorCount)
exp := min(n-1, 3)
ms := 60_000 * int(math.Pow(5, float64(exp)))
ms = min(3_600_000, ms) // cap at 1 hour
return time.Duration(ms) * time.Millisecond
}
// calculateBillingCooldown computes billing-specific exponential backoff.
// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10))
//
// 1 error → 5 hours
// 2 errors → 10 hours
// 3 errors → 20 hours
// 4+ errors → 24 hours (cap)
func calculateBillingCooldown(billingErrorCount int) time.Duration {
const baseMs = 5 * 60 * 60 * 1000 // 5 hours
const maxMs = 24 * 60 * 60 * 1000 // 24 hours
n := max(1, billingErrorCount)
exp := min(n-1, 10)
raw := float64(baseMs) * math.Pow(2, float64(exp))
ms := int(math.Min(float64(maxMs), raw))
return time.Duration(ms) * time.Millisecond
}
+269
View File
@@ -0,0 +1,269 @@
package providers
import (
"sync"
"testing"
"time"
)
func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) {
current := now
ct := NewCooldownTracker()
ct.nowFunc = func() time.Time { return current }
return ct, &current
}
func TestCooldown_InitiallyAvailable(t *testing.T) {
ct := NewCooldownTracker()
if !ct.IsAvailable("openai") {
t.Error("new provider should be available")
}
if ct.ErrorCount("openai") != 0 {
t.Error("new provider should have 0 errors")
}
}
func TestCooldown_StandardEscalation(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 1st error → 1 min cooldown
ct.MarkFailure("openai", FailoverRateLimit)
if ct.IsAvailable("openai") {
t.Error("should be in cooldown after 1st error")
}
// Advance 61 seconds → available
*current = now.Add(61 * time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after 1 min cooldown")
}
// 2nd error → 5 min cooldown
ct.MarkFailure("openai", FailoverRateLimit)
*current = now.Add(61*time.Second + 4*time.Minute)
if ct.IsAvailable("openai") {
t.Error("should be in cooldown (5 min) after 2nd error")
}
*current = now.Add(61*time.Second + 6*time.Minute)
if !ct.IsAvailable("openai") {
t.Error("should be available after 5 min cooldown")
}
}
func TestCooldown_StandardCap(t *testing.T) {
// Verify formula: 1m, 5m, 25m, 1h, 1h, 1h...
expected := []time.Duration{
1 * time.Minute,
5 * time.Minute,
25 * time.Minute,
1 * time.Hour,
1 * time.Hour,
}
for i, want := range expected {
got := calculateStandardCooldown(i + 1)
if got != want {
t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want)
}
}
}
func TestCooldown_BillingEscalation(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 1st billing error → 5h cooldown
ct.MarkFailure("openai", FailoverBilling)
if ct.IsAvailable("openai") {
t.Error("should be disabled after billing error")
}
// Advance 4h → still disabled
*current = now.Add(4 * time.Hour)
if ct.IsAvailable("openai") {
t.Error("should still be disabled (5h cooldown)")
}
// Advance 5h + 1s → available
*current = now.Add(5*time.Hour + 1*time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after 5h billing cooldown")
}
}
func TestCooldown_BillingCap(t *testing.T) {
expected := []time.Duration{
5 * time.Hour,
10 * time.Hour,
20 * time.Hour,
24 * time.Hour,
24 * time.Hour,
}
for i, want := range expected {
got := calculateBillingCooldown(i + 1)
if got != want {
t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want)
}
}
}
func TestCooldown_SuccessReset(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverBilling)
if ct.ErrorCount("openai") != 2 {
t.Errorf("error count = %d, want 2", ct.ErrorCount("openai"))
}
ct.MarkSuccess("openai")
if ct.ErrorCount("openai") != 0 {
t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai"))
}
if !ct.IsAvailable("openai") {
t.Error("should be available after success")
}
if ct.FailureCount("openai", FailoverRateLimit) != 0 {
t.Error("failure counts should be reset after success")
}
if ct.FailureCount("openai", FailoverBilling) != 0 {
t.Error("billing failure count should be reset after success")
}
}
func TestCooldown_FailureWindowReset(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// 4 errors → 1h cooldown
for i := 0; i < 4; i++ {
ct.MarkFailure("openai", FailoverRateLimit)
*current = current.Add(2 * time.Second) // small advance between errors
}
if ct.ErrorCount("openai") != 4 {
t.Errorf("error count = %d, want 4", ct.ErrorCount("openai"))
}
// Advance 25 hours (past 24h failure window)
*current = now.Add(25 * time.Hour)
// Next error should reset counters first, then increment to 1
ct.MarkFailure("openai", FailoverRateLimit)
if ct.ErrorCount("openai") != 1 {
t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai"))
}
}
func TestCooldown_PerReasonTracking(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("openai", FailoverBilling)
ct.MarkFailure("openai", FailoverAuth)
if ct.FailureCount("openai", FailoverRateLimit) != 2 {
t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit))
}
if ct.FailureCount("openai", FailoverBilling) != 1 {
t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling))
}
if ct.FailureCount("openai", FailoverAuth) != 1 {
t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth))
}
if ct.ErrorCount("openai") != 4 {
t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai"))
}
}
func TestCooldown_BillingTakesPrecedence(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// Standard cooldown (1 min) + billing disable (5h)
ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown
ct.MarkFailure("openai", FailoverBilling) // 5h disable
// After 2 min: standard cooldown expired but billing still active
*current = now.Add(2 * time.Minute)
if ct.IsAvailable("openai") {
t.Error("billing disable should take precedence over standard cooldown")
}
// After 5h + 1s: both expired
*current = now.Add(5*time.Hour + 1*time.Second)
if !ct.IsAvailable("openai") {
t.Error("should be available after all cooldowns expire")
}
}
func TestCooldown_CooldownRemaining(t *testing.T) {
now := time.Now()
ct, current := newTestTracker(now)
// No failures → 0 remaining
if ct.CooldownRemaining("openai") != 0 {
t.Error("expected 0 remaining for new provider")
}
ct.MarkFailure("openai", FailoverRateLimit)
*current = now.Add(30 * time.Second)
remaining := ct.CooldownRemaining("openai")
if remaining <= 0 || remaining > 1*time.Minute {
t.Errorf("remaining = %v, expected ~30s", remaining)
}
}
func TestCooldown_SuccessOnUnknownProvider(t *testing.T) {
ct := NewCooldownTracker()
// Should not panic
ct.MarkSuccess("nonexistent")
if !ct.IsAvailable("nonexistent") {
t.Error("nonexistent provider should be available")
}
}
func TestCooldown_ConcurrentAccess(t *testing.T) {
ct := NewCooldownTracker()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(3)
go func() {
defer wg.Done()
ct.MarkFailure("openai", FailoverRateLimit)
}()
go func() {
defer wg.Done()
ct.IsAvailable("openai")
}()
go func() {
defer wg.Done()
ct.MarkSuccess("openai")
}()
}
wg.Wait()
// If we got here without panic, concurrent access is safe
}
func TestCooldown_MultipleProviders(t *testing.T) {
ct := NewCooldownTracker()
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("anthropic", FailoverBilling)
if ct.IsAvailable("openai") {
t.Error("openai should be in cooldown")
}
if ct.IsAvailable("anthropic") {
t.Error("anthropic should be in cooldown")
}
// groq was never touched
if !ct.IsAvailable("groq") {
t.Error("groq should be available")
}
}
+253
View File
@@ -0,0 +1,253 @@
package providers
import (
"context"
"regexp"
"strings"
)
// errorPattern defines a single pattern (string or regex) for error classification.
type errorPattern struct {
substring string
regex *regexp.Regexp
}
func substr(s string) errorPattern { return errorPattern{substring: s} }
func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} }
// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns).
var (
rateLimitPatterns = []errorPattern{
rxp(`rate[_ ]limit`),
substr("too many requests"),
substr("429"),
substr("exceeded your current quota"),
rxp(`exceeded.*quota`),
rxp(`resource has been exhausted`),
rxp(`resource.*exhausted`),
substr("resource_exhausted"),
substr("quota exceeded"),
substr("usage limit"),
}
overloadedPatterns = []errorPattern{
rxp(`overloaded_error`),
rxp(`"type"\s*:\s*"overloaded_error"`),
substr("overloaded"),
}
timeoutPatterns = []errorPattern{
substr("timeout"),
substr("timed out"),
substr("deadline exceeded"),
substr("context deadline exceeded"),
}
billingPatterns = []errorPattern{
rxp(`\b402\b`),
substr("payment required"),
substr("insufficient credits"),
substr("credit balance"),
substr("plans & billing"),
substr("insufficient balance"),
}
authPatterns = []errorPattern{
rxp(`invalid[_ ]?api[_ ]?key`),
substr("incorrect api key"),
substr("invalid token"),
substr("authentication"),
substr("re-authenticate"),
substr("oauth token refresh failed"),
substr("unauthorized"),
substr("forbidden"),
substr("access denied"),
substr("expired"),
substr("token has expired"),
rxp(`\b401\b`),
rxp(`\b403\b`),
substr("no credentials found"),
substr("no api key found"),
}
formatPatterns = []errorPattern{
substr("string should match pattern"),
substr("tool_use.id"),
substr("tool_use_id"),
substr("messages.1.content.1.tool_use.id"),
substr("invalid request format"),
}
imageDimensionPatterns = []errorPattern{
rxp(`image dimensions exceed max`),
}
imageSizePatterns = []errorPattern{
rxp(`image exceeds.*mb`),
}
// Transient HTTP status codes that map to timeout (server-side failures).
transientStatusCodes = map[int]bool{
500: true, 502: true, 503: true,
521: true, 522: true, 523: true, 524: true,
529: true,
}
)
// ClassifyError classifies an error into a FailoverError with reason.
// Returns nil if the error is not classifiable (unknown errors should not trigger fallback).
func ClassifyError(err error, provider, model string) *FailoverError {
if err == nil {
return nil
}
// Context cancellation: user abort, never fallback.
if err == context.Canceled {
return nil
}
// Context deadline exceeded: treat as timeout, always fallback.
if err == context.DeadlineExceeded {
return &FailoverError{
Reason: FailoverTimeout,
Provider: provider,
Model: model,
Wrapped: err,
}
}
msg := strings.ToLower(err.Error())
// Image dimension/size errors: non-retriable, non-fallback.
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
return &FailoverError{
Reason: FailoverFormat,
Provider: provider,
Model: model,
Wrapped: err,
}
}
// Try HTTP status code extraction first.
if status := extractHTTPStatus(msg); status > 0 {
if reason := classifyByStatus(status); reason != "" {
return &FailoverError{
Reason: reason,
Provider: provider,
Model: model,
Status: status,
Wrapped: err,
}
}
}
// Message pattern matching (priority order from OpenClaw).
if reason := classifyByMessage(msg); reason != "" {
return &FailoverError{
Reason: reason,
Provider: provider,
Model: model,
Wrapped: err,
}
}
return nil
}
// classifyByStatus maps HTTP status codes to FailoverReason.
func classifyByStatus(status int) FailoverReason {
switch {
case status == 401 || status == 403:
return FailoverAuth
case status == 402:
return FailoverBilling
case status == 408:
return FailoverTimeout
case status == 429:
return FailoverRateLimit
case status == 400:
return FailoverFormat
case transientStatusCodes[status]:
return FailoverTimeout
}
return ""
}
// classifyByMessage matches error messages against patterns.
// Priority order matters (from OpenClaw classifyFailoverReason).
func classifyByMessage(msg string) FailoverReason {
if matchesAny(msg, rateLimitPatterns) {
return FailoverRateLimit
}
if matchesAny(msg, overloadedPatterns) {
return FailoverRateLimit // Overloaded treated as rate_limit
}
if matchesAny(msg, billingPatterns) {
return FailoverBilling
}
if matchesAny(msg, timeoutPatterns) {
return FailoverTimeout
}
if matchesAny(msg, authPatterns) {
return FailoverAuth
}
if matchesAny(msg, formatPatterns) {
return FailoverFormat
}
return ""
}
// extractHTTPStatus extracts an HTTP status code from an error message.
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
func extractHTTPStatus(msg string) int {
// Common patterns in Go HTTP error messages
patterns := []*regexp.Regexp{
regexp.MustCompile(`status[:\s]+(\d{3})`),
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
}
for _, p := range patterns {
if m := p.FindStringSubmatch(msg); len(m) > 1 {
return parseDigits(m[1])
}
}
return 0
}
// IsImageDimensionError returns true if the message indicates an image dimension error.
func IsImageDimensionError(msg string) bool {
return matchesAny(msg, imageDimensionPatterns)
}
// IsImageSizeError returns true if the message indicates an image file size error.
func IsImageSizeError(msg string) bool {
return matchesAny(msg, imageSizePatterns)
}
// matchesAny checks if msg matches any of the patterns.
func matchesAny(msg string, patterns []errorPattern) bool {
for _, p := range patterns {
if p.regex != nil {
if p.regex.MatchString(msg) {
return true
}
} else if p.substring != "" {
if strings.Contains(msg, p.substring) {
return true
}
}
}
return false
}
// parseDigits converts a string of digits to an int.
func parseDigits(s string) int {
n := 0
for _, c := range s {
if c >= '0' && c <= '9' {
n = n*10 + int(c-'0')
}
}
return n
}
+337
View File
@@ -0,0 +1,337 @@
package providers
import (
"context"
"errors"
"fmt"
"testing"
)
func TestClassifyError_Nil(t *testing.T) {
result := ClassifyError(nil, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for nil error, got %+v", result)
}
}
func TestClassifyError_ContextCanceled(t *testing.T) {
result := ClassifyError(context.Canceled, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for context.Canceled (user abort), got %+v", result)
}
}
func TestClassifyError_ContextDeadlineExceeded(t *testing.T) {
result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4")
if result == nil {
t.Fatal("expected non-nil for deadline exceeded")
}
if result.Reason != FailoverTimeout {
t.Errorf("reason = %q, want timeout", result.Reason)
}
}
func TestClassifyError_StatusCodes(t *testing.T) {
tests := []struct {
status int
reason FailoverReason
}{
{401, FailoverAuth},
{403, FailoverAuth},
{402, FailoverBilling},
{408, FailoverTimeout},
{429, FailoverRateLimit},
{400, FailoverFormat},
{500, FailoverTimeout},
{502, FailoverTimeout},
{503, FailoverTimeout},
{521, FailoverTimeout},
{522, FailoverTimeout},
{523, FailoverTimeout},
{524, FailoverTimeout},
{529, FailoverTimeout},
}
for _, tt := range tests {
err := fmt.Errorf("API error: status: %d something went wrong", tt.status)
result := ClassifyError(err, "test", "model")
if result == nil {
t.Errorf("status %d: expected non-nil", tt.status)
continue
}
if result.Reason != tt.reason {
t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason)
}
}
}
func TestClassifyError_RateLimitPatterns(t *testing.T) {
patterns := []string{
"rate limit exceeded",
"rate_limit reached",
"too many requests",
"exceeded your current quota",
"resource has been exhausted",
"resource_exhausted",
"quota exceeded",
"usage limit reached",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverRateLimit {
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
}
}
}
func TestClassifyError_OverloadedPatterns(t *testing.T) {
patterns := []string{
"overloaded_error",
`{"type": "overloaded_error"}`,
"server is overloaded",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "anthropic", "claude")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
// Overloaded is treated as rate_limit
if result.Reason != FailoverRateLimit {
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
}
}
}
func TestClassifyError_BillingPatterns(t *testing.T) {
patterns := []string{
"payment required",
"insufficient credits",
"credit balance too low",
"plans & billing page",
"insufficient balance",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverBilling {
t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason)
}
}
}
func TestClassifyError_TimeoutPatterns(t *testing.T) {
patterns := []string{
"request timeout",
"connection timed out",
"deadline exceeded",
"context deadline exceeded",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverTimeout {
t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason)
}
}
}
func TestClassifyError_AuthPatterns(t *testing.T) {
patterns := []string{
"invalid api key",
"invalid_api_key",
"incorrect api key",
"invalid token",
"authentication failed",
"re-authenticate",
"oauth token refresh failed",
"unauthorized access",
"forbidden",
"access denied",
"expired",
"token has expired",
"no credentials found",
"no api key found",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "openai", "gpt-4")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverAuth {
t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason)
}
}
}
func TestClassifyError_FormatPatterns(t *testing.T) {
patterns := []string{
"string should match pattern",
"tool_use.id is required",
"invalid tool_use_id",
"messages.1.content.1.tool_use.id must be valid",
"invalid request format",
}
for _, msg := range patterns {
err := errors.New(msg)
result := ClassifyError(err, "anthropic", "claude")
if result == nil {
t.Errorf("pattern %q: expected non-nil", msg)
continue
}
if result.Reason != FailoverFormat {
t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason)
}
}
}
func TestClassifyError_ImageDimensionError(t *testing.T) {
err := errors.New("image dimensions exceed max allowed 2048x2048")
result := ClassifyError(err, "openai", "gpt-4o")
if result == nil {
t.Fatal("expected non-nil for image dimension error")
}
if result.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", result.Reason)
}
if result.IsRetriable() {
t.Error("image dimension error should not be retriable")
}
}
func TestClassifyError_ImageSizeError(t *testing.T) {
err := errors.New("image exceeds 20 mb limit")
result := ClassifyError(err, "openai", "gpt-4o")
if result == nil {
t.Fatal("expected non-nil for image size error")
}
if result.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", result.Reason)
}
}
func TestClassifyError_UnknownError(t *testing.T) {
err := errors.New("some completely random error")
result := ClassifyError(err, "openai", "gpt-4")
if result != nil {
t.Errorf("expected nil for unknown error, got %+v", result)
}
}
func TestClassifyError_ProviderModelPropagation(t *testing.T) {
err := errors.New("rate limit exceeded")
result := ClassifyError(err, "my-provider", "my-model")
if result == nil {
t.Fatal("expected non-nil")
}
if result.Provider != "my-provider" {
t.Errorf("provider = %q, want my-provider", result.Provider)
}
if result.Model != "my-model" {
t.Errorf("model = %q, want my-model", result.Model)
}
}
func TestFailoverError_IsRetriable(t *testing.T) {
tests := []struct {
reason FailoverReason
retriable bool
}{
{FailoverAuth, true},
{FailoverRateLimit, true},
{FailoverBilling, true},
{FailoverTimeout, true},
{FailoverOverloaded, true},
{FailoverFormat, false},
{FailoverUnknown, true},
}
for _, tt := range tests {
fe := &FailoverError{Reason: tt.reason}
if fe.IsRetriable() != tt.retriable {
t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable)
}
}
}
func TestFailoverError_ErrorString(t *testing.T) {
fe := &FailoverError{
Reason: FailoverRateLimit,
Provider: "openai",
Model: "gpt-4",
Status: 429,
Wrapped: errors.New("too many requests"),
}
s := fe.Error()
if s == "" {
t.Error("expected non-empty error string")
}
}
func TestFailoverError_Unwrap(t *testing.T) {
inner := errors.New("inner error")
fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner}
if fe.Unwrap() != inner {
t.Error("Unwrap should return wrapped error")
}
}
func TestExtractHTTPStatus(t *testing.T) {
tests := []struct {
msg string
want int
}{
{"status: 429 rate limited", 429},
{"status 401 unauthorized", 401},
{"HTTP/1.1 502 Bad Gateway", 502},
{"no status code here", 0},
{"random number 12345", 0},
}
for _, tt := range tests {
got := extractHTTPStatus(tt.msg)
if got != tt.want {
t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want)
}
}
}
func TestIsImageDimensionError(t *testing.T) {
if !IsImageDimensionError("image dimensions exceed max 4096x4096") {
t.Error("should match image dimensions exceed max")
}
if IsImageDimensionError("normal error message") {
t.Error("should not match normal error")
}
}
func TestIsImageSizeError(t *testing.T) {
if !IsImageSizeError("image exceeds 20 mb") {
t.Error("should match image exceeds mb")
}
if IsImageSizeError("normal error message") {
t.Error("should not match normal error")
}
}
+360
View File
@@ -0,0 +1,360 @@
package providers
import (
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
var getCredential = auth.GetCredential
type providerType int
const (
providerTypeHTTPCompat providerType = iota
providerTypeClaudeAuth
providerTypeCodexAuth
providerTypeCodexCLIToken
providerTypeClaudeCLI
providerTypeCodexCLI
providerTypeGitHubCopilot
)
type providerSelection struct {
providerType providerType
apiKey string
apiBase string
proxy string
model string
workspace string
connectMode string
enableWebSearch bool
}
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
if apiBase == "" {
apiBase = defaultAnthropicAPIBase
}
cred, err := getCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
}
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
cred, err := getCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
p.enableWebSearch = enableWebSearch
return p, nil
}
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
lowerModel := strings.ToLower(model)
sel := providerSelection{
providerType: providerTypeHTTPCompat,
model: model,
}
// First, prefer explicit provider configuration.
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "nvidia":
if cfg.Providers.Nvidia.APIKey != "" {
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
}
case "claude-cli", "claude-code", "claudecode":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeClaudeCLI
sel.workspace = workspace
return sel, nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
sel.providerType = providerTypeCodexCLI
sel.workspace = workspace
return sel, nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
sel.apiKey = cfg.Providers.DeepSeek.APIKey
sel.apiBase = cfg.Providers.DeepSeek.APIBase
sel.proxy = cfg.Providers.DeepSeek.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
sel.model = "deepseek-chat"
}
}
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
sel.apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
sel.apiBase = "localhost:4321"
}
sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode
return sel, nil
}
}
// Fallback: infer provider from model and configured keys.
if sel.apiKey == "" && sel.apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
sel.apiKey = cfg.Providers.Moonshot.APIKey
sel.apiBase = cfg.Providers.Moonshot.APIBase
sel.proxy = cfg.Providers.Moonshot.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") ||
strings.HasPrefix(model, "anthropic/") ||
strings.HasPrefix(model, "openai/") ||
strings.HasPrefix(model, "meta-llama/") ||
strings.HasPrefix(model, "deepseek/") ||
strings.HasPrefix(model, "google/"):
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
sel.apiBase = cfg.Providers.Anthropic.APIBase
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
sel.providerType = providerTypeClaudeAuth
return sel, nil
}
sel.apiKey = cfg.Providers.Anthropic.APIKey
sel.apiBase = cfg.Providers.Anthropic.APIBase
sel.proxy = cfg.Providers.Anthropic.Proxy
if sel.apiBase == "" {
sel.apiBase = defaultAnthropicAPIBase
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
sel.providerType = providerTypeCodexCLIToken
return sel, nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
sel.providerType = providerTypeCodexAuth
return sel, nil
}
sel.apiKey = cfg.Providers.OpenAI.APIKey
sel.apiBase = cfg.Providers.OpenAI.APIBase
sel.proxy = cfg.Providers.OpenAI.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
sel.apiKey = cfg.Providers.Gemini.APIKey
sel.apiBase = cfg.Providers.Gemini.APIBase
sel.proxy = cfg.Providers.Gemini.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
sel.apiKey = cfg.Providers.Zhipu.APIKey
sel.apiBase = cfg.Providers.Zhipu.APIBase
sel.proxy = cfg.Providers.Zhipu.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
sel.apiKey = cfg.Providers.Groq.APIKey
sel.apiBase = cfg.Providers.Groq.APIBase
sel.proxy = cfg.Providers.Groq.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
sel.apiKey = cfg.Providers.Nvidia.APIKey
sel.apiBase = cfg.Providers.Nvidia.APIBase
sel.proxy = cfg.Providers.Nvidia.Proxy
if sel.apiBase == "" {
sel.apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
sel.apiKey = cfg.Providers.Ollama.APIKey
sel.apiBase = cfg.Providers.Ollama.APIBase
sel.proxy = cfg.Providers.Ollama.Proxy
if sel.apiBase == "" {
sel.apiBase = "http://localhost:11434/v1"
}
case cfg.Providers.VLLM.APIBase != "":
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
sel.proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
sel.apiKey = cfg.Providers.OpenRouter.APIKey
sel.proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
sel.apiBase = cfg.Providers.OpenRouter.APIBase
} else {
sel.apiBase = "https://openrouter.ai/api/v1"
}
} else {
return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if sel.providerType == providerTypeHTTPCompat {
if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if sel.apiBase == "" {
return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
}
return sel, nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
sel, err := resolveProviderSelection(cfg)
if err != nil {
return nil, err
}
switch sel.providerType {
case providerTypeClaudeAuth:
return createClaudeAuthProvider(sel.apiBase)
case providerTypeCodexAuth:
return createCodexAuthProvider(sel.enableWebSearch)
case providerTypeCodexCLIToken:
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
c.enableWebSearch = sel.enableWebSearch
return c, nil
case providerTypeClaudeCLI:
return NewClaudeCliProvider(sel.workspace), nil
case providerTypeCodexCLI:
return NewCodexCliProvider(sel.workspace), nil
case providerTypeGitHubCopilot:
return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model)
default:
return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil
}
}
+299
View File
@@ -0,0 +1,299 @@
package providers
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestResolveProviderSelection(t *testing.T) {
tests := []struct {
name string
setup func(*config.Config)
wantType providerType
wantAPIBase string
wantProxy string
wantErrSubstr string
}{
{
name: "explicit claude-cli provider routes to cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeClaudeCLI,
},
{
name: "explicit copilot provider routes to github copilot type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "copilot"
},
wantType: providerTypeGitHubCopilot,
wantAPIBase: "localhost:4321",
},
{
name: "explicit deepseek provider uses deepseek defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "deepseek"
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.deepseek.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit shengsuanyun provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "shengsuanyun"
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "explicit nvidia provider uses defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "nvidia"
cfg.Providers.Nvidia.APIKey = "nvapi-test"
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://integrate.api.nvidia.com/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "openrouter model uses openrouter defaults",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://openrouter.ai/api/v1",
},
{
name: "anthropic oauth routes to claude auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929"
cfg.Providers.Anthropic.AuthMethod = "oauth"
},
wantType: providerTypeClaudeAuth,
},
{
name: "openai oauth routes to codex auth provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "oauth"
},
wantType: providerTypeCodexAuth,
},
{
name: "openai codex-cli auth routes to codex cli token provider",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "gpt-4o"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
},
wantType: providerTypeCodexCLIToken,
},
{
name: "explicit codex-code provider routes to codex cli provider type",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Provider = "codex-code"
cfg.Agents.Defaults.Workspace = "/tmp/ws"
},
wantType: providerTypeCodexCLI,
},
{
name: "zhipu model uses zhipu base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "glm-4.7"
cfg.Providers.Zhipu.APIKey = "zhipu-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://open.bigmodel.cn/api/paas/v4",
},
{
name: "groq model uses groq base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "groq/llama-3.3-70b"
cfg.Providers.Groq.APIKey = "gsk-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.groq.com/openai/v1",
},
{
name: "ollama model uses ollama base default",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b"
cfg.Providers.Ollama.APIKey = "ollama-key"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "http://localhost:11434/v1",
},
{
name: "moonshot model keeps proxy and default base",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5"
cfg.Providers.Moonshot.APIKey = "moonshot-key"
cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890"
},
wantType: providerTypeHTTPCompat,
wantAPIBase: "https://api.moonshot.cn/v1",
wantProxy: "http://127.0.0.1:7890",
},
{
name: "missing keys returns model config error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "custom-model"
},
wantErrSubstr: "no API key configured for model",
},
{
name: "openrouter prefix without key returns provider key error",
setup: func(cfg *config.Config) {
cfg.Agents.Defaults.Model = "openrouter/auto"
},
wantErrSubstr: "no API key configured for provider",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := config.DefaultConfig()
tt.setup(cfg)
got, err := resolveProviderSelection(cfg)
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr)
}
return
}
if err != nil {
t.Fatalf("resolveProviderSelection() error = %v", err)
}
if got.providerType != tt.wantType {
t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType)
}
if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase {
t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase)
}
if tt.wantProxy != "" && got.proxy != tt.wantProxy {
t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy)
}
})
}
}
func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Model = "openrouter/auto"
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*HTTPProvider); !ok {
t.Fatalf("provider type = %T, want *HTTPProvider", provider)
}
}
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "codex-code"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexCliProvider); !ok {
t.Fatalf("provider type = %T, want *CodexCliProvider", provider)
}
}
func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "anthropic" {
t.Fatalf("provider = %q, want anthropic", provider)
}
return &auth.AuthCredential{
AccessToken: "anthropic-token",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "anthropic"
cfg.Providers.Anthropic.AuthMethod = "oauth"
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
claudeProvider, ok := provider.(*ClaudeProvider)
if !ok {
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
}
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
}
}
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
originalGetCredential := getCredential
t.Cleanup(func() { getCredential = originalGetCredential })
getCredential = func(provider string) (*auth.AuthCredential, error) {
if provider != "openai" {
t.Fatalf("provider = %q, want openai", provider)
}
return &auth.AuthCredential{
AccessToken: "openai-token",
AccountID: "acct_123",
}, nil
}
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "openai"
cfg.Providers.OpenAI.AuthMethod = "oauth"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
if _, ok := provider.(*CodexProvider); !ok {
t.Fatalf("provider type = %T, want *CodexProvider", provider)
}
}
+283
View File
@@ -0,0 +1,283 @@
package providers
import (
"context"
"fmt"
"strings"
"time"
)
// FallbackChain orchestrates model fallback across multiple candidates.
type FallbackChain struct {
cooldown *CooldownTracker
}
// FallbackCandidate represents one model/provider to try.
type FallbackCandidate struct {
Provider string
Model string
}
// FallbackResult contains the successful response and metadata about all attempts.
type FallbackResult struct {
Response *LLMResponse
Provider string
Model string
Attempts []FallbackAttempt
}
// FallbackAttempt records one attempt in the fallback chain.
type FallbackAttempt struct {
Provider string
Model string
Error error
Reason FailoverReason
Duration time.Duration
Skipped bool // true if skipped due to cooldown
}
// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
return &FallbackChain{cooldown: cooldown}
}
// ResolveCandidates parses model config into a deduplicated candidate list.
func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate {
seen := make(map[string]bool)
var candidates []FallbackCandidate
addCandidate := func(raw string) {
ref := ParseModelRef(raw, defaultProvider)
if ref == nil {
return
}
key := ModelKey(ref.Provider, ref.Model)
if seen[key] {
return
}
seen[key] = true
candidates = append(candidates, FallbackCandidate{
Provider: ref.Provider,
Model: ref.Model,
})
}
// Primary first.
addCandidate(cfg.Primary)
// Then fallbacks.
for _, fb := range cfg.Fallbacks {
addCandidate(fb)
}
return candidates
}
// Execute runs the fallback chain for text/chat requests.
// It tries each candidate in order, respecting cooldowns and error classification.
//
// Behavior:
// - Candidates in cooldown are skipped (logged as skipped attempt).
// - context.Canceled aborts immediately (user abort, no fallback).
// - Non-retriable errors (format) abort immediately.
// - Retriable errors trigger fallback to next candidate.
// - Success marks provider as good (resets cooldown).
// - If all fail, returns aggregate error with all attempts.
func (fc *FallbackChain) Execute(
ctx context.Context,
candidates []FallbackCandidate,
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
) (*FallbackResult, error) {
if len(candidates) == 0 {
return nil, fmt.Errorf("fallback: no candidates configured")
}
result := &FallbackResult{
Attempts: make([]FallbackAttempt, 0, len(candidates)),
}
for i, candidate := range candidates {
// Check context before each attempt.
if ctx.Err() == context.Canceled {
return nil, context.Canceled
}
// Check cooldown.
if !fc.cooldown.IsAvailable(candidate.Provider) {
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Skipped: true,
Reason: FailoverRateLimit,
Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)),
})
continue
}
// Execute the run function.
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
elapsed := time.Since(start)
if err == nil {
// Success.
fc.cooldown.MarkSuccess(candidate.Provider)
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
return result, nil
}
// Context cancellation: abort immediately, no fallback.
if ctx.Err() == context.Canceled {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, context.Canceled
}
// Classify the error.
failErr := ClassifyError(err, candidate.Provider, candidate.Model)
if failErr == nil {
// Unclassifiable error: do not fallback, return immediately.
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w",
candidate.Provider, candidate.Model, err)
}
// Non-retriable error: abort immediately.
if !failErr.IsRetriable() {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: failErr,
Reason: failErr.Reason,
Duration: elapsed,
})
return nil, failErr
}
// Retriable error: mark failure and continue to next candidate.
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: failErr,
Reason: failErr.Reason,
Duration: elapsed,
})
// If this was the last candidate, return aggregate error.
if i == len(candidates)-1 {
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
}
// All candidates were skipped (all in cooldown).
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
// ExecuteImage runs the fallback chain for image/vision requests.
// Simpler than Execute: no cooldown checks (image endpoints have different rate limits).
// Image dimension/size errors abort immediately (non-retriable).
func (fc *FallbackChain) ExecuteImage(
ctx context.Context,
candidates []FallbackCandidate,
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
) (*FallbackResult, error) {
if len(candidates) == 0 {
return nil, fmt.Errorf("image fallback: no candidates configured")
}
result := &FallbackResult{
Attempts: make([]FallbackAttempt, 0, len(candidates)),
}
for i, candidate := range candidates {
if ctx.Err() == context.Canceled {
return nil, context.Canceled
}
start := time.Now()
resp, err := run(ctx, candidate.Provider, candidate.Model)
elapsed := time.Since(start)
if err == nil {
result.Response = resp
result.Provider = candidate.Provider
result.Model = candidate.Model
return result, nil
}
if ctx.Err() == context.Canceled {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
return nil, context.Canceled
}
// Image dimension/size errors are non-retriable.
errMsg := strings.ToLower(err.Error())
if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) {
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Reason: FailoverFormat,
Duration: elapsed,
})
return nil, &FailoverError{
Reason: FailoverFormat,
Provider: candidate.Provider,
Model: candidate.Model,
Wrapped: err,
}
}
// Any other error: record and try next.
result.Attempts = append(result.Attempts, FallbackAttempt{
Provider: candidate.Provider,
Model: candidate.Model,
Error: err,
Duration: elapsed,
})
if i == len(candidates)-1 {
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
}
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
}
// FallbackExhaustedError indicates all fallback candidates were tried and failed.
type FallbackExhaustedError struct {
Attempts []FallbackAttempt
}
func (e *FallbackExhaustedError) Error() string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts)))
for i, a := range e.Attempts {
if a.Skipped {
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model))
} else {
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)",
i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond)))
}
}
return sb.String()
}
+473
View File
@@ -0,0 +1,473 @@
package providers
import (
"context"
"errors"
"testing"
"time"
)
func makeCandidate(provider, model string) FallbackCandidate {
return FallbackCandidate{Provider: provider, Model: model}
}
func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return &LLMResponse{Content: content, FinishReason: "stop"}, nil
}
}
func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return nil, err
}
}
func TestFallback_SingleCandidate_Success(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "hello" {
t.Errorf("content = %q, want hello", result.Response.Content)
}
if result.Provider != "openai" || result.Model != "gpt-4" {
t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model)
}
}
func TestFallback_SecondCandidateSuccess(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude-opus"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New("rate limit exceeded")
}
return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
if result.Response.Content != "from claude" {
t.Errorf("content = %q, want 'from claude'", result.Response.Content)
}
if len(result.Attempts) != 1 {
t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts))
}
}
func TestFallback_AllFail(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
makeCandidate("groq", "llama"),
}
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
return nil, errors.New("rate limit exceeded")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error when all candidates fail")
}
var exhausted *FallbackExhaustedError
if !errors.As(err, &exhausted) {
t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err)
}
if len(exhausted.Attempts) != 3 {
t.Errorf("attempts = %d, want 3", len(exhausted.Attempts))
}
}
func TestFallback_ContextCanceled(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
ctx, cancel := context.WithCancel(context.Background())
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
cancel() // cancel context
return nil, context.Canceled
}
t.Error("should not reach second candidate after cancel")
return nil, nil
}
_, err := fc.Execute(ctx, candidates, run)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %v", err)
}
}
func TestFallback_NonRetriableError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("string should match pattern")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for non-retriable")
}
var fe *FailoverError
if !errors.As(err, &fe) {
t.Fatalf("expected FailoverError, got %T", err)
}
if fe.Reason != FailoverFormat {
t.Errorf("reason = %q, want format", fe.Reason)
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt)
}
}
func TestFallback_CooldownSkip(t *testing.T) {
now := time.Now()
ct, _ := newTestTracker(now)
fc := NewFallbackChain(ct)
// Put openai in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
if provider == "openai" {
t.Error("should not call openai (in cooldown)")
}
return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil
}
result, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
// Should have 1 skipped attempt
skipped := 0
for _, a := range result.Attempts {
if a.Skipped {
skipped++
}
}
if skipped != 1 {
t.Errorf("skipped = %d, want 1", skipped)
}
}
func TestFallback_AllInCooldown(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
// Put all providers in cooldown
ct.MarkFailure("openai", FailoverRateLimit)
ct.MarkFailure("anthropic", FailoverBilling)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
_, err := fc.Execute(context.Background(), candidates,
func(ctx context.Context, provider, model string) (*LLMResponse, error) {
t.Error("should not call any provider (all in cooldown)")
return nil, nil
})
if err == nil {
t.Fatal("expected error when all in cooldown")
}
var exhausted *FallbackExhaustedError
if !errors.As(err, &exhausted) {
t.Fatalf("expected FallbackExhaustedError, got %T", err)
}
}
func TestFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
if err == nil {
t.Error("expected error for empty candidates")
}
}
func TestFallback_EmptyFallbacks(t *testing.T) {
// Single primary, no fallbacks: should work like direct call
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "ok" {
t.Error("expected success with single candidate")
}
}
func TestFallback_UnclassifiedError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("completely unknown internal error")
}
_, err := fc.Execute(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for unclassified error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt)
}
}
func TestFallback_SuccessResetsCooldown(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
}
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
}
_, err := fc.Execute(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !ct.IsAvailable("openai") {
t.Error("success should reset cooldown")
}
}
// --- Image Fallback Tests ---
func TestImageFallback_Success(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Response.Content != "image result" {
t.Error("expected image result")
}
}
func TestImageFallback_DimensionError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("image dimensions exceed max 4096x4096")
}
_, err := fc.ExecuteImage(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for image dimension error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt)
}
}
func TestImageFallback_SizeError(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
return nil, errors.New("image exceeds 20 mb")
}
_, err := fc.ExecuteImage(context.Background(), candidates, run)
if err == nil {
t.Fatal("expected error for image size error")
}
if attempt != 1 {
t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt)
}
}
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
candidates := []FallbackCandidate{
makeCandidate("openai", "gpt-4o"),
makeCandidate("anthropic", "claude-sonnet"),
}
attempt := 0
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
attempt++
if attempt == 1 {
return nil, errors.New("rate limit exceeded")
}
return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil
}
result, err := fc.ExecuteImage(context.Background(), candidates, run)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", result.Provider)
}
}
func TestImageFallback_NoCandidates(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
if err == nil {
t.Error("expected error for empty candidates")
}
}
// --- ResolveCandidates Tests ---
func TestResolveCandidates_Simple(t *testing.T) {
cfg := ModelConfig{
Primary: "gpt-4",
Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"},
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 3 {
t.Fatalf("candidates = %d, want 3", len(candidates))
}
if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" {
t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model)
}
if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" {
t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model)
}
if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" {
t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model)
}
}
func TestResolveCandidates_Deduplication(t *testing.T) {
cfg := ModelConfig{
Primary: "openai/gpt-4",
Fallbacks: []string{"openai/gpt-4", "anthropic/claude"},
}
candidates := ResolveCandidates(cfg, "default")
if len(candidates) != 2 {
t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates))
}
}
func TestResolveCandidates_EmptyFallbacks(t *testing.T) {
cfg := ModelConfig{
Primary: "gpt-4",
Fallbacks: nil,
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 1 {
t.Errorf("candidates = %d, want 1", len(candidates))
}
}
func TestResolveCandidates_EmptyPrimary(t *testing.T) {
cfg := ModelConfig{
Primary: "",
Fallbacks: []string{"anthropic/claude"},
}
candidates := ResolveCandidates(cfg, "openai")
if len(candidates) != 1 {
t.Errorf("candidates = %d, want 1", len(candidates))
}
}
func TestFallbackExhaustedError_Message(t *testing.T) {
e := &FallbackExhaustedError{
Attempts: []FallbackAttempt{
{Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond},
{Provider: "anthropic", Model: "claude", Skipped: true},
},
}
msg := e.Error()
if msg == "" {
t.Error("expected non-empty error message")
}
}
+4 -423
View File
@@ -7,444 +7,25 @@
package providers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
)
type HTTPProvider struct {
apiKey string
apiBase string
httpClient *http.Client
delegate *openai_compat.Provider
}
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
proxyURL, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
}
return &HTTPProvider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
}
}
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
model = model[idx+1:]
}
}
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := options["max_tokens"].(int); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
}
}
if temperature, ok := options["temperature"].(float64); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return p.parseResponse(body)
}
func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]interface{})
name := ""
// Handle OpenAI format with nested function object
if tc.Type == "function" && tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
arguments["raw"] = tc.Function.Arguments
}
}
} else if tc.Function != nil {
// Legacy format without type field
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
arguments["raw"] = tc.Function.Arguments
}
}
}
toolCalls = append(toolCalls, ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
})
}
return &LLMResponse{
Content: choice.Message.Content,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
return p.delegate.Chat(ctx, messages, tools, model, options)
}
func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
var apiKey, apiBase, proxy string
lowerModel := strings.ToLower(model)
// First, try to use explicitly configured provider
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
apiKey = cfg.Providers.DeepSeek.APIKey
apiBase = cfg.Providers.DeepSeek.APIBase
if apiBase == "" {
apiBase = "https://api.deepseek.com/v1"
}
if model != "deepseek-chat" && model != "deepseek-reasoner" {
model = "deepseek-chat"
}
}
case "github_copilot", "copilot":
if cfg.Providers.GitHubCopilot.APIBase != "" {
apiBase = cfg.Providers.GitHubCopilot.APIBase
} else {
apiBase = "localhost:4321"
}
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
}
}
// Fallback: detect provider from model name
if apiKey == "" && apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
apiKey = cfg.Providers.Moonshot.APIKey
apiBase = cfg.Providers.Moonshot.APIBase
proxy = cfg.Providers.Moonshot.Proxy
if apiBase == "" {
apiBase = "https://api.moonshot.cn/v1"
}
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
proxy = cfg.Providers.Anthropic.Proxy
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
proxy = cfg.Providers.OpenAI.Proxy
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
proxy = cfg.Providers.Gemini.Proxy
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
proxy = cfg.Providers.Zhipu.Proxy
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
proxy = cfg.Providers.Groq.Proxy
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
apiKey = cfg.Providers.Nvidia.APIKey
apiBase = cfg.Providers.Nvidia.APIBase
proxy = cfg.Providers.Nvidia.Proxy
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
fmt.Println("Ollama provider selected based on model name prefix")
apiKey = cfg.Providers.Ollama.APIKey
apiBase = cfg.Providers.Ollama.APIBase
proxy = cfg.Providers.Ollama.Proxy
if apiBase == "" {
apiBase = "http://localhost:11434/v1"
}
fmt.Println("Ollama apiBase:", apiBase)
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
}
if apiBase == "" {
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
return NewHTTPProvider(apiKey, apiBase, proxy), nil
}
+64
View File
@@ -0,0 +1,64 @@
package providers
import "strings"
// ModelRef represents a parsed model reference with provider and model name.
type ModelRef struct {
Provider string
Model string
}
// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}.
// If no slash present, uses defaultProvider.
// Returns nil for empty input.
func ParseModelRef(raw string, defaultProvider string) *ModelRef {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
if idx := strings.Index(raw, "/"); idx > 0 {
provider := NormalizeProvider(raw[:idx])
model := strings.TrimSpace(raw[idx+1:])
if model == "" {
return nil
}
return &ModelRef{Provider: provider, Model: model}
}
return &ModelRef{
Provider: NormalizeProvider(defaultProvider),
Model: raw,
}
}
// NormalizeProvider normalizes provider identifiers to canonical form.
func NormalizeProvider(provider string) string {
p := strings.ToLower(strings.TrimSpace(provider))
switch p {
case "z.ai", "z-ai":
return "zai"
case "opencode-zen":
return "opencode"
case "qwen":
return "qwen-portal"
case "kimi-code":
return "kimi-coding"
case "gpt":
return "openai"
case "claude":
return "anthropic"
case "glm":
return "zhipu"
case "google":
return "gemini"
}
return p
}
// ModelKey returns a canonical "provider/model" key for deduplication.
func ModelKey(provider, model string) string {
return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model))
}
+125
View File
@@ -0,0 +1,125 @@
package providers
import "testing"
func TestParseModelRef_WithSlash(t *testing.T) {
ref := ParseModelRef("anthropic/claude-opus", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", ref.Provider)
}
if ref.Model != "claude-opus" {
t.Errorf("model = %q, want claude-opus", ref.Model)
}
}
func TestParseModelRef_WithoutSlash(t *testing.T) {
ref := ParseModelRef("gpt-4", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "openai" {
t.Errorf("provider = %q, want openai", ref.Provider)
}
if ref.Model != "gpt-4" {
t.Errorf("model = %q, want gpt-4", ref.Model)
}
}
func TestParseModelRef_Empty(t *testing.T) {
ref := ParseModelRef("", "openai")
if ref != nil {
t.Errorf("expected nil for empty string, got %+v", ref)
}
}
func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) {
ref := ParseModelRef("openai/", "default")
if ref != nil {
t.Errorf("expected nil for empty model, got %+v", ref)
}
}
func TestParseModelRef_WhitespaceHandling(t *testing.T) {
ref := ParseModelRef(" anthropic / claude-opus ", "openai")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "anthropic" {
t.Errorf("provider = %q, want anthropic", ref.Provider)
}
if ref.Model != "claude-opus" {
t.Errorf("model = %q, want claude-opus", ref.Model)
}
}
func TestNormalizeProvider(t *testing.T) {
tests := []struct {
input string
want string
}{
{"OpenAI", "openai"},
{"ANTHROPIC", "anthropic"},
{"z.ai", "zai"},
{"z-ai", "zai"},
{"Z.AI", "zai"},
{"opencode-zen", "opencode"},
{"qwen", "qwen-portal"},
{"kimi-code", "kimi-coding"},
{"gpt", "openai"},
{"claude", "anthropic"},
{"glm", "zhipu"},
{"google", "gemini"},
{"groq", "groq"},
{"", ""},
}
for _, tt := range tests {
got := NormalizeProvider(tt.input)
if got != tt.want {
t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestModelKey(t *testing.T) {
tests := []struct {
provider string
model string
want string
}{
{"openai", "gpt-4", "openai/gpt-4"},
{"Anthropic", "Claude-Opus", "anthropic/claude-opus"},
{"claude", "sonnet", "anthropic/sonnet"},
{"z.ai", "Model-X", "zai/model-x"},
}
for _, tt := range tests {
got := ModelKey(tt.provider, tt.model)
if got != tt.want {
t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want)
}
}
}
func TestParseModelRef_ProviderNormalization(t *testing.T) {
ref := ParseModelRef("Z.AI/model-x", "default")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "zai" {
t.Errorf("provider = %q, want zai", ref.Provider)
}
}
func TestParseModelRef_DefaultProviderNormalization(t *testing.T) {
ref := ParseModelRef("gpt-4o", "GPT")
if ref == nil {
t.Fatal("expected non-nil ref")
}
if ref.Provider != "openai" {
t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider)
}
}
+232
View File
@@ -0,0 +1,232 @@
package openai_compat
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
}
func NewProvider(apiKey, apiBase, proxy string) *Provider {
client := &http.Client{
Timeout: 120 * time.Second,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
return &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
}
}
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeModel(model, p.apiBase)
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
}
}
if temperature, ok := asFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return parseResponse(body)
}
func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.Unmarshal(body, &apiResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]interface{})
name := ""
if tc.Function != nil {
name = tc.Function.Name
if tc.Function.Arguments != "" {
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = tc.Function.Arguments
}
}
}
toolCalls = append(toolCalls, ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
})
}
return &LLMResponse{
Content: choice.Message.Content,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
return model
}
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
return model
}
prefix := strings.ToLower(model[:idx])
switch prefix {
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
return model[idx+1:]
default:
return model
}
}
func asInt(v interface{}) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v interface{}) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
@@ -0,0 +1,277 @@
package openai_compat
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, ok := requestBody["max_completion_tokens"]; !ok {
t.Fatalf("expected max_completion_tokens in request body")
}
if _, ok := requestBody["max_tokens"]; ok {
t.Fatalf("did not expect max_tokens key for glm model")
}
}
func TestProviderChat_ParsesToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{
"content": "",
"tool_calls": []map[string]interface{}{
{
"id": "call_1",
"type": "function",
"function": map[string]interface{}{
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
},
},
},
},
"finish_reason": "tool_calls",
},
},
"usage": map[string]interface{}{
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
}
func TestProviderChat_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"moonshot/kimi-k2.5",
map[string]interface{}{"temperature": 0.3},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["model"] != "kimi-k2.5" {
t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"])
}
if requestBody["temperature"] != 1.0 {
t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"])
}
}
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
tests := []struct {
name string
input string
wantModel string
}{
{
name: "strips groq prefix and keeps nested model",
input: "groq/openai/gpt-oss-120b",
wantModel: "openai/gpt-oss-120b",
},
{
name: "strips ollama prefix",
input: "ollama/qwen2.5:14b",
wantModel: "qwen2.5:14b",
},
{
name: "strips deepseek prefix",
input: "deepseek/deepseek-chat",
wantModel: "deepseek-chat",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["model"] != tt.wantModel {
t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel)
}
})
}
}
func TestProvider_ProxyConfigured(t *testing.T) {
proxyURL := "http://127.0.0.1:8080"
p := NewProvider("key", "https://example.com", proxyURL)
transport, ok := p.httpClient.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport)
}
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
gotProxy, err := transport.Proxy(req)
if err != nil {
t.Fatalf("proxy function returned error: %v", err)
}
if gotProxy == nil || gotProxy.String() != proxyURL {
t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL)
}
}
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
var requestBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"choices": []map[string]interface{}{
{
"message": map[string]interface{}{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if requestBody["max_tokens"] != float64(512) {
t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"])
}
if requestBody["temperature"] != float64(1) {
t.Fatalf("temperature = %v, want 1", requestBody["temperature"])
}
}
func TestNormalizeModel_UsesAPIBase(t *testing.T) {
if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" {
t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat")
}
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
}
}
+45
View File
@@ -0,0 +1,45 @@
package protocoltypes
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
+51 -39
View File
@@ -1,52 +1,64 @@
package providers
import "context"
import (
"context"
"fmt"
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type,omitempty"`
Function *FunctionCall `json:"function,omitempty"`
Name string `json:"name,omitempty"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type LLMResponse struct {
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
FinishReason string `json:"finish_reason"`
Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall = protocoltypes.ToolCall
type FunctionCall = protocoltypes.FunctionCall
type LLMResponse = protocoltypes.LLMResponse
type UsageInfo = protocoltypes.UsageInfo
type Message = protocoltypes.Message
type ToolDefinition = protocoltypes.ToolDefinition
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
type LLMProvider interface {
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
GetDefaultModel() string
}
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDefinition `json:"function"`
// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string
const (
FailoverAuth FailoverReason = "auth"
FailoverRateLimit FailoverReason = "rate_limit"
FailoverBilling FailoverReason = "billing"
FailoverTimeout FailoverReason = "timeout"
FailoverFormat FailoverReason = "format"
FailoverOverloaded FailoverReason = "overloaded"
FailoverUnknown FailoverReason = "unknown"
)
// FailoverError wraps an LLM provider error with classification metadata.
type FailoverError struct {
Reason FailoverReason
Provider string
Model string
Status int
Wrapped error
}
type ToolFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
func (e *FailoverError) Error() string {
return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v",
e.Reason, e.Provider, e.Model, e.Status, e.Wrapped)
}
func (e *FailoverError) Unwrap() error {
return e.Wrapped
}
// IsRetriable returns true if this error should trigger fallback to next candidate.
// Non-retriable: Format errors (bad request structure, image dimension/size).
func (e *FailoverError) IsRetriable() bool {
return e.Reason != FailoverFormat
}
// ModelConfig holds primary model and fallback list.
type ModelConfig struct {
Primary string
Fallbacks []string
}
+66
View File
@@ -0,0 +1,66 @@
package routing
import (
"regexp"
"strings"
)
const (
DefaultAgentID = "main"
DefaultMainKey = "main"
DefaultAccountID = "default"
MaxAgentIDLength = 64
)
var (
validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`)
invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`)
leadingDashRe = regexp.MustCompile(`^-+`)
trailingDashRe = regexp.MustCompile(`-+$`)
)
// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}.
// Invalid characters are collapsed to "-". Leading/trailing dashes stripped.
// Empty input returns DefaultAgentID ("main").
func NormalizeAgentID(id string) string {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
return DefaultAgentID
}
lower := strings.ToLower(trimmed)
if validIDRe.MatchString(lower) {
return lower
}
result := invalidCharsRe.ReplaceAllString(lower, "-")
result = leadingDashRe.ReplaceAllString(result, "")
result = trailingDashRe.ReplaceAllString(result, "")
if len(result) > MaxAgentIDLength {
result = result[:MaxAgentIDLength]
}
if result == "" {
return DefaultAgentID
}
return result
}
// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID.
func NormalizeAccountID(id string) string {
trimmed := strings.TrimSpace(id)
if trimmed == "" {
return DefaultAccountID
}
lower := strings.ToLower(trimmed)
if validIDRe.MatchString(lower) {
return lower
}
result := invalidCharsRe.ReplaceAllString(lower, "-")
result = leadingDashRe.ReplaceAllString(result, "")
result = trailingDashRe.ReplaceAllString(result, "")
if len(result) > MaxAgentIDLength {
result = result[:MaxAgentIDLength]
}
if result == "" {
return DefaultAccountID
}
return result
}
+86
View File
@@ -0,0 +1,86 @@
package routing
import "testing"
func TestNormalizeAgentID_Empty(t *testing.T) {
if got := NormalizeAgentID(""); got != DefaultAgentID {
t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID)
}
}
func TestNormalizeAgentID_Whitespace(t *testing.T) {
if got := NormalizeAgentID(" "); got != DefaultAgentID {
t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID)
}
}
func TestNormalizeAgentID_Valid(t *testing.T) {
tests := []struct {
input, want string
}{
{"main", "main"},
{"Main", "main"},
{"SALES", "sales"},
{"support-bot", "support-bot"},
{"agent_1", "agent_1"},
{"a", "a"},
{"0test", "0test"},
}
for _, tt := range tests {
if got := NormalizeAgentID(tt.input); got != tt.want {
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestNormalizeAgentID_InvalidChars(t *testing.T) {
tests := []struct {
input, want string
}{
{"Hello World", "hello-world"},
{"agent@123", "agent-123"},
{"foo.bar.baz", "foo-bar-baz"},
{"--leading", "leading"},
{"--both--", "both"},
}
for _, tt := range tests {
if got := NormalizeAgentID(tt.input); got != tt.want {
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestNormalizeAgentID_AllInvalid(t *testing.T) {
if got := NormalizeAgentID("@@@"); got != DefaultAgentID {
t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID)
}
}
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
long := ""
for i := 0; i < 100; i++ {
long += "a"
}
got := NormalizeAgentID(long)
if len(got) > MaxAgentIDLength {
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
}
}
func TestNormalizeAccountID_Empty(t *testing.T) {
if got := NormalizeAccountID(""); got != DefaultAccountID {
t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID)
}
}
func TestNormalizeAccountID_Valid(t *testing.T) {
if got := NormalizeAccountID("MyBot"); got != "mybot" {
t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got)
}
}
func TestNormalizeAccountID_InvalidChars(t *testing.T) {
if got := NormalizeAccountID("bot@home"); got != "bot-home" {
t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got)
}
}
+252
View File
@@ -0,0 +1,252 @@
package routing
import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
)
// RouteInput contains the routing context from an inbound message.
type RouteInput struct {
Channel string
AccountID string
Peer *RoutePeer
ParentPeer *RoutePeer
GuildID string
TeamID string
}
// ResolvedRoute is the result of agent routing.
type ResolvedRoute struct {
AgentID string
Channel string
AccountID string
SessionKey string
MainSessionKey string
MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default"
}
// RouteResolver determines which agent handles a message based on config bindings.
type RouteResolver struct {
cfg *config.Config
}
// NewRouteResolver creates a new route resolver.
func NewRouteResolver(cfg *config.Config) *RouteResolver {
return &RouteResolver{cfg: cfg}
}
// ResolveRoute determines which agent handles the message and constructs session keys.
// Implements the 7-level priority cascade:
// peer > parent_peer > guild > team > account > channel_wildcard > default
func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
channel := strings.ToLower(strings.TrimSpace(input.Channel))
accountID := NormalizeAccountID(input.AccountID)
peer := input.Peer
dmScope := DMScope(r.cfg.Session.DMScope)
if dmScope == "" {
dmScope = DMScopeMain
}
identityLinks := r.cfg.Session.IdentityLinks
bindings := r.filterBindings(channel, accountID)
choose := func(agentID string, matchedBy string) ResolvedRoute {
resolvedAgentID := r.pickAgentID(agentID)
sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: resolvedAgentID,
Channel: channel,
AccountID: accountID,
Peer: peer,
DMScope: dmScope,
IdentityLinks: identityLinks,
}))
mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID))
return ResolvedRoute{
AgentID: resolvedAgentID,
Channel: channel,
AccountID: accountID,
SessionKey: sessionKey,
MainSessionKey: mainSessionKey,
MatchedBy: matchedBy,
}
}
// Priority 1: Peer binding
if peer != nil && strings.TrimSpace(peer.ID) != "" {
if match := r.findPeerMatch(bindings, peer); match != nil {
return choose(match.AgentID, "binding.peer")
}
}
// Priority 2: Parent peer binding
parentPeer := input.ParentPeer
if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" {
if match := r.findPeerMatch(bindings, parentPeer); match != nil {
return choose(match.AgentID, "binding.peer.parent")
}
}
// Priority 3: Guild binding
guildID := strings.TrimSpace(input.GuildID)
if guildID != "" {
if match := r.findGuildMatch(bindings, guildID); match != nil {
return choose(match.AgentID, "binding.guild")
}
}
// Priority 4: Team binding
teamID := strings.TrimSpace(input.TeamID)
if teamID != "" {
if match := r.findTeamMatch(bindings, teamID); match != nil {
return choose(match.AgentID, "binding.team")
}
}
// Priority 5: Account binding
if match := r.findAccountMatch(bindings); match != nil {
return choose(match.AgentID, "binding.account")
}
// Priority 6: Channel wildcard binding
if match := r.findChannelWildcardMatch(bindings); match != nil {
return choose(match.AgentID, "binding.channel")
}
// Priority 7: Default agent
return choose(r.resolveDefaultAgentID(), "default")
}
func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding {
var filtered []config.AgentBinding
for _, b := range r.cfg.Bindings {
matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel))
if matchChannel == "" || matchChannel != channel {
continue
}
if !matchesAccountID(b.Match.AccountID, accountID) {
continue
}
filtered = append(filtered, b)
}
return filtered
}
func matchesAccountID(matchAccountID, actual string) bool {
trimmed := strings.TrimSpace(matchAccountID)
if trimmed == "" {
return actual == DefaultAccountID
}
if trimmed == "*" {
return true
}
return strings.ToLower(trimmed) == strings.ToLower(actual)
}
func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
if b.Match.Peer == nil {
continue
}
peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind))
peerID := strings.TrimSpace(b.Match.Peer.ID)
if peerKind == "" || peerID == "" {
continue
}
if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID {
return b
}
}
return nil
}
func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
matchGuild := strings.TrimSpace(b.Match.GuildID)
if matchGuild != "" && matchGuild == guildID {
return &bindings[i]
}
}
return nil
}
func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
matchTeam := strings.TrimSpace(b.Match.TeamID)
if matchTeam != "" && matchTeam == teamID {
return &bindings[i]
}
}
return nil
}
func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
accountID := strings.TrimSpace(b.Match.AccountID)
if accountID == "*" {
continue
}
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
continue
}
return &bindings[i]
}
return nil
}
func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding {
for i := range bindings {
b := &bindings[i]
accountID := strings.TrimSpace(b.Match.AccountID)
if accountID != "*" {
continue
}
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
continue
}
return &bindings[i]
}
return nil
}
func (r *RouteResolver) pickAgentID(agentID string) string {
trimmed := strings.TrimSpace(agentID)
if trimmed == "" {
return NormalizeAgentID(r.resolveDefaultAgentID())
}
normalized := NormalizeAgentID(trimmed)
agents := r.cfg.Agents.List
if len(agents) == 0 {
return normalized
}
for _, a := range agents {
if NormalizeAgentID(a.ID) == normalized {
return normalized
}
}
return NormalizeAgentID(r.resolveDefaultAgentID())
}
func (r *RouteResolver) resolveDefaultAgentID() string {
agents := r.cfg.Agents.List
if len(agents) == 0 {
return DefaultAgentID
}
for _, a := range agents {
if a.Default {
id := strings.TrimSpace(a.ID)
if id != "" {
return NormalizeAgentID(id)
}
}
}
if id := strings.TrimSpace(agents[0].ID); id != "" {
return NormalizeAgentID(id)
}
return DefaultAgentID
}
+297
View File
@@ -0,0 +1,297 @@
package routing
import (
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config {
return &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: "/tmp/picoclaw-test",
Model: "gpt-4",
},
List: agents,
},
Bindings: bindings,
Session: config.SessionConfig{
DMScope: "per-peer",
},
}
}
func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
cfg := testConfig(nil, nil)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
})
if route.AgentID != DefaultAgentID {
t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID)
}
if route.MatchedBy != "default" {
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
}
}
func TestResolveRoute_PeerBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "sales", Default: true},
{ID: "support"},
}
bindings := []config.AgentBinding{
{
AgentID: "support",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
Peer: &config.PeerMatch{Kind: "direct", ID: "user123"},
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
})
if route.AgentID != "support" {
t.Errorf("AgentID = %q, want 'support'", route.AgentID)
}
if route.MatchedBy != "binding.peer" {
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
}
}
func TestResolveRoute_GuildBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "gaming"},
}
bindings := []config.AgentBinding{
{
AgentID: "gaming",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
GuildID: "guild-abc",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "discord",
GuildID: "guild-abc",
Peer: &RoutePeer{Kind: "channel", ID: "ch1"},
})
if route.AgentID != "gaming" {
t.Errorf("AgentID = %q, want 'gaming'", route.AgentID)
}
if route.MatchedBy != "binding.guild" {
t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy)
}
}
func TestResolveRoute_TeamBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "work"},
}
bindings := []config.AgentBinding{
{
AgentID: "work",
Match: config.BindingMatch{
Channel: "slack",
AccountID: "*",
TeamID: "T12345",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "slack",
TeamID: "T12345",
Peer: &RoutePeer{Kind: "channel", ID: "C001"},
})
if route.AgentID != "work" {
t.Errorf("AgentID = %q, want 'work'", route.AgentID)
}
if route.MatchedBy != "binding.team" {
t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy)
}
}
func TestResolveRoute_AccountBinding(t *testing.T) {
agents := []config.AgentConfig{
{ID: "default-agent", Default: true},
{ID: "premium"},
}
bindings := []config.AgentBinding{
{
AgentID: "premium",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "bot2",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
AccountID: "bot2",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
})
if route.AgentID != "premium" {
t.Errorf("AgentID = %q, want 'premium'", route.AgentID)
}
if route.MatchedBy != "binding.account" {
t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy)
}
}
func TestResolveRoute_ChannelWildcard(t *testing.T) {
agents := []config.AgentConfig{
{ID: "main", Default: true},
{ID: "telegram-bot"},
}
bindings := []config.AgentBinding{
{
AgentID: "telegram-bot",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
})
if route.AgentID != "telegram-bot" {
t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID)
}
if route.MatchedBy != "binding.channel" {
t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy)
}
}
func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) {
agents := []config.AgentConfig{
{ID: "general", Default: true},
{ID: "vip"},
{ID: "gaming"},
}
bindings := []config.AgentBinding{
{
AgentID: "vip",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"},
},
},
{
AgentID: "gaming",
Match: config.BindingMatch{
Channel: "discord",
AccountID: "*",
GuildID: "guild-1",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "discord",
GuildID: "guild-1",
Peer: &RoutePeer{Kind: "direct", ID: "user-vip"},
})
if route.AgentID != "vip" {
t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID)
}
if route.MatchedBy != "binding.peer" {
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
}
}
func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) {
agents := []config.AgentConfig{
{ID: "main", Default: true},
}
bindings := []config.AgentBinding{
{
AgentID: "nonexistent",
Match: config.BindingMatch{
Channel: "telegram",
AccountID: "*",
},
},
}
cfg := testConfig(agents, bindings)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "telegram",
})
if route.AgentID != "main" {
t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID)
}
}
func TestResolveRoute_DefaultAgentSelection(t *testing.T) {
agents := []config.AgentConfig{
{ID: "alpha"},
{ID: "beta", Default: true},
{ID: "gamma"},
}
cfg := testConfig(agents, nil)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "cli",
})
if route.AgentID != "beta" {
t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID)
}
}
func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) {
agents := []config.AgentConfig{
{ID: "alpha"},
{ID: "beta"},
}
cfg := testConfig(agents, nil)
r := NewRouteResolver(cfg)
route := r.ResolveRoute(RouteInput{
Channel: "cli",
})
if route.AgentID != "alpha" {
t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID)
}
}
+183
View File
@@ -0,0 +1,183 @@
package routing
import (
"fmt"
"strings"
)
// DMScope controls DM session isolation granularity.
type DMScope string
const (
DMScopeMain DMScope = "main"
DMScopePerPeer DMScope = "per-peer"
DMScopePerChannelPeer DMScope = "per-channel-peer"
DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer"
)
// RoutePeer represents a chat peer with kind and ID.
type RoutePeer struct {
Kind string // "direct", "group", "channel"
ID string
}
// SessionKeyParams holds all inputs for session key construction.
type SessionKeyParams struct {
AgentID string
Channel string
AccountID string
Peer *RoutePeer
DMScope DMScope
IdentityLinks map[string][]string
}
// ParsedSessionKey is the result of parsing an agent-scoped session key.
type ParsedSessionKey struct {
AgentID string
Rest string
}
// BuildAgentMainSessionKey returns "agent:<agentId>:main".
func BuildAgentMainSessionKey(agentID string) string {
return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey)
}
// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope.
func BuildAgentPeerSessionKey(params SessionKeyParams) string {
agentID := NormalizeAgentID(params.AgentID)
peer := params.Peer
if peer == nil {
peer = &RoutePeer{Kind: "direct"}
}
peerKind := strings.TrimSpace(peer.Kind)
if peerKind == "" {
peerKind = "direct"
}
if peerKind == "direct" {
dmScope := params.DMScope
if dmScope == "" {
dmScope = DMScopeMain
}
peerID := strings.TrimSpace(peer.ID)
// Resolve identity links (cross-platform collapse)
if dmScope != DMScopeMain && peerID != "" {
if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" {
peerID = linked
}
}
peerID = strings.ToLower(peerID)
switch dmScope {
case DMScopePerAccountChannelPeer:
if peerID != "" {
channel := normalizeChannel(params.Channel)
accountID := NormalizeAccountID(params.AccountID)
return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID)
}
case DMScopePerChannelPeer:
if peerID != "" {
channel := normalizeChannel(params.Channel)
return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID)
}
case DMScopePerPeer:
if peerID != "" {
return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID)
}
}
return BuildAgentMainSessionKey(agentID)
}
// Group/channel peers always get per-peer sessions
channel := normalizeChannel(params.Channel)
peerID := strings.ToLower(strings.TrimSpace(peer.ID))
if peerID == "" {
peerID = "unknown"
}
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
}
// ParseAgentSessionKey extracts agentId and rest from "agent:<agentId>:<rest>".
func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey {
raw := strings.TrimSpace(sessionKey)
if raw == "" {
return nil
}
parts := strings.SplitN(raw, ":", 3)
if len(parts) < 3 {
return nil
}
if parts[0] != "agent" {
return nil
}
agentID := strings.TrimSpace(parts[1])
rest := parts[2]
if agentID == "" || rest == "" {
return nil
}
return &ParsedSessionKey{AgentID: agentID, Rest: rest}
}
// IsSubagentSessionKey returns true if the session key represents a subagent.
func IsSubagentSessionKey(sessionKey string) bool {
raw := strings.TrimSpace(sessionKey)
if raw == "" {
return false
}
if strings.HasPrefix(strings.ToLower(raw), "subagent:") {
return true
}
parsed := ParseAgentSessionKey(raw)
if parsed == nil {
return false
}
return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:")
}
func normalizeChannel(channel string) string {
c := strings.TrimSpace(strings.ToLower(channel))
if c == "" {
return "unknown"
}
return c
}
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
if len(identityLinks) == 0 {
return ""
}
peerID = strings.TrimSpace(peerID)
if peerID == "" {
return ""
}
candidates := make(map[string]bool)
rawCandidate := strings.ToLower(peerID)
if rawCandidate != "" {
candidates[rawCandidate] = true
}
channel = strings.ToLower(strings.TrimSpace(channel))
if channel != "" {
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
candidates[scopedCandidate] = true
}
if len(candidates) == 0 {
return ""
}
for canonical, ids := range identityLinks {
canonicalName := strings.TrimSpace(canonical)
if canonicalName == "" {
continue
}
for _, id := range ids {
normalized := strings.ToLower(strings.TrimSpace(id))
if normalized != "" && candidates[normalized] {
return canonicalName
}
}
}
return ""
}
+162
View File
@@ -0,0 +1,162 @@
package routing
import "testing"
func TestBuildAgentMainSessionKey(t *testing.T) {
got := BuildAgentMainSessionKey("sales")
want := "agent:sales:main"
if got != want {
t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want)
}
}
func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) {
got := BuildAgentMainSessionKey("Sales Bot")
want := "agent:sales-bot:main"
if got != want {
t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopeMain,
})
want := "agent:main:main"
if got != want {
t.Errorf("DMScopeMain = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerPeer,
})
want := "agent:main:direct:user123"
if got != want {
t.Errorf("DMScopePerPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerChannelPeer,
})
want := "agent:main:telegram:direct:user123"
if got != want {
t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
AccountID: "bot1",
Peer: &RoutePeer{Kind: "direct", ID: "User123"},
DMScope: DMScopePerAccountChannelPeer,
})
want := "agent:main:telegram:bot1:direct:user123"
if got != want {
t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "group", ID: "chat456"},
DMScope: DMScopePerPeer,
})
want := "agent:main:telegram:group:chat456"
if got != want {
t.Errorf("GroupPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) {
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: nil,
DMScope: DMScopePerPeer,
})
// nil peer defaults to direct with empty ID, falls to main
want := "agent:main:main"
if got != want {
t.Errorf("NilPeer = %q, want %q", got, want)
}
}
func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
links := map[string][]string{
"john": {"telegram:user123", "discord:john#1234"},
}
got := BuildAgentPeerSessionKey(SessionKeyParams{
AgentID: "main",
Channel: "telegram",
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
DMScope: DMScopePerPeer,
IdentityLinks: links,
})
want := "agent:main:direct:john"
if got != want {
t.Errorf("IdentityLink = %q, want %q", got, want)
}
}
func TestParseAgentSessionKey_Valid(t *testing.T) {
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
if parsed == nil {
t.Fatal("expected non-nil result")
}
if parsed.AgentID != "sales" {
t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID)
}
if parsed.Rest != "telegram:direct:user123" {
t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest)
}
}
func TestParseAgentSessionKey_Invalid(t *testing.T) {
tests := []string{
"",
"foo:bar",
"notprefix:sales:main",
"agent::main",
"agent:sales:",
}
for _, input := range tests {
if got := ParseAgentSessionKey(input); got != nil {
t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got)
}
}
}
func TestIsSubagentSessionKey(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"subagent:task-1", true},
{"agent:main:subagent:task-1", true},
{"agent:main:main", false},
{"agent:main:telegram:direct:user123", false},
{"", false},
}
for _, tt := range tests {
if got := IsSubagentSessionKey(tt.input); got != tt.want {
t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
-53
View File
@@ -8,7 +8,6 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
@@ -24,12 +23,6 @@ type AvailableSkill struct {
Tags []string `json:"tags"`
}
type BuiltinSkill struct {
Name string `json:"name"`
Path string `json:"path"`
Enabled bool `json:"enabled"`
}
func NewSkillInstaller(workspace string) *SkillInstaller {
return &SkillInstaller{
workspace: workspace,
@@ -123,49 +116,3 @@ func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableS
return skills, nil
}
func (si *SkillInstaller) ListBuiltinSkills() []BuiltinSkill {
builtinSkillsDir := filepath.Join(filepath.Dir(si.workspace), "picoclaw", "skills")
entries, err := os.ReadDir(builtinSkillsDir)
if err != nil {
return nil
}
var skills []BuiltinSkill
for _, entry := range entries {
if entry.IsDir() {
_ = entry
skillName := entry.Name()
skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md")
data, err := os.ReadFile(skillFile)
description := ""
if err == nil {
content := string(data)
if idx := strings.Index(content, "\n"); idx > 0 {
firstLine := content[:idx]
if strings.Contains(firstLine, "description:") {
descLine := strings.Index(content[idx:], "\n")
if descLine > 0 {
description = strings.TrimSpace(content[idx+descLine : idx+descLine])
}
}
}
}
// skill := BuiltinSkill{
// Name: skillName,
// Path: description,
// Enabled: true,
// }
status := "✓"
fmt.Printf(" %s %s\n", status, entry.Name())
if description != "" {
fmt.Printf(" %s\n", description)
}
}
}
return skills
}
+22 -5
View File
@@ -9,6 +9,8 @@ import (
"path/filepath"
"regexp"
"strings"
"github.com/sipeed/picoclaw/pkg/logger"
)
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
@@ -251,6 +253,11 @@ func (sl *SkillsLoader) BuildSkillsSummary() string {
func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata {
content, err := os.ReadFile(skillPath)
if err != nil {
logger.WarnCF("skills", "Failed to read skill metadata",
map[string]interface{}{
"skill_path": skillPath,
"error": err.Error(),
})
return nil
}
@@ -283,10 +290,15 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata {
// parseSimpleYAML parses simple key: value YAML format
// Example: name: github\n description: "..."
// Normalizes line endings to handle \n (Unix), \r\n (Windows), and \r (classic Mac)
func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
result := make(map[string]string)
for _, line := range strings.Split(content, "\n") {
// Normalize line endings: convert \r\n and \r to \n
normalized := strings.ReplaceAll(content, "\r\n", "\n")
normalized = strings.ReplaceAll(normalized, "\r", "\n")
for _, line := range strings.Split(normalized, "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
@@ -306,9 +318,10 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
}
func (sl *SkillsLoader) extractFrontmatter(content string) string {
// (?s) enables DOTALL mode so . matches newlines
// Match first ---, capture everything until next --- on its own line
re := regexp.MustCompile(`(?s)^---\n(.*)\n---`)
// Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
// (?s) enables DOTALL so . matches newlines;
// ^--- at start, then ... --- at start of line, honoring all three line ending types
re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`)
match := re.FindStringSubmatch(content)
if len(match) > 1 {
return match[1]
@@ -317,7 +330,11 @@ func (sl *SkillsLoader) extractFrontmatter(content string) string {
}
func (sl *SkillsLoader) stripFrontmatter(content string) string {
re := regexp.MustCompile(`^---\n.*?\n---\n`)
// Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
// (?s) enables DOTALL so . matches newlines;
// ^--- at start, then ... --- at start of line, honoring all three line ending types
// Match zero or more trailing line endings after closing --- (handles both with and without blank lines)
re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
return re.ReplaceAllString(content, "")
}
+102
View File
@@ -75,3 +75,105 @@ func TestSkillsInfoValidate(t *testing.T) {
})
}
}
func TestExtractFrontmatter(t *testing.T) {
sl := &SkillsLoader{}
testcases := []struct {
name string
content string
expectedName string
expectedDesc string
lineEndingType string
}{
{
name: "unix-line-endings",
lineEndingType: "Unix (\\n)",
content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content",
expectedName: "test-skill",
expectedDesc: "A test skill",
},
{
name: "windows-line-endings",
lineEndingType: "Windows (\\r\\n)",
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content",
expectedName: "test-skill",
expectedDesc: "A test skill",
},
{
name: "classic-mac-line-endings",
lineEndingType: "Classic Mac (\\r)",
content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content",
expectedName: "test-skill",
expectedDesc: "A test skill",
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
// Extract frontmatter
frontmatter := sl.extractFrontmatter(tc.content)
assert.NotEmpty(t, frontmatter, "Frontmatter should be extracted for %s line endings", tc.lineEndingType)
// Parse YAML to get name and description (parseSimpleYAML now handles all line ending types)
yamlMeta := sl.parseSimpleYAML(frontmatter)
assert.Equal(t, tc.expectedName, yamlMeta["name"], "Name should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
assert.Equal(t, tc.expectedDesc, yamlMeta["description"], "Description should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
})
}
}
func TestStripFrontmatter(t *testing.T) {
sl := &SkillsLoader{}
testcases := []struct {
name string
content string
expectedContent string
lineEndingType string
}{
{
name: "unix-line-endings",
lineEndingType: "Unix (\\n)",
content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "windows-line-endings",
lineEndingType: "Windows (\\r\\n)",
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "classic-mac-line-endings",
lineEndingType: "Classic Mac (\\r)",
content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "unix-line-endings-without-trailing-newline",
lineEndingType: "Unix (\\n) without trailing newline",
content: "---\nname: test-skill\ndescription: A test skill\n---\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "windows-line-endings-without-trailing-newline",
lineEndingType: "Windows (\\r\\n) without trailing newline",
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n# Skill Content",
expectedContent: "# Skill Content",
},
{
name: "no-frontmatter",
lineEndingType: "No frontmatter",
content: "# Skill Content\n\nSome content here.",
expectedContent: "# Skill Content\n\nSome content here.",
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
result := sl.stripFrontmatter(tc.content)
assert.Equal(t, tc.expectedContent, result, "Frontmatter should be stripped correctly for %s", tc.lineEndingType)
})
}
}
+6 -2
View File
@@ -7,6 +7,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -28,12 +29,15 @@ type CronTool struct {
}
// NewCronTool creates a new CronTool
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool {
// execTimeout: 0 means no timeout, >0 sets the timeout duration
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *CronTool {
execTool := NewExecToolWithConfig(workspace, restrict, config)
execTool.SetTimeout(execTimeout)
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: NewExecTool(workspace, restrict),
execTool: execTool,
}
}
+85 -10
View File
@@ -11,6 +11,8 @@ import (
"runtime"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
type ExecTool struct {
@@ -21,16 +23,82 @@ type ExecTool struct {
restrictToWorkspace bool
}
var defaultDenyPatterns = []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
regexp.MustCompile(`\$\([^)]+\)`),
regexp.MustCompile(`\$\{[^}]+\}`),
regexp.MustCompile("`[^`]+`"),
regexp.MustCompile(`\|\s*sh\b`),
regexp.MustCompile(`\|\s*bash\b`),
regexp.MustCompile(`;\s*rm\s+-[rf]`),
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
regexp.MustCompile(`<<\s*EOF`),
regexp.MustCompile(`\$\(\s*cat\s+`),
regexp.MustCompile(`\$\(\s*curl\s+`),
regexp.MustCompile(`\$\(\s*wget\s+`),
regexp.MustCompile(`\$\(\s*which\s+`),
regexp.MustCompile(`\bsudo\b`),
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
regexp.MustCompile(`\bchown\b`),
regexp.MustCompile(`\bpkill\b`),
regexp.MustCompile(`\bkillall\b`),
regexp.MustCompile(`\bkill\s+-[9]\b`),
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
regexp.MustCompile(`\byum\s+(install|remove)\b`),
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
regexp.MustCompile(`\bdocker\s+run\b`),
regexp.MustCompile(`\bdocker\s+exec\b`),
regexp.MustCompile(`\bgit\s+push\b`),
regexp.MustCompile(`\bgit\s+force\b`),
regexp.MustCompile(`\bssh\b.*@`),
regexp.MustCompile(`\beval\b`),
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
}
func NewExecTool(workingDir string, restrict bool) *ExecTool {
denyPatterns := []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
return NewExecToolWithConfig(workingDir, restrict, nil)
}
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
denyPatterns := make([]*regexp.Regexp, 0)
enableDenyPatterns := true
if config != nil {
execConfig := config.Tools.Exec
enableDenyPatterns = execConfig.EnableDenyPatterns
if enableDenyPatterns {
if len(execConfig.CustomDenyPatterns) > 0 {
fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns)
for _, pattern := range execConfig.CustomDenyPatterns {
re, err := regexp.Compile(pattern)
if err != nil {
fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err)
continue
}
denyPatterns = append(denyPatterns, re)
}
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
} else {
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
}
} else {
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
return &ExecTool{
@@ -89,7 +157,14 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
return ErrorResult(guardError)
}
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
// timeout == 0 means no timeout
var cmdCtx context.Context
var cancel context.CancelFunc
if t.timeout > 0 {
cmdCtx, cancel = context.WithTimeout(ctx, t.timeout)
} else {
cmdCtx, cancel = context.WithCancel(ctx)
}
defer cancel()
var cmd *exec.Cmd
+22 -5
View File
@@ -6,10 +6,11 @@ import (
)
type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
callback AsyncCallback // For async completion notification
manager *SubagentManager
originChannel string
originChatID string
allowlistCheck func(targetAgentID string) bool
callback AsyncCallback // For async completion notification
}
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
@@ -45,6 +46,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} {
"type": "string",
"description": "Optional short label for the task (for display)",
},
"agent_id": map[string]interface{}{
"type": "string",
"description": "Optional target agent ID to delegate the task to",
},
},
"required": []string{"task"},
}
@@ -55,6 +60,10 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChatID = chatID
}
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
t.allowlistCheck = check
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
@@ -62,13 +71,21 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T
}
label, _ := args["label"].(string)
agentID, _ := args["agent_id"].(string)
// Check allowlist if targeting a specific agent
if agentID != "" && t.allowlistCheck != nil {
if !t.allowlistCheck(agentID) {
return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID))
}
}
if t.manager == nil {
return ErrorResult("Subagent manager not configured")
}
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
+3 -1
View File
@@ -14,6 +14,7 @@ type SubagentTask struct {
ID string
Task string
Label string
AgentID string
OriginChannel string
OriginChatID string
Status string
@@ -61,7 +62,7 @@ func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -72,6 +73,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
ID: taskID,
Task: task,
Label: label,
AgentID: agentID,
OriginChannel: originChannel,
OriginChatID: originChatID,
Status: "running",
+79 -4
View File
@@ -176,6 +176,71 @@ func stripTags(content string) string {
return re.ReplaceAllString(content, "")
}
type PerplexitySearchProvider struct {
apiKey string
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := "https://api.perplexity.ai/chat/completions"
payload := map[string]interface{}{
"model": "sonar",
"messages": []map[string]string{
{"role": "system", "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary."},
{"role": "user", "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count)},
},
"max_tokens": 1000,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("Perplexity API error: %s", string(body))
}
var searchResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
if len(searchResp.Choices) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
}
type WebSearchTool struct {
provider SearchProvider
maxResults int
@@ -187,14 +252,22 @@ type WebSearchToolOptions struct {
BraveEnabled bool
DuckDuckGoMaxResults int
DuckDuckGoEnabled bool
PerplexityAPIKey string
PerplexityMaxResults int
PerplexityEnabled bool
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
var provider SearchProvider
maxResults := 5
// Priority: Brave > DuckDuckGo
if opts.BraveEnabled && opts.BraveAPIKey != "" {
// Priority: Perplexity > Brave > DuckDuckGo
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
@@ -419,8 +492,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
result = strings.TrimSpace(result)
re = regexp.MustCompile(`\s+`)
result = re.ReplaceAllLiteralString(result, " ")
re = regexp.MustCompile(`[^\S\n]+`)
result = re.ReplaceAllString(result, " ")
re = regexp.MustCompile(`\n{3,}`)
result = re.ReplaceAllString(result, "\n\n")
lines := strings.Split(result, "\n")
var cleanLines []string
+74
View File
@@ -234,6 +234,80 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
}
}
// TestWebFetchTool_extractText verifies text extraction preserves newlines
func TestWebFetchTool_extractText(t *testing.T) {
tool := &WebFetchTool{}
tests := []struct {
name string
input string
wantFunc func(t *testing.T, got string)
}{
{
name: "preserves newlines between block elements",
input: "<html><body><h1>Title</h1>\n<p>Paragraph 1</p>\n<p>Paragraph 2</p></body></html>",
wantFunc: func(t *testing.T, got string) {
lines := strings.Split(got, "\n")
if len(lines) < 2 {
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
}
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") {
t.Errorf("Missing expected text: %q", got)
}
},
},
{
name: "removes script and style tags",
input: "<script>alert('x');</script><style>body{}</style><p>Keep this</p>",
wantFunc: func(t *testing.T, got string) {
if strings.Contains(got, "alert") || strings.Contains(got, "body{}") {
t.Errorf("Expected script/style content removed, got: %q", got)
}
if !strings.Contains(got, "Keep this") {
t.Errorf("Expected 'Keep this' to remain, got: %q", got)
}
},
},
{
name: "collapses excessive blank lines",
input: "<p>A</p>\n\n\n\n\n<p>B</p>",
wantFunc: func(t *testing.T, got string) {
if strings.Contains(got, "\n\n\n") {
t.Errorf("Expected excessive blank lines collapsed, got: %q", got)
}
},
},
{
name: "collapses horizontal whitespace",
input: "<p>hello world</p>",
wantFunc: func(t *testing.T, got string) {
if strings.Contains(got, " ") {
t.Errorf("Expected spaces collapsed, got: %q", got)
}
if !strings.Contains(got, "hello world") {
t.Errorf("Expected 'hello world', got: %q", got)
}
},
},
{
name: "empty input",
input: "",
wantFunc: func(t *testing.T, got string) {
if got != "" {
t.Errorf("Expected empty string, got: %q", got)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tool.extractText(tt.input)
tt.wantFunc(t, got)
})
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
+179
View File
@@ -0,0 +1,179 @@
package utils
import (
"strings"
)
// SplitMessage splits long messages into chunks, preserving code block integrity.
// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks,
// but may extend to maxLen when needed.
// Call SplitMessage with the full text content and the maximum allowed length of a single message;
// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks.
func SplitMessage(content string, maxLen int) []string {
var messages []string
// Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
codeBlockBuffer := maxLen / 10
if codeBlockBuffer < 50 {
codeBlockBuffer = 50
}
if codeBlockBuffer > maxLen/2 {
codeBlockBuffer = maxLen / 2
}
for len(content) > 0 {
if len(content) <= maxLen {
messages = append(messages, content)
break
}
// Effective split point: maxLen minus buffer, to leave room for code blocks
effectiveLimit := maxLen - codeBlockBuffer
if effectiveLimit < maxLen/2 {
effectiveLimit = maxLen / 2
}
// Find natural split point within the effective limit
msgEnd := findLastNewline(content[:effectiveLimit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:effectiveLimit], 100)
}
if msgEnd <= 0 {
msgEnd = effectiveLimit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend up to maxLen to include the closing ```
if len(content) > msgEnd {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= maxLen {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Code block is too long to fit in one chunk or missing closing fence.
// Try to split inside by injecting closing and reopening fences.
headerEnd := strings.Index(content[unclosedIdx:], "\n")
if headerEnd == -1 {
headerEnd = unclosedIdx + 3
} else {
headerEnd += unclosedIdx
}
header := strings.TrimSpace(content[unclosedIdx:headerEnd])
// If we have a reasonable amount of content after the header, split inside
if msgEnd > headerEnd+20 {
// Find a better split point closer to maxLen
innerLimit := maxLen - 5 // Leave room for "\n```"
betterEnd := findLastNewline(content[:innerLimit], 200)
if betterEnd > headerEnd {
msgEnd = betterEnd
} else {
msgEnd = innerLimit
}
messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```")
content = strings.TrimSpace(header + "\n" + content[msgEnd:])
continue
}
// Otherwise, try to split before the code block starts
newEnd := findLastNewline(content[:unclosedIdx], 200)
if newEnd <= 0 {
newEnd = findLastSpace(content[:unclosedIdx], 100)
}
if newEnd > 0 {
msgEnd = newEnd
} else {
// If we can't split before, we MUST split inside (last resort)
if unclosedIdx > 20 {
msgEnd = unclosedIdx
} else {
msgEnd = maxLen - 5
messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```")
content = strings.TrimSpace(header + "\n" + content[msgEnd:])
continue
}
}
}
}
}
if msgEnd <= 0 {
msgEnd = effectiveLimit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
inCodeBlock := false
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
// Toggle code block state on each fence
if !inCodeBlock {
// Entering a code block: record this opening fence
lastOpenIdx = i
}
inCodeBlock = !inCodeBlock
i += 2
}
}
if inCodeBlock {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
+151
View File
@@ -0,0 +1,151 @@
package utils
import (
"strings"
"testing"
)
func TestSplitMessage(t *testing.T) {
longText := strings.Repeat("a", 2500)
longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars
tests := []struct {
name string
content string
maxLen int
expectChunks int // Check number of chunks
checkContent func(t *testing.T, chunks []string) // Custom validation
}{
{
name: "Empty message",
content: "",
maxLen: 2000,
expectChunks: 0,
},
{
name: "Short message fits in one chunk",
content: "Hello world",
maxLen: 2000,
expectChunks: 1,
},
{
name: "Simple split regular text",
content: longText,
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
if len(chunks[0]) > 2000 {
t.Errorf("Chunk 0 too large: %d", len(chunks[0]))
}
if len(chunks[0])+len(chunks[1]) != len(longText) {
t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText))
}
},
},
{
name: "Split at newline",
// 1750 chars then newline, then more chars.
// Dynamic buffer: 2000 / 10 = 200.
// Effective limit: 2000 - 200 = 1800.
// Split should happen at newline because it's at 1750 (< 1800).
// Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051.
content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300),
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
if len(chunks[0]) != 1750 {
t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0]))
}
if chunks[1] != strings.Repeat("b", 300) {
t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1]))
}
},
},
{
name: "Long code block split",
content: "Prefix\n" + longCode,
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
// Check that first chunk ends with closing fence
if !strings.HasSuffix(chunks[0], "\n```") {
t.Error("First chunk should end with injected closing fence")
}
// Check that second chunk starts with execution header
if !strings.HasPrefix(chunks[1], "```go") {
t.Error("Second chunk should start with injected code block header")
}
},
},
{
name: "Preserve Unicode characters",
content: strings.Repeat("\u4e16", 1000), // 3000 bytes
maxLen: 2000,
expectChunks: 2,
checkContent: func(t *testing.T, chunks []string) {
// Just verify we didn't panic and got valid strings.
// Go strings are UTF-8, if we split mid-rune it would be bad,
// but standard slicing might do that.
// Let's assume standard behavior is acceptable or check if it produces invalid rune?
if !strings.Contains(chunks[0], "\u4e16") {
t.Error("Chunk should contain unicode characters")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := SplitMessage(tc.content, tc.maxLen)
if tc.expectChunks == 0 {
if len(got) != 0 {
t.Errorf("Expected 0 chunks, got %d", len(got))
}
return
}
if len(got) != tc.expectChunks {
t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got))
// Log sizes for debugging
for i, c := range got {
t.Logf("Chunk %d length: %d", i, len(c))
}
return // Stop further checks if count assumes specific split
}
if tc.checkContent != nil {
tc.checkContent(t, got)
}
})
}
}
func TestSplitMessage_CodeBlockIntegrity(t *testing.T) {
// Focused test for the core requirement: splitting inside a code block preserves syntax highlighting
// 60 chars total approximately
content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```"
maxLen := 40
chunks := SplitMessage(content, maxLen)
if len(chunks) != 2 {
t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks)
}
// First chunk must end with "\n```"
if !strings.HasSuffix(chunks[0], "\n```") {
t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0])
}
// Second chunk must start with the header "```go"
if !strings.HasPrefix(chunks[1], "```go") {
t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1])
}
// First chunk should contain meaningful content
if len(chunks[0]) > 40 {
t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0]))
}
}