diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7910cb1e2..c96b7da12 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,4 +1,7 @@ ## 📝 Description + + + ## 🗣️ Type of Change - [ ] 🐞 Bug fix (non-breaking change which fixes an issue) - [ ] ✨ New feature (non-breaking change which adds functionality) @@ -11,25 +14,28 @@ - [ ] 👨‍💻 Mostly Human-written (Human lead, AI assisted or none) -## 🔗 Linked Issue +## 🔗 Related Issue + + + ## 📚 Technical Context (Skip for Docs) -* **Reference:** [URL] -* **Reasoning:** ... +- **Reference URL:** +- **Reasoning:** + +## 🧪 Test Environment +- **Hardware:** +- **OS:** +- **Model/Provider:** +- **Channels:** -## 🧪 Test Environment & Hardware -- **Hardware:** [e.g. Raspberry Pi 5, Orange Pi, PC] -- **OS:** [e.g. Debian 12, Ubuntu 22.04] -- **Model/Provider:** [e.g. OpenAI GPT-4o, Kimi k2, DeepSeek-V3] -- **Channels:** [e.g. Discord, Telegram, Feishu, ...] - - -## 📸 Proof of Work (Optional for Docs) +## 📸 Evidence (Optional)
Click to view Logs/Screenshots -
+ + ## ☑️ Checklist - [ ] My code/docs follow the style of this project. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0f075b0bb..499613625 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,10 +9,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: go.mod diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 2d1aa9ffc..dadbed212 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -25,7 +25,7 @@ jobs: steps: # ── Checkout ────────────────────────────── - name: 📥 Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: ref: ${{ inputs.tag }} diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index fac7597ea..55bf77e00 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -1,17 +1,39 @@ -name: pr-check +name: PR on: - pull_request: + pull_request: { } jobs: - fmt-check: + lint: + name: Linter runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 + with: + go-version-file: go.mod + + - name: Run go generate + run: go generate ./... + + - name: Golangci Lint + uses: golangci/golangci-lint-action@v9 + with: + version: v2.10.1 + + # TODO: Remove once linter is properly configured + fmt-check: + name: Formatting + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup Go + uses: actions/setup-go@v6 with: go-version-file: go.mod @@ -20,15 +42,17 @@ jobs: make fmt git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1) + # TODO: Remove once linter is properly configured vet: + name: Vet runs-on: ubuntu-latest needs: fmt-check steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: go.mod @@ -39,14 +63,15 @@ jobs: run: go vet ./... test: + name: Tests runs-on: ubuntu-latest needs: fmt-check steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: go.mod @@ -55,4 +80,3 @@ jobs: - name: Run go test run: go test ./... - diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9fe3a684e..786c893ef 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,7 +26,7 @@ jobs: contents: write steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -49,13 +49,14 @@ jobs: packages: write steps: - name: Checkout tag - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ inputs.tag }} - name: Setup Go from go.mod - uses: actions/setup-go@v5 + id: setup-go + uses: actions/setup-go@v6 with: go-version-file: go.mod @@ -89,6 +90,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }} DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }} + GOVERSION: ${{ steps.setup-go.outputs.go-version }} - name: Apply release flags shell: bash diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 000000000..80e54ac1c --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,184 @@ +version: "2" + +linters: + default: all + disable: + # TODO: Tweak for current project needs + - containedctx + - cyclop + - depguard + - dupl + - dupword + - err113 + - exhaustruct + - funcorder + - gochecknoglobals + - godot + - intrange + - ireturn + - nlreturn + - noctx + - noinlineerr + - nonamedreturns + - tagliatelle + - testpackage + - varnamelen + - wrapcheck + - wsl + - wsl_v5 + + # TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step) + - bodyclose + - contextcheck + - dogsled + - embeddedstructfieldcheck + - errcheck + - errchkjson + - errorlint + - exhaustive + - forbidigo + - forcetypeassert + - funlen + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godox + - goprintffuncname + - gosec + - govet + - ineffassign + - lll + - maintidx + - misspell + - mnd + - modernize + - nakedret + - nestif + - nilnil + - paralleltest + - perfsprint + - prealloc + - predeclared + - revive + - staticcheck + - tagalign + - testifylint + - thelper + - unparam + - unused + - usestdlibvars + - usetesting + - wastedassign + - whitespace + settings: + errcheck: + check-type-assertions: true + check-blank: true + exhaustive: + default-signifies-exhaustive: true + funlen: + lines: 120 + statements: 40 + gocognit: + min-complexity: 25 + gocyclo: + min-complexity: 20 + govet: + enable-all: true + disable: + - fieldalignment + lll: + line-length: 120 + tab-width: 4 + misspell: + locale: US + mnd: + checks: + - argument + - assign + - case + - condition + - operation + - return + nakedret: + max-func-lines: 3 + revive: + enable-all-rules: true + rules: + - name: add-constant + disabled: true + - name: argument-limit + arguments: + - 7 + severity: warning + - name: banned-characters + disabled: true + - name: cognitive-complexity + disabled: true + - name: comment-spacings + arguments: + - nolint + severity: warning + - name: cyclomatic + disabled: true + - name: file-header + disabled: true + - name: function-result-limit + arguments: + - 3 + severity: warning + - name: function-length + disabled: true + - name: line-length-limit + disabled: true + - name: max-public-structs + disabled: true + - name: modifies-value-receiver + disabled: true + - name: package-comments + disabled: true + - name: unused-receiver + disabled: true + exclusions: + generated: lax + rules: + - linters: + - lll + source: '^//go:generate ' + - linters: + - funlen + - maintidx + - gocognit + - gocyclo + path: _test\.go$ + +issues: + max-issues-per-linter: 0 + max-same-issues: 0 + +formatters: + enable: + - goimports + # TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step) + # - gci + # - gofmt + # - gofumpt + # - golines + settings: + gci: + sections: + - standard + - default + - localmodule + custom-order: true + gofmt: + simplify: true + rewrite-rules: + - pattern: "interface{}" + replacement: "any" + - pattern: "a[b:len(a)]" + replacement: "a[b:]" + golines: + max-len: 120 diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 368a0f06b..2c47f7d86 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -11,6 +11,14 @@ builds: - id: picoclaw env: - CGO_ENABLED=0 + tags: + - stdjson + ldflags: + - -s -w + - -X main.version={{ .Version }} + - -X main.gitCommit={{ .ShortCommit }} + - -X main.buildTime={{ .Date }} + - -X main.goVersion={{ .Env.GOVERSION }} goos: - linux - windows diff --git a/Dockerfile b/Dockerfile index dd98ec0bd..0360cfda6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,14 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ # Copy binary COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw -# Create picoclaw home directory +# Create non-root user and group +RUN addgroup -g 1000 picoclaw && \ + adduser -D -u 1000 -G picoclaw picoclaw + +# Switch to non-root user +USER picoclaw + +# Run onboard to create initial directories and config RUN /usr/local/bin/picoclaw onboard ENTRYPOINT ["picoclaw"] diff --git a/Makefile b/Makefile index bb31243dd..ff280e3e4 100644 --- a/Makefile +++ b/Makefile @@ -11,11 +11,11 @@ VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev") BUILD_TIME=$(shell date +%FT%T%z) GO_VERSION=$(shell $(GO) version | awk '{print $$3}') -LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)" +LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION) -s -w" # Go variables GO?=go -GOFLAGS?=-v +GOFLAGS?=-v -tags stdjson # Installation INSTALL_PREFIX?=$(HOME)/.local diff --git a/README.fr.md b/README.fr.md new file mode 100644 index 000000000..ab8faf468 --- /dev/null +++ b/README.fr.md @@ -0,0 +1,881 @@ +
+ PicoClaw + +

PicoClaw : Assistant IA Ultra-Efficace en Go

+ +

Matériel à 10$ · 10 Mo de RAM · Démarrage en 1s · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ + [中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [English](README.md) | **Français** +
+ +--- + +🦐 **PicoClaw** est un assistant personnel IA ultra-léger inspiré de [nanobot](https://github.com/HKUDS/nanobot), entièrement réécrit en **Go** via un processus d'auto-amorçage (self-bootstrapping) — où l'agent IA lui-même a piloté l'intégralité de la migration architecturale et de l'optimisation du code. + +⚡️ **Extrêmement léger :** Fonctionne sur du matériel à seulement **10$** avec **<10 Mo** de RAM. C'est 99% de mémoire en moins qu'OpenClaw et 98% moins cher qu'un Mac mini ! + + + + + + +
+

+ +

+
+

+ +

+
+ +> [!CAUTION] +> **🚨 SÉCURITÉ & CANAUX OFFICIELS** +> +> * **PAS DE CRYPTO :** PicoClaw n'a **AUCUN** token/jeton officiel. Toute annonce sur `pump.fun` ou d'autres plateformes de trading est une **ARNAQUE**. +> * **DOMAINE OFFICIEL :** Le **SEUL** site officiel est **[picoclaw.io](https://picoclaw.io)**, et le site de l'entreprise est **[sipeed.com](https://sipeed.com)**. +> * **Attention :** De nombreux domaines `.ai/.org/.com/.net/...` sont enregistrés par des tiers et ne nous appartiennent pas. +> * **Attention :** PicoClaw est en phase de développement précoce et peut présenter des problèmes de sécurité réseau non résolus. Ne déployez pas en environnement de production avant la version v1.0. +> * **Note :** PicoClaw a récemment fusionné de nombreuses PR, ce qui peut entraîner une empreinte mémoire plus importante (10–20 Mo) dans les dernières versions. Nous prévoyons de prioriser l'optimisation des ressources dès que l'ensemble des fonctionnalités sera stabilisé. + + +## 📢 Actualités + +2026-02-16 🎉 PicoClaw a atteint 12K étoiles en une semaine ! Merci à tous pour votre soutien ! PicoClaw grandit plus vite que nous ne l'avions jamais imaginé. Vu le volume élevé de PR, nous avons un besoin urgent de mainteneurs communautaires. Nos rôles de bénévoles et notre feuille de route sont officiellement publiés [ici](docs/picoclaw_community_roadmap_260216.md) — nous avons hâte de vous accueillir ! + +2026-02-13 🎉 PicoClaw a atteint 5000 étoiles en 4 jours ! Merci à la communauté ! Nous finalisons la **Feuille de Route du Projet** et mettons en place le **Groupe de Développeurs** pour accélérer le développement de PicoClaw. +🚀 **Appel à l'action :** Soumettez vos demandes de fonctionnalités dans les GitHub Discussions. Nous les examinerons et les prioriserons lors de notre prochaine réunion hebdomadaire. + +2026-02-09 🎉 PicoClaw est lancé ! Construit en 1 jour pour apporter les Agents IA au matériel à 10$ avec <10 Mo de RAM. 🦐 PicoClaw, c'est parti ! + +## ✨ Fonctionnalités + +🪶 **Ultra-Léger** : Empreinte mémoire <10 Mo — 99% plus petit que Clawdbot pour les fonctionnalités essentielles. + +💰 **Coût Minimal** : Suffisamment efficace pour fonctionner sur du matériel à 10$ — 98% moins cher qu'un Mac mini. + +⚡️ **Démarrage Éclair** : Temps de démarrage 400X plus rapide, boot en 1 seconde même sur un cœur unique à 0,6 GHz. + +🌍 **Véritable Portabilité** : Un seul binaire autonome pour RISC-V, ARM et x86. Un clic et c'est parti ! + +🤖 **Auto-Construit par l'IA** : Implémentation native en Go de manière autonome — 95% du cœur généré par l'Agent avec affinement humain dans la boucle. + +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Langage** | TypeScript | Python | **Go** | +| **RAM** | >1 Go | >100 Mo | **< 10 Mo** | +| **Démarrage**
(cœur 0,8 GHz) | >500s | >30s | **<1s** | +| **Coût** | Mac Mini 599$ | La plupart des SBC Linux
~50$ | **N'importe quelle carte Linux**
**À partir de 10$** | + +PicoClaw + +## 🦾 Démonstration + +### 🛠️ Flux de Travail Standard de l'Assistant + + + + + + + + + + + + + + + + + +

🧩 Ingénieur Full-Stack

🗂️ Gestion des Logs & Planification

🔎 Recherche Web & Apprentissage

Développer • Déployer • Mettre à l'échellePlanifier • Automatiser • MémoriserDécouvrir • Analyser • Tendances
+ +### 📱 Utiliser sur d'anciens téléphones Android + +Donnez une seconde vie à votre téléphone d'il y a dix ans ! Transformez-le en assistant IA intelligent avec PicoClaw. Démarrage rapide : + +1. **Installez Termux** (disponible sur F-Droid ou Google Play). +2. **Exécutez les commandes** + +```bash +# Note : Remplacez v0.1.1 par la dernière version depuis la page des Releases +wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64 +chmod +x picoclaw-linux-arm64 +pkg install proot +termux-chroot ./picoclaw-linux-arm64 onboard +``` + +Puis suivez les instructions de la section « Démarrage Rapide » pour terminer la configuration ! + +PicoClaw + +### 🐜 Déploiement Innovant à Faible Empreinte + +PicoClaw peut être déployé sur pratiquement n'importe quel appareil Linux ! + +- 9,9$ [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) version E (Ethernet) ou W (WiFi6), pour un Assistant Domotique Minimaliste +- 30~50$ [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou 100$ [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) pour la Maintenance Automatisée de Serveurs +- 50$ [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou 100$ [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) pour la Surveillance Intelligente + + + +🌟 Encore plus de scénarios de déploiement vous attendent ! + +## 📦 Installation + +### Installer avec un binaire précompilé + +Téléchargez le binaire pour votre plateforme depuis la page des [releases](https://github.com/sipeed/picoclaw/releases). + +### Installer depuis les sources (dernières fonctionnalités, recommandé pour le développement) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# Compiler, pas besoin d'installer +make build + +# Compiler pour plusieurs plateformes +make build-all + +# Compiler et Installer +make install +``` + +## 🐳 Docker Compose + +Vous pouvez également exécuter PicoClaw avec Docker Compose sans rien installer localement. + +```bash +# 1. Clonez ce dépôt +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. Configurez vos clés API +cp config/config.example.json config/config.json +vim config/config.json # Configurez DISCORD_BOT_TOKEN, clés API, etc. + +# 3. Compiler & Démarrer +docker compose --profile gateway up -d + +# 4. Voir les logs +docker compose logs -f picoclaw-gateway + +# 5. Arrêter +docker compose --profile gateway down +``` + +### Mode Agent (exécution unique) + +```bash +# Poser une question +docker compose run --rm picoclaw-agent -m "Combien font 2+2 ?" + +# Mode interactif +docker compose run --rm picoclaw-agent +``` + +### Recompiler + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + +### 🚀 Démarrage Rapide + +> [!TIP] +> Configurez votre clé API dans `~/.picoclaw/config.json`. +> Obtenir des clés API : [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> La recherche web est **optionnelle** — obtenez gratuitement l'[API Brave Search](https://brave.com/search/api) (2000 requêtes gratuites/mois) ou utilisez le repli automatique intégré. + +**1. Initialiser** + +```bash +picoclaw onboard +``` + +**2. Configurer** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "openrouter": { + "api_key": "xxx", + "api_base": "https://openrouter.ai/api/v1" + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "VOTRE_CLE_API_BRAVE", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +**3. Obtenir des Clés API** + +* **Fournisseur LLM** : [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Recherche Web** (optionnel) : [Brave Search](https://brave.com/search/api) - Offre gratuite disponible (2000 requêtes/mois) + +> **Note** : Consultez `config.example.json` pour un modèle de configuration complet. + +**4. Discuter** + +```bash +picoclaw agent -m "Combien font 2+2 ?" +``` + +Et voilà ! Vous avez un assistant IA fonctionnel en 2 minutes. + +--- + +## 💬 Applications de Chat + +Discutez avec votre PicoClaw via Telegram, Discord, DingTalk ou LINE + +| Canal | Configuration | +| ------------ | -------------------------------------- | +| **Telegram** | Facile (juste un token) | +| **Discord** | Facile (token bot + intents) | +| **QQ** | Facile (AppID + AppSecret) | +| **DingTalk** | Moyen (identifiants de l'application) | +| **LINE** | Moyen (identifiants + URL de webhook) | + +
+Telegram (Recommandé) + +**1. Créer un bot** + +* Ouvrez Telegram, recherchez `@BotFather` +* Envoyez `/newbot`, suivez les instructions +* Copiez le token + +**2. Configurer** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "VOTRE_TOKEN_BOT", + "allowFrom": ["VOTRE_USER_ID"] + } + } +} +``` + +> Obtenez votre User ID via `@userinfobot` sur Telegram. + +**3. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+Discord + +**1. Créer un bot** + +* Rendez-vous sur +* Créez une application → Bot → Add Bot +* Copiez le token du bot + +**2. Activer les intents** + +* Dans les paramètres du Bot, activez **MESSAGE CONTENT INTENT** +* (Optionnel) Activez **SERVER MEMBERS INTENT** si vous souhaitez utiliser des listes d'autorisation basées sur les données des membres + +**3. Obtenir votre User ID** + +* Paramètres Discord → Avancé → activez le **Mode Développeur** +* Clic droit sur votre avatar → **Copier l'identifiant** + +**4. Configurer** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "VOTRE_TOKEN_BOT", + "allowFrom": ["VOTRE_USER_ID"] + } + } +} +``` + +**5. Inviter le bot** + +* OAuth2 → URL Generator +* Scopes : `bot` +* Permissions du Bot : `Send Messages`, `Read Message History` +* Ouvrez l'URL d'invitation générée et ajoutez le bot à votre serveur + +**6. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+QQ + +**1. Créer un bot** + +- Rendez-vous sur la [QQ Open Platform](https://q.qq.com/#) +- Créez une application → Obtenez l'**AppID** et l'**AppSecret** + +**2. Configurer** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "VOTRE_APP_ID", + "app_secret": "VOTRE_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> Laissez `allow_from` vide pour autoriser tous les utilisateurs, ou spécifiez des numéros QQ pour restreindre l'accès. + +**3. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Créer un bot** + +* Rendez-vous sur la [Open Platform](https://open.dingtalk.com/) +* Créez une application interne +* Copiez le Client ID et le Client Secret + +**2. Configurer** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "VOTRE_CLIENT_ID", + "client_secret": "VOTRE_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> Laissez `allow_from` vide pour autoriser tous les utilisateurs, ou spécifiez des identifiants pour restreindre l'accès. + +**3. Lancer** + +```bash +picoclaw gateway +``` + +
+ +
+LINE + +**1. Créer un Compte Officiel LINE** + +- Rendez-vous sur la [LINE Developers Console](https://developers.line.biz/) +- Créez un provider → Créez un canal Messaging API +- Copiez le **Channel Secret** et le **Channel Access Token** + +**2. Configurer** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "VOTRE_CHANNEL_SECRET", + "channel_access_token": "VOTRE_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Configurer l'URL du Webhook** + +LINE exige HTTPS pour les webhooks. Utilisez un reverse proxy ou un tunnel : + +```bash +# Exemple avec ngrok +ngrok http 18791 +``` + +Puis configurez l'URL du Webhook dans la LINE Developers Console sur `https://votre-domaine/webhook/line` et activez **Use webhook**. + +**4. Lancer** + +```bash +picoclaw gateway +``` + +> Dans les discussions de groupe, le bot répond uniquement lorsqu'il est mentionné avec @. Les réponses citent le message original. + +> **Docker Compose** : Ajoutez `ports: ["18791:18791"]` au service `picoclaw-gateway` pour exposer le port du webhook. + +
+ +## ClawdChat Rejoignez le Réseau Social d'Agents + +Connectez PicoClaw au Réseau Social d'Agents simplement en envoyant un seul message via le CLI ou n'importe quelle application de chat intégrée. + +**Lisez `https://clawdchat.ai/skill.md` et suivez les instructions pour rejoindre [ClawdChat.ai](https://clawdchat.ai)** + +## ⚙️ Configuration + +Fichier de configuration : `~/.picoclaw/config.json` + +### Structure du Workspace + +PicoClaw stocke les données dans votre workspace configuré (par défaut : `~/.picoclaw/workspace`) : + +``` +~/.picoclaw/workspace/ +├── sessions/ # Sessions de conversation et historique +├── memory/ # Mémoire à long terme (MEMORY.md) +├── state/ # État persistant (dernier canal, etc.) +├── cron/ # Base de données des tâches planifiées +├── skills/ # Compétences personnalisées +├── AGENTS.md # Guide de comportement de l'Agent +├── HEARTBEAT.md # Invites de tâches périodiques (vérifiées toutes les 30 min) +├── IDENTITY.md # Identité de l'Agent +├── SOUL.md # Âme de l'Agent +├── TOOLS.md # Description des outils +└── USER.md # Préférences utilisateur +``` + +### 🔒 Bac à Sable de Sécurité + +PicoClaw s'exécute dans un environnement sandboxé par défaut. L'agent ne peut accéder aux fichiers et exécuter des commandes qu'au sein du workspace configuré. + +#### Configuration par Défaut + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Option | Par défaut | Description | +|--------|------------|-------------| +| `workspace` | `~/.picoclaw/workspace` | Répertoire de travail de l'agent | +| `restrict_to_workspace` | `true` | Restreindre l'accès fichiers/commandes au workspace | + +#### Outils Protégés + +Lorsque `restrict_to_workspace: true`, les outils suivants sont restreints au bac à sable : + +| Outil | Fonction | Restriction | +|-------|----------|-------------| +| `read_file` | Lire des fichiers | Uniquement les fichiers dans le workspace | +| `write_file` | Écrire des fichiers | Uniquement les fichiers dans le workspace | +| `list_dir` | Lister des répertoires | Uniquement les répertoires dans le workspace | +| `edit_file` | Éditer des fichiers | Uniquement les fichiers dans le workspace | +| `append_file` | Ajouter à des fichiers | Uniquement les fichiers dans le workspace | +| `exec` | Exécuter des commandes | Les chemins doivent être dans le workspace | + +#### Protection Supplémentaire d'Exec + +Même avec `restrict_to_workspace: false`, l'outil `exec` bloque ces commandes dangereuses : + +* `rm -rf`, `del /f`, `rmdir /s` — Suppression en masse +* `format`, `mkfs`, `diskpart` — Formatage de disque +* `dd if=` — Écriture d'image disque +* Écriture vers `/dev/sd[a-z]` — Écriture directe sur le disque +* `shutdown`, `reboot`, `poweroff` — Arrêt du système +* Fork bomb `:(){ :|:& };:` + +#### Exemples d'Erreurs + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Désactiver les Restrictions (Risque de Sécurité) + +Si vous avez besoin que l'agent accède à des chemins en dehors du workspace : + +**Méthode 1 : Fichier de configuration** + +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Méthode 2 : Variable d'environnement** + +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Attention** : Désactiver cette restriction permet à l'agent d'accéder à n'importe quel chemin sur votre système. À utiliser avec précaution uniquement dans des environnements contrôlés. + +#### Cohérence du Périmètre de Sécurité + +Le paramètre `restrict_to_workspace` s'applique de manière cohérente sur tous les chemins d'exécution : + +| Chemin d'Exécution | Périmètre de Sécurité | +|--------------------|----------------------| +| Agent Principal | `restrict_to_workspace` ✅ | +| Sous-agent / Spawn | Hérite de la même restriction ✅ | +| Tâches Heartbeat | Hérite de la même restriction ✅ | + +Tous les chemins partagent la même restriction de workspace — il est impossible de contourner le périmètre de sécurité via des sous-agents ou des tâches planifiées. + +### Heartbeat (Tâches Périodiques) + +PicoClaw peut exécuter des tâches périodiques automatiquement. Créez un fichier `HEARTBEAT.md` dans votre workspace : + +```markdown +# Tâches Périodiques + +- Vérifier mes e-mails pour les messages importants +- Consulter mon agenda pour les événements à venir +- Vérifier les prévisions météo +``` + +L'agent lira ce fichier toutes les 30 minutes (configurable) et exécutera les tâches à l'aide des outils disponibles. + +#### Tâches Asynchrones avec Spawn + +Pour les tâches de longue durée (recherche web, appels API), utilisez l'outil `spawn` pour créer un **sous-agent** : + +```markdown +# Tâches Périodiques + +## Tâches Rapides (réponse directe) +- Indiquer l'heure actuelle + +## Tâches Longues (utiliser spawn pour l'asynchrone) +- Rechercher les actualités IA sur le web et les résumer +- Vérifier les e-mails et signaler les messages importants +``` + +**Comportements clés :** + +| Fonctionnalité | Description | +|----------------|-------------| +| **spawn** | Crée un sous-agent asynchrone, ne bloque pas le heartbeat | +| **Contexte indépendant** | Le sous-agent a son propre contexte, sans historique de session | +| **Outil message** | Le sous-agent communique directement avec l'utilisateur via l'outil message | +| **Non-bloquant** | Après le spawn, le heartbeat continue vers la tâche suivante | + +#### Fonctionnement de la Communication du Sous-agent + +``` +Le Heartbeat se déclenche + ↓ +L'Agent lit HEARTBEAT.md + ↓ +Pour une tâche longue : spawn d'un sous-agent + ↓ ↓ +Continue la tâche suivante Le sous-agent travaille indépendamment + ↓ ↓ +Toutes les tâches terminées Le sous-agent utilise l'outil "message" + ↓ ↓ +Répond HEARTBEAT_OK L'utilisateur reçoit le résultat directement +``` + +Le sous-agent a accès aux outils (message, web_search, etc.) et peut communiquer avec l'utilisateur indépendamment sans passer par l'agent principal. + +**Configuration :** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Option | Par défaut | Description | +|--------|------------|-------------| +| `enabled` | `true` | Activer/désactiver le heartbeat | +| `interval` | `30` | Intervalle de vérification en minutes (min : 5) | + +**Variables d'environnement :** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` pour désactiver +* `PICOCLAW_HEARTBEAT_INTERVAL=60` pour modifier l'intervalle + +### Fournisseurs + +> [!NOTE] +> Groq fournit la transcription vocale gratuite via Whisper. Si configuré, les messages vocaux Telegram seront automatiquement transcrits. + +| Fournisseur | Utilisation | Obtenir une Clé API | +| ------------------------ | ---------------------------------------- | ------------------------------------------------------ | +| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | +| `openrouter` (À tester) | LLM (recommandé, accès à tous les modèles) | [openrouter.ai](https://openrouter.ai) | +| `anthropic` (À tester) | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | +| `openai` (À tester) | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | +| `deepseek` (À tester) | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **Transcription vocale** (Whisper) | [console.groq.com](https://console.groq.com) | + +
+Configuration Zhipu + +**1. Obtenir la clé API** + +* Obtenez la [clé API](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. Configurer** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Votre Clé API", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + } +} +``` + +**3. Lancer** + +```bash +picoclaw agent -m "Bonjour, comment ça va ?" +``` + +
+ +
+Exemple de configuration complète + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + }, + "cron": { + "exec_timeout_minutes": 5 + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +## Référence CLI + +| Commande | Description | +| ------------------------- | ------------------------------------- | +| `picoclaw onboard` | Initialiser la configuration & le workspace | +| `picoclaw agent -m "..."` | Discuter avec l'agent | +| `picoclaw agent` | Mode de discussion interactif | +| `picoclaw gateway` | Démarrer la passerelle | +| `picoclaw status` | Afficher le statut | +| `picoclaw cron list` | Lister toutes les tâches planifiées | +| `picoclaw cron add ...` | Ajouter une tâche planifiée | + +### Tâches Planifiées / Rappels + +PicoClaw prend en charge les rappels planifiés et les tâches récurrentes via l'outil `cron` : + +* **Rappels ponctuels** : « Rappelle-moi dans 10 minutes » → se déclenche une fois après 10 min +* **Tâches récurrentes** : « Rappelle-moi toutes les 2 heures » → se déclenche toutes les 2 heures +* **Expressions Cron** : « Rappelle-moi à 9h tous les jours » → utilise une expression cron + +Les tâches sont stockées dans `~/.picoclaw/workspace/cron/` et traitées automatiquement. + +## 🤝 Contribuer & Feuille de Route + +Les PR sont les bienvenues ! Le code source est volontairement petit et lisible. 🤗 + +Feuille de route à venir... + +Groupe de développeurs en construction. Condition d'entrée : au moins 1 PR fusionnée. + +Groupes d'utilisateurs : + +Discord : + +PicoClaw + +## 🐛 Dépannage + +### La recherche web affiche « API 配置问题 » + +C'est normal si vous n'avez pas encore configuré de clé API de recherche. PicoClaw fournira des liens utiles pour la recherche manuelle. + +Pour activer la recherche web : + +1. **Option 1 (Recommandé)** : Obtenez une clé API gratuite sur [https://brave.com/search/api](https://brave.com/search/api) (2000 requêtes gratuites/mois) pour les meilleurs résultats. +2. **Option 2 (Sans carte bancaire)** : Si vous n'avez pas de clé, le système bascule automatiquement sur **DuckDuckGo** (aucune clé requise). + +Ajoutez la clé dans `~/.picoclaw/config.json` si vous utilisez Brave : + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": true, + "api_key": "VOTRE_CLE_API_BRAVE", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +### Erreurs de filtrage de contenu + +Certains fournisseurs (comme Zhipu) disposent d'un filtrage de contenu. Essayez de reformuler votre requête ou utilisez un modèle différent. + +### Le bot Telegram affiche « Conflict: terminated by other getUpdates » + +Cela se produit lorsqu'une autre instance du bot est en cours d'exécution. Assurez-vous qu'un seul `picoclaw gateway` fonctionne à la fois. + +--- + +## 📝 Comparaison des Clés API + +| Service | Offre Gratuite | Cas d'Utilisation | +| ---------------- | -------------------- | ------------------------------------- | +| **OpenRouter** | 200K tokens/mois | Multiples modèles (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/mois | Idéal pour les utilisateurs chinois | +| **Brave Search** | 2000 requêtes/mois | Fonctionnalité de recherche web | +| **Groq** | Offre gratuite dispo | Inférence ultra-rapide (Llama, Mixtral) | diff --git a/README.ja.md b/README.ja.md index 9826db751..0b687b646 100644 --- a/README.ja.md +++ b/README.ja.md @@ -3,7 +3,7 @@

PicoClaw: Go で書かれた超効率 AI アシスタント

-

$10 ハードウェア · 10MB RAM · 1秒起動 · 皮皮虾,我们走!

+

$10 ハードウェア · 10MB RAM · 1秒起動 · 行くぜ、シャコ!

@@ -12,7 +12,7 @@ License

-**日本語** | [English](README.md) +[中文](README.zh.md) | **日本語** | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md) @@ -39,7 +39,7 @@ ## 📢 ニュース -2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 皮皮虾,我们走! +2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 行くぜ、シャコ! ## ✨ 特徴 @@ -253,7 +253,7 @@ Telegram、Discord、QQ、DingTalk、LINE で PicoClaw と会話できます "telegram": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -293,7 +293,7 @@ picoclaw gateway "discord": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -692,7 +692,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る "telegram": { "enabled": true, "token": "123456:ABC...", - "allowFrom": ["123456789"] + "allow_from": ["123456789"] }, "discord": { "enabled": true, @@ -708,7 +708,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る "appSecret": "xxx", "encryptKey": "", "verificationToken": "", - "allowFrom": [] + "allow_from": [] } }, "tools": { @@ -751,7 +751,7 @@ Discord: https://discord.gg/V4sAZ9XWpN ## 🐛 トラブルシューティング -### Web 検索で「API 配置问题」と表示される +### Web 検索で「API 設定の問題」と表示される 検索 API キーをまだ設定していない場合、これは正常です。PicoClaw は手動検索用の便利なリンクを提供します。 diff --git a/README.md b/README.md index 3ec420b8d..49113c31a 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Twitter

- [中文](README.zh.md) | [日本語](README.ja.md) | **English** + [中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | **English** --- @@ -291,7 +291,7 @@ Talk to your picoclaw through Telegram, Discord, DingTalk, or LINE "telegram": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -334,7 +334,7 @@ picoclaw gateway "discord": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -746,6 +746,16 @@ The new `model_list` configuration allows you to add providers with zero code ch > **Note**: The legacy `providers` configuration is deprecated. See [migration guide](docs/migration/model-list-migration.md) for details. +### Provider Architecture + +PicoClaw routes providers by protocol family: + +- OpenAI-compatible protocol: OpenRouter, OpenAI-compatible gateways, Groq, Zhipu, and vLLM-style endpoints. +- Anthropic protocol: Claude-native API behavior. +- Codex/OAuth path: OpenAI OAuth/token authentication route. + +This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`). +
Zhipu diff --git a/README.pt-br.md b/README.pt-br.md new file mode 100644 index 000000000..a89854be7 --- /dev/null +++ b/README.pt-br.md @@ -0,0 +1,882 @@ +
+PicoClaw + +

PicoClaw: Assistente de IA Ultra-Eficiente em Go

+ +

Hardware de $10 · 10MB de RAM · Boot em 1s · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ + [中文](README.zh.md) | [日本語](README.ja.md) | **Português** | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md) +
+ +--- + +🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [nanobot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código. + +⚡️ **Extremamente leve:** Roda em hardware de apenas **$10** com **<10MB** de RAM. Isso é 99% menos memória que o OpenClaw e 98% mais barato que um Mac mini! + + + + + + +
+

+ +

+
+

+ +

+
+ +> [!CAUTION] +> **🚨 DECLARAÇÃO DE SEGURANÇA & CANAIS OFICIAIS** +> +> * **SEM CRIPTOMOEDAS:** O PicoClaw **NÃO** possui nenhum token/moeda oficial. Todas as alegações no `pump.fun` ou outras plataformas de negociação são **GOLPES**. +> * **DOMÍNIO OFICIAL:** O **ÚNICO** site oficial é o **[picoclaw.io](https://picoclaw.io)**, e o site da empresa é o **[sipeed.com](https://sipeed.com)**. +> * **Aviso:** Muitos domínios `.ai/.org/.com/.net/...` foram registrados por terceiros, não são nossos. +> * **Aviso:** O PicoClaw está em fase inicial de desenvolvimento e pode ter problemas de segurança de rede não resolvidos. Não implante em ambientes de produção antes da versão v1.0. +> * **Nota:** O PicoClaw recentemente fez merge de muitos PRs, o que pode resultar em maior consumo de memória (10-20MB) nas versões mais recentes. Planejamos priorizar a otimização de recursos assim que o conjunto de funcionalidades estiver estável. + + +## 📢 Novidades + +2026-02-16 🎉 PicoClaw atingiu 12K stars em uma semana! Obrigado a todos pelo apoio! O PicoClaw está crescendo mais rápido do que jamais imaginamos. Dado o alto volume de PRs, precisamos urgentemente de maintainers da comunidade. Nossos papéis de voluntários e roadmap foram publicados oficialmente [aqui](docs/picoclaw_community_roadmap_260216.md) — estamos ansiosos para ter você a bordo! + +2026-02-13 🎉 PicoClaw atingiu 5000 stars em 4 dias! Obrigado à comunidade! Estamos finalizando o **Roadmap do Projeto** e configurando o **Grupo de Desenvolvedores** para acelerar o desenvolvimento do PicoClaw. + +🚀 **Chamada para Ação:** Envie suas solicitações de funcionalidades nas GitHub Discussions. Revisaremos e priorizaremos na próxima reunião semanal. + +2026-02-09 🎉 PicoClaw lançado oficialmente! Construído em 1 dia para trazer Agentes de IA para hardware de $10 com <10MB de RAM. 🦐 PicoClaw, Partiu! + +## ✨ Funcionalidades + +🪶 **Ultra-Leve**: Consumo de memória <10MB — 99% menor que o Clawdbot para funcionalidades essenciais. + +💰 **Custo Mínimo**: Eficiente o suficiente para rodar em hardware de $10 — 98% mais barato que um Mac mini. + +⚡️ **Inicialização Relámpago**: Tempo de inicialização 400X mais rápido, boot em 1 segundo mesmo em CPU single-core de 0.6GHz. + +🌍 **Portabilidade Real**: Um único binário auto-contido para RISC-V, ARM e x86. Um clique e já era! + +🤖 **Auto-Construído por IA**: Implementação nativa em Go de forma autônoma — 95% do núcleo gerado pelo Agente com refinamento humano no loop. + +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Linguagem** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **Inicialização**
(CPU 0.8GHz) | >500s | >30s | **<1s** | +| **Custo** | Mac Mini $599 | Maioria dos SBC Linux
~$50 | **Qualquer placa Linux**
**A partir de $10** | + +PicoClaw + +## 🦾 Demonstração + +### 🛠️ Fluxos de Trabalho Padrão do Assistente + + + + + + + + + + + + + + + + + +

🧩 Engenharia Full-Stack

🗂️ Gerenciamento de Logs & Planejamento

🔎 Busca Web & Aprendizado

Desenvolver • Implantar • EscalarAgendar • Automatizar • MemorizarDescobrir • Analisar • Tendências
+ +### 📱 Rode em celulares Android antigos + +Dê uma segunda vida ao seu celular de dez anos atrás! Transforme-o em um assistente de IA inteligente com o PicoClaw. Início rápido: + +1. **Instale o Termux** (Disponível no F-Droid ou Google Play). +2. **Execute os comandos** + +```bash +# Nota: Substitua v0.1.1 pela versao mais recente da pagina de Releases +wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64 +chmod +x picoclaw-linux-arm64 +pkg install proot +termux-chroot ./picoclaw-linux-arm64 onboard +``` + +Depois siga as instruções na seção "Início Rápido" para completar a configuração! + +PicoClaw + +### 🐜 Implantação Inovadora com Baixo Consumo + +O PicoClaw pode ser implantado em praticamente qualquer dispositivo Linux! + +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) versão E (Ethernet) ou W (WiFi6), para Assistente Doméstico Minimalista +- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) para Manutenção Automatizada de Servidores +- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) para Monitoramento Inteligente + +https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4 + +🌟 Mais cenários de implantação aguardam você! + +## 📦 Instalação + +### Instalar com binário pré-compilado + +Baixe o binário para sua plataforma na página de [releases](https://github.com/sipeed/picoclaw/releases). + +### Instalar a partir do código-fonte (funcionalidades mais recentes, recomendado para desenvolvimento) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# Build, sem necessidade de instalar +make build + +# Build para multiplas plataformas +make build-all + +# Build e Instalar +make install +``` + +## 🐳 Docker Compose + +Você tambêm pode rodar o PicoClaw usando Docker Compose sem instalar nada localmente. + +```bash +# 1. Clone este repositorio +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. Configure suas API keys +cp config/config.example.json config/config.json +vim config/config.json # Configure DISCORD_BOT_TOKEN, API keys, etc. + +# 3. Build & Iniciar +docker compose --profile gateway up -d + +# 4. Ver logs +docker compose logs -f picoclaw-gateway + +# 5. Parar +docker compose --profile gateway down +``` + +### Modo Agente (Execução única) + +```bash +# Fazer uma pergunta +docker compose run --rm picoclaw-agent -m "Quanto e 2+2?" + +# Modo interativo +docker compose run --rm picoclaw-agent +``` + +### Rebuild + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + +### 🚀 Início Rápido + +> [!TIP] +> Configure sua API key em `~/.picoclaw/config.json`. +> Obtenha API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> Busca web e **opcional** — obtenha a [Brave Search API](https://brave.com/search/api) gratuita (2000 consultas grátis/mês) ou use o fallback automático integrado. + +**1. Inicializar** + +```bash +picoclaw onboard +``` + +**2. Configurar** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "openrouter": { + "api_key": "xxx", + "api_base": "https://openrouter.ai/api/v1" + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +**3. Obter API Keys** + +* **Provedor de LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Busca Web** (opcional): [Brave Search](https://brave.com/search/api) - Plano gratuito disponível (2000 consultas/mês) + +> **Nota**: Veja `config.example.json` para um modelo de configuração completo. + +**4. Conversar** + +```bash +picoclaw agent -m "Quanto e 2+2?" +``` + +Pronto! Você tem um assistente de IA funcionando em 2 minutos. + +--- + +## 💬 Integração com Apps de Chat + +Converse com seu PicoClaw via Telegram, Discord, DingTalk ou LINE. + +| Canal | Nível de Configuração | +| --- | --- | +| **Telegram** | Fácil (apenas um token) | +| **Discord** | Fácil (bot token + intents) | +| **QQ** | Fácil (AppID + AppSecret) | +| **DingTalk** | Médio (credenciais do app) | +| **LINE** | Médio (credenciais + webhook URL) | + +
+Telegram (Recomendado) + +**1. Criar o bot** + +* Abra o Telegram, busque `@BotFather` +* Envie `/newbot`, siga as instruções +* Copie o token + +**2. Configurar** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} +``` + +> Obtenha seu User ID pelo `@userinfobot` no Telegram. + +**3. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+Discord + +**1. Criar o bot** + +* Acesse +* Crie um aplicativo → Bot → Add Bot +* Copie o token do bot + +**2. Habilitar Intents** + +* Nas configurações do Bot, habilite **MESSAGE CONTENT INTENT** +* (Opcional) Habilite **SERVER MEMBERS INTENT** se quiser usar lista de permissões baseada em dados dos membros + +**3. Obter seu User ID** + +* Configurações do Discord → Avançado → habilite **Modo Desenvolvedor** +* Clique com botão direito no seu avatar → **Copiar ID do Usuário** + +**4. Configurar** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} +``` + +**5. Convidar o bot** + +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* Abra a URL de convite gerada e adicione o bot ao seu servidor + +**6. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+QQ + +**1. Criar o bot** + +- Acesse a [QQ Open Platform](https://q.qq.com/#) +- Crie um aplicativo → Obtenha **AppID** e **AppSecret** + +**2. Configurar** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique números QQ para restringir o acesso. + +**3. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Criar o bot** + +* Acesse a [Open Platform](https://open.dingtalk.com/) +* Crie um app interno +* Copie o Client ID e Client Secret + +**2. Configurar** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique IDs para restringir o acesso. + +**3. Executar** + +```bash +picoclaw gateway +``` + +
+ +
+LINE + +**1. Criar uma Conta Oficial LINE** + +- Acesse o [LINE Developers Console](https://developers.line.biz/) +- Crie um provider → Crie um canal Messaging API +- Copie o **Channel Secret** e o **Channel Access Token** + +**2. Configurar** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Configurar URL do Webhook** + +O LINE requer HTTPS para webhooks. Use um reverse proxy ou tunnel: + +```bash +# Exemplo com ngrok +ngrok http 18791 +``` + +Em seguida, configure a Webhook URL no LINE Developers Console para `https://seu-dominio/webhook/line` e habilite **Use webhook**. + +**4. Executar** + +```bash +picoclaw gateway +``` + +> Em chats de grupo, o bot responde apenas quando mencionado com @. As respostas citam a mensagem original. + +> **Docker Compose**: Adicione `ports: ["18791:18791"]` ao serviço `picoclaw-gateway` para expor a porta do webhook. + +
+ +## ClawdChat Junte-se a Rede Social de Agentes + +Conecte o PicoClaw a Rede Social de Agentes simplesmente enviando uma única mensagem via CLI ou qualquer App de Chat integrado. + +**Leia `https://clawdchat.ai/skill.md` e siga as instruções para entrar no [ClawdChat.ai](https://clawdchat.ai)** + +## ⚙️ Configuração Detalhada + +Arquivo de configuração: `~/.picoclaw/config.json` + +### Estrutura do Workspace + +O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +├── sessions/ # Sessoes de conversa e historico +├── memory/ # Memoria de longo prazo (MEMORY.md) +├── state/ # Estado persistente (ultimo canal, etc.) +├── cron/ # Banco de dados de tarefas agendadas +├── skills/ # Skills personalizadas +├── AGENTS.md # Guia de comportamento do Agente +├── HEARTBEAT.md # Prompts de tarefas periodicas (verificado a cada 30 min) +├── IDENTITY.md # Identidade do Agente +├── SOUL.md # Alma do Agente +├── TOOLS.md # Descrição das ferramentas +└── USER.md # Preferencias do usuario +``` + +### 🔒 Sandbox de Segurança + +O PicoClaw roda em um ambiente sandbox por padrão. O agente so pode acessar arquivos e executar comandos dentro do workspace configurado. + +#### Configuração Padrão + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Opção | Padrão | Descrição | +|-------|--------|-----------| +| `workspace` | `~/.picoclaw/workspace` | Diretório de trabalho do agente | +| `restrict_to_workspace` | `true` | Restringir acesso de arquivos/comandos ao workspace | + +#### Ferramentas Protegidas + +Quando `restrict_to_workspace: true`, as seguintes ferramentas são restritas ao sandbox: + +| Ferramenta | Função | Restrição | +|------------|--------|-----------| +| `read_file` | Ler arquivos | Apenas arquivos dentro do workspace | +| `write_file` | Escrever arquivos | Apenas arquivos dentro do workspace | +| `list_dir` | Listar diretorios | Apenas diretorios dentro do workspace | +| `edit_file` | Editar arquivos | Apenas arquivos dentro do workspace | +| `append_file` | Adicionar a arquivos | Apenas arquivos dentro do workspace | +| `exec` | Executar comandos | Caminhos dos comandos devem estar dentro do workspace | + +#### Proteção Adicional do Exec + +Mesmo com `restrict_to_workspace: false`, a ferramenta `exec` bloqueia estes comandos perigosos: + +* `rm -rf`, `del /f`, `rmdir /s` — Exclusão em massa +* `format`, `mkfs`, `diskpart` — Formatação de disco +* `dd if=` — Criação de imagem de disco +* Escrita em `/dev/sd[a-z]` — Escrita direta no disco +* `shutdown`, `reboot`, `poweroff` — Desligamento do sistema +* Fork bomb `:(){ :|:& };:` + +#### Exemplos de Erro + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Desabilitar Restrições (Risco de Segurança) + +Se você precisa que o agente acesse caminhos fora do workspace: + +**Método 1: Arquivo de configuração** + +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Método 2: Variável de ambiente** + +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Aviso**: Desabilitar esta restrição permite que o agente acesse qualquer caminho no seu sistema. Use com cuidado apenas em ambientes controlados. + +#### Consistência do Limite de Segurança + +A configuração `restrict_to_workspace` se aplica consistentemente em todos os caminhos de execução: + +| Caminho de Execução | Limite de Segurança | +|----------------------|---------------------| +| Agente Principal | `restrict_to_workspace` ✅ | +| Subagente / Spawn | Herda a mesma restrição ✅ | +| Tarefas Heartbeat | Herda a mesma restrição ✅ | + +Todos os caminhos compartilham a mesma restrição de workspace — nao há como contornar o limite de segurança por meio de subagentes ou tarefas agendadas. + +### Heartbeat (Tarefas Periódicas) + +O PicoClaw pode executar tarefas periódicas automaticamente. Crie um arquivo `HEARTBEAT.md` no seu workspace: + +```markdown +# Tarefas Periodicas + +- Verificar meu email para mensagens importantes +- Revisar minha agenda para proximos eventos +- Verificar a previsao do tempo +``` + +O agente lerá este arquivo a cada 30 minutos (configurável) e executará as tarefas usando as ferramentas disponíveis. + +#### Tarefas Assincronas com Spawn + +Para tarefas de longa duração (busca web, chamadas de API), use a ferramenta `spawn` para criar um **subagente**: + +```markdown +# Tarefas Periódicas + +## Tarefas Rápidas (resposta direta) +- Informar hora atual + +## Tarefas Longas (usar spawn para async) +- Buscar notícias de IA na web e resumir +- Verificar email e reportar mensagens importantes +``` + +**Comportamentos principais:** + +| Funcionalidade | Descrição | +|----------------|-----------| +| **spawn** | Cria subagente assíncrono, não bloqueia o heartbeat | +| **Contexto independente** | Subagente tem seu próprio contexto, sem histórico de sessão | +| **Ferramenta message** | Subagente se comunica diretamente com o usuário via ferramenta message | +| **Não-bloqueante** | Após o spawn, o heartbeat continua para a próxima tarefa | + +#### Como Funciona a Comunicação do Subagente + +``` +Heartbeat dispara + ↓ +Agente lê HEARTBEAT.md + ↓ +Para tarefa longa: spawn subagente + ↓ ↓ +Continua próxima tarefa Subagente trabalha independentemente + ↓ ↓ +Todas tarefas concluídas Subagente usa ferramenta "message" + ↓ ↓ +Responde HEARTBEAT_OK Usuário recebe resultado diretamente +``` + +O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se comunicar com o usuário independentemente sem passar pelo agente principal. + +**Configuração:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Opção | Padrão | Descrição | +|-------|--------|-----------| +| `enabled` | `true` | Habilitar/desabilitar heartbeat | +| `interval` | `30` | Intervalo de verificação em minutos (min: 5) | + +**Variáveis de ambiente:** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` para desabilitar +* `PICOCLAW_HEARTBEAT_INTERVAL=60` para alterar o intervalo + +### Provedores + +> [!NOTE] +> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de voz do Telegram serão automaticamente transcritas. + +| Provedor | Finalidade | Obter API Key | +| --- | --- | --- | +| `gemini` | LLM (Gemini direto) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu direto) | [bigmodel.cn](bigmodel.cn) | +| `openrouter` (Em teste) | LLM (recomendado, acesso a todos os modelos) | [openrouter.ai](https://openrouter.ai) | +| `anthropic` (Em teste) | LLM (Claude direto) | [console.anthropic.com](https://console.anthropic.com) | +| `openai` (Em teste) | LLM (GPT direto) | [platform.openai.com](https://platform.openai.com) | +| `deepseek` (Em teste) | LLM (DeepSeek direto) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **Transcrição de voz** (Whisper) | [console.groq.com](https://console.groq.com) | + +
+Configuração Zhipu + +**1. Obter API key** + +* Obtenha a [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. Configurar** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Sua API Key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + } +} +``` + +**3. Executar** + +```bash +picoclaw agent -m "Ola, como vai?" +``` + +
+ +
+Exemplo de configuraçao completa + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + }, + "cron": { + "exec_timeout_minutes": 5 + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +## Referência CLI + +| Comando | Descrição | +| --- | --- | +| `picoclaw onboard` | Inicializar configuração & workspace | +| `picoclaw agent -m "..."` | Conversar com o agente | +| `picoclaw agent` | Modo de chat interativo | +| `picoclaw gateway` | Iniciar o gateway (para bots de chat) | +| `picoclaw status` | Mostrar status | +| `picoclaw cron list` | Listar todas as tarefas agendadas | +| `picoclaw cron add ...` | Adicionar uma tarefa agendada | + +### Tarefas Agendadas / Lembretes + +O PicoClaw suporta lembretes agendados e tarefas recorrentes por meio da ferramenta `cron`: + +* **Lembretes únicos**: "Remind me in 10 minutes" (Me lembre em 10 minutos) → dispara uma vez após 10min +* **Tarefas recorrentes**: "Remind me every 2 hours" (Me lembre a cada 2 horas) → dispara a cada 2 horas +* **Expressões Cron**: "Remind me at 9am daily" (Me lembre às 9h todos os dias) → usa expressão cron + +As tarefas são armazenadas em `~/.picoclaw/workspace/cron/` e processadas automaticamente. + +## 🤝 Contribuir & Roadmap + +PRs são bem-vindos! O código-fonte é intencionalmente pequeno e legível. 🤗 + +Roadmap em breve... + +Grupo de desenvolvedores em formação. Requisito de entrada: Pelo menos 1 PR com merge. + +Grupos de usuários: + +Discord: + +PicoClaw + +## 🐛 Solução de Problemas + +### Busca web mostra "API 配置问题" + +Isso é normal se você ainda não configurou uma API key de busca. O PicoClaw fornecerá links úteis para busca manual. + +Para habilitar a busca web: + +1. **Opção 1 (Recomendado)**: Obtenha uma API key gratuita em [https://brave.com/search/api](https://brave.com/search/api) (2000 consultas grátis/mês) para os melhores resultados. +2. **Opção 2 (Sem Cartão de Crédito)**: Se você não tem uma key, o sistema automaticamente usa o **DuckDuckGo** como fallback (sem necessidade de key). + +Adicione a key em `~/.picoclaw/config.json` se usar o Brave: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": true, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +### Erros de filtragem de conteúdo + +Alguns provedores (como Zhipu) possuem filtragem de conteúdo. Tente reformular sua pergunta ou use um modelo diferente. + +### Bot do Telegram diz "Conflict: terminated by other getUpdates" + +Isso acontece quando outra instância do bot está em execução. Certifique-se de que apenas um `picoclaw gateway` esteja rodando por vez. + +--- + +## 📝 Comparação de API Keys + +| Serviço | Plano Gratuito | Caso de Uso | +| --- | --- | --- | +| **OpenRouter** | 200K tokens/mês | Múltiplos modelos (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/mês | Melhor para usuários chineses | +| **Brave Search** | 2000 consultas/mês | Funcionalidade de busca web | +| **Groq** | Plano gratuito disponível | Inferência ultra-rápida (Llama, Mixtral) | diff --git a/README.vi.md b/README.vi.md new file mode 100644 index 000000000..c36be9865 --- /dev/null +++ b/README.vi.md @@ -0,0 +1,859 @@ +
+PicoClaw + +

PicoClaw: Trợ lý AI Siêu Nhẹ viết bằng Go

+ +

Phần cứng $10 · RAM 10MB · Khởi động 1 giây · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ +[中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | **Tiếng Việt** | [Français](README.fr.md) | [English](README.md) +
+ +--- + +🦐 **PicoClaw** là trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [nanobot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn. + +⚡️ **Cực kỳ nhẹ:** Chạy trên phần cứng chỉ **$10** với RAM **<10MB**. Tiết kiệm 99% bộ nhớ so với OpenClaw và rẻ hơn 98% so với Mac mini! + + + + + + +
+

+ +

+
+

+ +

+
+ +> [!CAUTION] +> **🚨 TUYÊN BỐ BẢO MẬT & KÊNH CHÍNH THỨC** +> +> * **KHÔNG CÓ CRYPTO:** PicoClaw **KHÔNG** có bất kỳ token/coin chính thức nào. Mọi thông tin trên `pump.fun` hoặc các sàn giao dịch khác đều là **LỪA ĐẢO**. +> * **DOMAIN CHÍNH THỨC:** Website chính thức **DUY NHẤT** là **[picoclaw.io](https://picoclaw.io)**, website công ty là **[sipeed.com](https://sipeed.com)**. +> * **Cảnh báo:** Nhiều tên miền `.ai/.org/.com/.net/...` đã bị bên thứ ba đăng ký, không phải của chúng tôi. +> * **Cảnh báo:** PicoClaw đang trong giai đoạn phát triển sớm và có thể còn các vấn đề bảo mật mạng chưa được giải quyết. Không nên triển khai lên môi trường production trước phiên bản v1.0. +> * **Lưu ý:** PicoClaw gần đây đã merge nhiều PR, dẫn đến bộ nhớ sử dụng có thể lớn hơn (10–20MB) ở các phiên bản mới nhất. Chúng tôi sẽ ưu tiên tối ưu tài nguyên khi bộ tính năng đã ổn định. + + +## 📢 Tin tức + +2026-02-16 🎉 PicoClaw đạt 12K stars chỉ trong một tuần! Cảm ơn tất cả mọi người! PicoClaw đang phát triển nhanh hơn chúng tôi tưởng tượng. Do số lượng PR tăng cao, chúng tôi cấp thiết cần maintainer từ cộng đồng. Các vai trò tình nguyện viên và roadmap đã được công bố [tại đây](docs/picoclaw_community_roadmap_260216.md) — rất mong đón nhận sự tham gia của bạn! + +2026-02-13 🎉 PicoClaw đạt 5000 stars trong 4 ngày! Cảm ơn cộng đồng! Chúng tôi đang hoàn thiện **Lộ trình dự án (Roadmap)** và thiết lập **Nhóm phát triển** để đẩy nhanh tốc độ phát triển PicoClaw. +🚀 **Kêu gọi hành động:** Vui lòng gửi yêu cầu tính năng tại GitHub Discussions. Chúng tôi sẽ xem xét và ưu tiên trong cuộc họp hàng tuần. + +2026-02-09 🎉 PicoClaw chính thức ra mắt! Được xây dựng trong 1 ngày để mang AI Agent đến phần cứng $10 với RAM <10MB. 🦐 PicoClaw, Lên Đường! + +## ✨ Tính năng nổi bật + +🪶 **Siêu nhẹ**: Bộ nhớ sử dụng <10MB — nhỏ hơn 99% so với Clawdbot (chức năng cốt lõi). + +💰 **Chi phí tối thiểu**: Đủ hiệu quả để chạy trên phần cứng $10 — rẻ hơn 98% so với Mac mini. + +⚡️ **Khởi động siêu nhanh**: Nhanh gấp 400 lần, khởi động trong 1 giây ngay cả trên CPU đơn nhân 0.6GHz. + +🌍 **Di động thực sự**: Một file binary duy nhất chạy trên RISC-V, ARM và x86. Một click là chạy! + +🤖 **AI tự xây dựng**: Triển khai Go-native tự động — 95% mã nguồn cốt lõi được Agent tạo ra, với sự tinh chỉnh của con người. + +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Ngôn ngữ** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **Thời gian khởi động**
(CPU 0.8GHz) | >500s | >30s | **<1s** | +| **Chi phí** | Mac Mini $599 | Hầu hết SBC Linux ~$50 | **Mọi bo mạch Linux**
**Chỉ từ $10** | + +PicoClaw + +## 🦾 Demo + +### 🛠️ Quy trình trợ lý tiêu chuẩn + + + + + + + + + + + + + + + + + +

🧩 Lập trình Full-Stack

🗂️ Quản lý Nhật ký & Kế hoạch

🔎 Tìm kiếm Web & Học hỏi

Phát triển • Triển khai • Mở rộngLên lịch • Tự động hóa • Ghi nhớKhám phá • Phân tích • Xu hướng
+ +### 🐜 Triển khai sáng tạo trên phần cứng tối thiểu + +PicoClaw có thể triển khai trên hầu hết mọi thiết bị Linux! + +* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) phiên bản E (Ethernet) hoặc W (WiFi6), dùng làm Trợ lý Gia đình tối giản. +* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), hoặc $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html), dùng cho quản trị Server tự động. +* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) hoặc $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera), dùng cho Giám sát thông minh. + +https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4 + +🌟 Nhiều hình thức triển khai hơn đang chờ bạn khám phá! + +## 📦 Cài đặt + +### Cài đặt bằng binary biên dịch sẵn + +Tải file binary cho nền tảng của bạn từ [trang Release](https://github.com/sipeed/picoclaw/releases). + +### Cài đặt từ mã nguồn (có tính năng mới nhất, khuyên dùng cho phát triển) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# Build (không cần cài đặt) +make build + +# Build cho nhiều nền tảng +make build-all + +# Build và cài đặt +make install +``` + +## 🐳 Docker Compose + +Bạn cũng có thể chạy PicoClaw bằng Docker Compose mà không cần cài đặt gì trên máy. + +```bash +# 1. Clone repo +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. Thiết lập API Key +cp config/config.example.json config/config.json +vim config/config.json # Thiết lập DISCORD_BOT_TOKEN, API keys, v.v. + +# 3. Build & Khởi động +docker compose --profile gateway up -d + +# 4. Xem logs +docker compose logs -f picoclaw-gateway + +# 5. Dừng +docker compose --profile gateway down +``` + +### Chế độ Agent (chạy một lần) + +```bash +# Đặt câu hỏi +docker compose run --rm picoclaw-agent -m "2+2 bằng mấy?" + +# Chế độ tương tác +docker compose run --rm picoclaw-agent +``` + +### Build lại + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + +### 🚀 Bắt đầu nhanh + +> [!TIP] +> Thiết lập API key trong `~/.picoclaw/config.json`. +> Lấy API key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> Tìm kiếm web là **tùy chọn** — lấy [Brave Search API](https://brave.com/search/api) miễn phí (2000 truy vấn/tháng) hoặc dùng tính năng auto fallback tích hợp sẵn. + +**1. Khởi tạo** + +```bash +picoclaw onboard +``` + +**2. Cấu hình** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "openrouter": { + "api_key": "xxx", + "api_base": "https://openrouter.ai/api/v1" + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +**3. Lấy API Key** + +* **Nhà cung cấp LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **Tìm kiếm Web** (tùy chọn): [Brave Search](https://brave.com/search/api) — Có gói miễn phí (2000 truy vấn/tháng) + +> **Lưu ý**: Xem `config.example.json` để có mẫu cấu hình đầy đủ. + +**4. Trò chuyện** + +```bash +picoclaw agent -m "Xin chào, bạn là ai?" +``` + +Vậy là xong! Bạn đã có một trợ lý AI hoạt động chỉ trong 2 phút. + +--- + +## 💬 Tích hợp ứng dụng Chat + +Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk hoặc LINE. + +| Kênh | Mức độ thiết lập | +| --- | --- | +| **Telegram** | Dễ (chỉ cần token) | +| **Discord** | Dễ (bot token + intents) | +| **QQ** | Dễ (AppID + AppSecret) | +| **DingTalk** | Trung bình (app credentials) | +| **LINE** | Trung bình (credentials + webhook URL) | + +
+Telegram (Khuyên dùng) + +**1. Tạo bot** + +* Mở Telegram, tìm `@BotFather` +* Gửi `/newbot`, làm theo hướng dẫn +* Sao chép token + +**2. Cấu hình** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +> Lấy User ID từ `@userinfobot` trên Telegram. + +**3. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+Discord + +**1. Tạo bot** + +* Truy cập +* Create an application → Bot → Add Bot +* Sao chép bot token + +**2. Bật Intents** + +* Trong phần Bot settings, bật **MESSAGE CONTENT INTENT** +* (Tùy chọn) Bật **SERVER MEMBERS INTENT** nếu muốn dùng danh sách cho phép theo thông tin thành viên + +**3. Lấy User ID** + +* Discord Settings → Advanced → bật **Developer Mode** +* Click chuột phải vào avatar → **Copy User ID** + +**4. Cấu hình** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allow_from": ["YOUR_USER_ID"] + } + } +} +``` + +**5. Mời bot vào server** + +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* Mở URL mời được tạo và thêm bot vào server của bạn + +**6. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+QQ + +**1. Tạo bot** + +* Truy cập [QQ Open Platform](https://q.qq.com/#) +* Tạo ứng dụng → Lấy **AppID** và **AppSecret** + +**2. Cấu hình** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định số QQ để giới hạn quyền truy cập. + +**3. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Tạo bot** + +* Truy cập [Open Platform](https://open.dingtalk.com/) +* Tạo ứng dụng nội bộ +* Sao chép Client ID và Client Secret + +**2. Cấu hình** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định ID để giới hạn quyền truy cập. + +**3. Chạy** + +```bash +picoclaw gateway +``` + +
+ +
+LINE + +**1. Tạo tài khoản LINE Official** + +- Truy cập [LINE Developers Console](https://developers.line.biz/) +- Tạo provider → Tạo Messaging API channel +- Sao chép **Channel Secret** và **Channel Access Token** + +**2. Cấu hình** + +```json +{ + "channels": { + "line": { + "enabled": true, + "channel_secret": "YOUR_CHANNEL_SECRET", + "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN", + "webhook_host": "0.0.0.0", + "webhook_port": 18791, + "webhook_path": "/webhook/line", + "allow_from": [] + } + } +} +``` + +**3. Thiết lập Webhook URL** + +LINE yêu cầu HTTPS cho webhook. Sử dụng reverse proxy hoặc tunnel: + +```bash +# Ví dụ với ngrok +ngrok http 18791 +``` + +Sau đó cài đặt Webhook URL trong LINE Developers Console thành `https://your-domain/webhook/line` và bật **Use webhook**. + +**4. Chạy** + +```bash +picoclaw gateway +``` + +> Trong nhóm chat, bot chỉ phản hồi khi được @mention. Các câu trả lời sẽ trích dẫn tin nhắn gốc. + +> **Docker Compose**: Thêm `ports: ["18791:18791"]` vào service `picoclaw-gateway` để mở port webhook. + +
+ +## ClawdChat Tham gia Mạng xã hội Agent + +Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một tin nhắn qua CLI hoặc bất kỳ ứng dụng Chat nào đã tích hợp. + +**Đọc `https://clawdchat.ai/skill.md` và làm theo hướng dẫn để tham gia [ClawdChat.ai](https://clawdchat.ai)** + +## ⚙️ Cấu hình chi tiết + +File cấu hình: `~/.picoclaw/config.json` + +### Cấu trúc Workspace + +PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định: `~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +├── sessions/ # Phiên hội thoại và lịch sử +├── memory/ # Bộ nhớ dài hạn (MEMORY.md) +├── state/ # Trạng thái lưu trữ (kênh cuối cùng, v.v.) +├── cron/ # Cơ sở dữ liệu tác vụ định kỳ +├── skills/ # Kỹ năng tùy chỉnh +├── AGENTS.md # Hướng dẫn hành vi Agent +├── HEARTBEAT.md # Prompt tác vụ định kỳ (kiểm tra mỗi 30 phút) +├── IDENTITY.md # Danh tính Agent +├── SOUL.md # Tâm hồn/Tính cách Agent +├── TOOLS.md # Mô tả công cụ +└── USER.md # Tùy chọn người dùng +``` + +### 🔒 Hộp cát bảo mật (Security Sandbox) + +PicoClaw chạy trong môi trường sandbox theo mặc định. Agent chỉ có thể truy cập file và thực thi lệnh trong phạm vi workspace. + +#### Cấu hình mặc định + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Tùy chọn | Mặc định | Mô tả | +|----------|---------|-------| +| `workspace` | `~/.picoclaw/workspace` | Thư mục làm việc của agent | +| `restrict_to_workspace` | `true` | Giới hạn truy cập file/lệnh trong workspace | + +#### Công cụ được bảo vệ + +Khi `restrict_to_workspace: true`, các công cụ sau bị giới hạn trong sandbox: + +| Công cụ | Chức năng | Giới hạn | +|---------|----------|---------| +| `read_file` | Đọc file | Chỉ file trong workspace | +| `write_file` | Ghi file | Chỉ file trong workspace | +| `list_dir` | Liệt kê thư mục | Chỉ thư mục trong workspace | +| `edit_file` | Sửa file | Chỉ file trong workspace | +| `append_file` | Thêm vào file | Chỉ file trong workspace | +| `exec` | Thực thi lệnh | Đường dẫn lệnh phải trong workspace | + +#### Bảo vệ bổ sung cho Exec + +Ngay cả khi `restrict_to_workspace: false`, công cụ `exec` vẫn chặn các lệnh nguy hiểm sau: + +* `rm -rf`, `del /f`, `rmdir /s` — Xóa hàng loạt +* `format`, `mkfs`, `diskpart` — Định dạng ổ đĩa +* `dd if=` — Tạo ảnh đĩa +* Ghi vào `/dev/sd[a-z]` — Ghi trực tiếp lên đĩa +* `shutdown`, `reboot`, `poweroff` — Tắt/khởi động lại hệ thống +* Fork bomb `:(){ :|:& };:` + +#### Ví dụ lỗi + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Tắt giới hạn (Rủi ro bảo mật) + +Nếu bạn cần agent truy cập đường dẫn ngoài workspace: + +**Cách 1: File cấu hình** + +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Cách 2: Biến môi trường** + +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Cảnh báo**: Tắt giới hạn này cho phép agent truy cập mọi đường dẫn trên hệ thống. Chỉ sử dụng cẩn thận trong môi trường được kiểm soát. + +#### Tính nhất quán của ranh giới bảo mật + +Cài đặt `restrict_to_workspace` áp dụng nhất quán trên mọi đường thực thi: + +| Đường thực thi | Ranh giới bảo mật | +|----------------|-------------------| +| Agent chính | `restrict_to_workspace` ✅ | +| Subagent / Spawn | Kế thừa cùng giới hạn ✅ | +| Tác vụ Heartbeat | Kế thừa cùng giới hạn ✅ | + +Tất cả đường thực thi chia sẻ cùng giới hạn workspace — không có cách nào vượt qua ranh giới bảo mật thông qua subagent hoặc tác vụ định kỳ. + +### Heartbeat (Tác vụ định kỳ) + +PicoClaw có thể tự động thực hiện các tác vụ định kỳ. Tạo file `HEARTBEAT.md` trong workspace: + +```markdown +# Tác vụ định kỳ + +- Kiểm tra email xem có tin nhắn quan trọng không +- Xem lại lịch cho các sự kiện sắp tới +- Kiểm tra dự báo thời tiết +``` + +Agent sẽ đọc file này mỗi 30 phút (có thể cấu hình) và thực hiện các tác vụ bằng công cụ có sẵn. + +#### Tác vụ bất đồng bộ với Spawn + +Đối với các tác vụ chạy lâu (tìm kiếm web, gọi API), sử dụng công cụ `spawn` để tạo **subagent**: + +```markdown +# Tác vụ định kỳ + +## Tác vụ nhanh (trả lời trực tiếp) +- Báo cáo thời gian hiện tại + +## Tác vụ lâu (dùng spawn cho async) +- Tìm kiếm tin tức AI trên web và tóm tắt +- Kiểm tra email và báo cáo tin nhắn quan trọng +``` + +**Hành vi chính:** + +| Tính năng | Mô tả | +|-----------|-------| +| **spawn** | Tạo subagent bất đồng bộ, không chặn heartbeat | +| **Context độc lập** | Subagent có context riêng, không có lịch sử phiên | +| **message tool** | Subagent giao tiếp trực tiếp với người dùng qua công cụ message | +| **Không chặn** | Sau khi spawn, heartbeat tiếp tục tác vụ tiếp theo | + +#### Cách Subagent giao tiếp + +``` +Heartbeat kích hoạt + ↓ +Agent đọc HEARTBEAT.md + ↓ +Tác vụ lâu: spawn subagent + ↓ ↓ +Tiếp tục tác vụ tiếp theo Subagent làm việc độc lập + ↓ ↓ +Tất cả tác vụ hoàn thành Subagent dùng công cụ "message" + ↓ ↓ +Phản hồi HEARTBEAT_OK Người dùng nhận kết quả trực tiếp +``` + +Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và có thể giao tiếp với người dùng một cách độc lập mà không cần thông qua agent chính. + +**Cấu hình:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Tùy chọn | Mặc định | Mô tả | +|----------|---------|-------| +| `enabled` | `true` | Bật/tắt heartbeat | +| `interval` | `30` | Khoảng thời gian kiểm tra (phút, tối thiểu: 5) | + +**Biến môi trường:** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` để tắt +* `PICOCLAW_HEARTBEAT_INTERVAL=60` để thay đổi khoảng thời gian + +### Nhà cung cấp (Providers) + +> [!NOTE] +> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn thoại trên Telegram sẽ được tự động chuyển thành văn bản. + +| Nhà cung cấp | Mục đích | Lấy API Key | +| --- | --- | --- | +| `gemini` | LLM (Gemini trực tiếp) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu trực tiếp) | [bigmodel.cn](bigmodel.cn) | +| `openrouter` (Đang thử nghiệm) | LLM (khuyên dùng, truy cập mọi model) | [openrouter.ai](https://openrouter.ai) | +| `anthropic` (Đang thử nghiệm) | LLM (Claude trực tiếp) | [console.anthropic.com](https://console.anthropic.com) | +| `openai` (Đang thử nghiệm) | LLM (GPT trực tiếp) | [platform.openai.com](https://platform.openai.com) | +| `deepseek` (Đang thử nghiệm) | LLM (DeepSeek trực tiếp) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **Chuyển giọng nói** (Whisper) | [console.groq.com](https://console.groq.com) | + +
+Cấu hình Zhipu + +**1. Lấy API key** + +* Lấy [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. Cấu hình** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Your API Key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + } + } +} +``` + +**3. Chạy** + +```bash +picoclaw agent -m "Xin chào" +``` + +
+ +
+Ví dụ cấu hình đầy đủ + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +## Tham chiếu CLI + +| Lệnh | Mô tả | +| --- | --- | +| `picoclaw onboard` | Khởi tạo cấu hình & workspace | +| `picoclaw agent -m "..."` | Trò chuyện với agent | +| `picoclaw agent` | Chế độ chat tương tác | +| `picoclaw gateway` | Khởi động gateway (cho bot chat) | +| `picoclaw status` | Hiển thị trạng thái | +| `picoclaw cron list` | Liệt kê tất cả tác vụ định kỳ | +| `picoclaw cron add ...` | Thêm tác vụ định kỳ | + +### Tác vụ định kỳ / Nhắc nhở + +PicoClaw hỗ trợ nhắc nhở theo lịch và tác vụ lặp lại thông qua công cụ `cron`: + +* **Nhắc nhở một lần**: "Remind me in 10 minutes" (Nhắc tôi sau 10 phút) → kích hoạt một lần sau 10 phút +* **Tác vụ lặp lại**: "Remind me every 2 hours" (Nhắc tôi mỗi 2 giờ) → kích hoạt mỗi 2 giờ +* **Biểu thức Cron**: "Remind me at 9am daily" (Nhắc tôi lúc 9 giờ sáng mỗi ngày) → sử dụng biểu thức cron + +Các tác vụ được lưu trong `~/.picoclaw/workspace/cron/` và được xử lý tự động. + +## 🤝 Đóng góp & Lộ trình + +Chào đón mọi PR! Mã nguồn được thiết kế nhỏ gọn và dễ đọc. 🤗 + +Lộ trình sắp được công bố... + +Nhóm phát triển đang được xây dựng. Điều kiện tham gia: Ít nhất 1 PR đã được merge. + +Nhóm người dùng: + +Discord: + +PicoClaw + +## 🐛 Xử lý sự cố + +### Tìm kiếm web hiện "API 配置问题" + +Điều này là bình thường nếu bạn chưa cấu hình API key cho tìm kiếm. PicoClaw sẽ cung cấp các liên kết hữu ích để tìm kiếm thủ công. + +Để bật tìm kiếm web: + +1. **Tùy chọn 1 (Khuyên dùng)**: Lấy API key miễn phí tại [https://brave.com/search/api](https://brave.com/search/api) (2000 truy vấn miễn phí/tháng) để có kết quả tốt nhất. +2. **Tùy chọn 2 (Không cần thẻ tín dụng)**: Nếu không có key, hệ thống tự động chuyển sang dùng **DuckDuckGo** (không cần key). + +Thêm key vào `~/.picoclaw/config.json` nếu dùng Brave: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": true, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` + +### Gặp lỗi lọc nội dung (Content Filtering) + +Một số nhà cung cấp (như Zhipu) có bộ lọc nội dung nghiêm ngặt. Thử diễn đạt lại câu hỏi hoặc sử dụng model khác. + +### Telegram bot báo "Conflict: terminated by other getUpdates" + +Điều này xảy ra khi có một instance bot khác đang chạy. Đảm bảo chỉ có một tiến trình `picoclaw gateway` chạy tại một thời điểm. + +--- + +## 📝 So sánh API Key + +| Dịch vụ | Gói miễn phí | Trường hợp sử dụng | +| --- | --- | --- | +| **OpenRouter** | 200K tokens/tháng | Đa model (Claude, GPT-4, v.v.) | +| **Zhipu** | 200K tokens/tháng | Tốt nhất cho người dùng Trung Quốc | +| **Brave Search** | 2000 truy vấn/tháng | Chức năng tìm kiếm web | +| **Groq** | Có gói miễn phí | Suy luận siêu nhanh (Llama, Mixtral) | diff --git a/README.zh.md b/README.zh.md index 630524dac..7132c5a9d 100644 --- a/README.zh.md +++ b/README.zh.md @@ -14,7 +14,7 @@ Twitter

- **中文** | [日本語](README.ja.md) | [English](README.md) + **中文** | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md) --- @@ -299,7 +299,7 @@ picoclaw agent -m "2+2 等于几?" "telegram": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } @@ -344,7 +344,7 @@ picoclaw gateway "discord": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allow_from": ["YOUR_USER_ID"] } } } diff --git a/assets/wechat.png b/assets/wechat.png index 6e6f50115..8fc41ea7d 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/config/config.example.json b/config/config.example.json index a8b709c77..fb970d0be 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -116,7 +116,8 @@ }, "openai": { "api_key": "", - "api_base": "" + "api_base": "", + "web_search": true }, "openrouter": { "api_key": "sk-or-v1-xxx", @@ -193,4 +194,4 @@ "host": "0.0.0.0", "port": 18790 } -} \ No newline at end of file +} diff --git a/docker-compose.yml b/docker-compose.yml index 48769627c..32e8ee339 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,8 +11,8 @@ services: profiles: - agent volumes: - - ./config/config.json:/root/.picoclaw/config.json:ro - - picoclaw-workspace:/root/.picoclaw/workspace + - ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro + - picoclaw-workspace:/home/picoclaw/.picoclaw/workspace entrypoint: ["picoclaw", "agent"] stdin_open: true tty: true @@ -31,9 +31,9 @@ services: - gateway volumes: # Configuration file - - ./config/config.json:/root/.picoclaw/config.json:ro + - ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro # Persistent workspace (sessions, memory, logs) - - picoclaw-workspace:/root/.picoclaw/workspace + - picoclaw-workspace:/home/picoclaw/.picoclaw/workspace command: ["gateway"] volumes: diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md new file mode 100644 index 000000000..8777ddbd6 --- /dev/null +++ b/docs/tools_configuration.md @@ -0,0 +1,122 @@ +# Tools Configuration + +PicoClaw's tools configuration is located in the `tools` field of `config.json`. + +## Directory Structure + +```json +{ + "tools": { + "web": { ... }, + "exec": { ... }, + "approval": { ... }, + "cron": { ... } + } +} +``` + +## Web Tools + +Web tools are used for web search and fetching. + +### Brave + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | false | Enable Brave search | +| `api_key` | string | - | Brave Search API key | +| `max_results` | int | 5 | Maximum number of results | + +### DuckDuckGo + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | true | Enable DuckDuckGo search | +| `max_results` | int | 5 | Maximum number of results | + +### Perplexity + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | false | Enable Perplexity search | +| `api_key` | string | - | Perplexity API key | +| `max_results` | int | 5 | Maximum number of results | + +## Exec Tool + +The exec tool is used to execute shell commands. + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking | +| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) | + +### Functionality + +- **`enable_deny_patterns`**: Set to `false` to completely disable the default dangerous command blocking patterns +- **`custom_deny_patterns`**: Add custom deny regex patterns; commands matching these will be blocked + +### Default Blocked Command Patterns + +By default, PicoClaw blocks the following dangerous commands: + +- Delete commands: `rm -rf`, `del /f/q`, `rmdir /s` +- Disk operations: `format`, `mkfs`, `diskpart`, `dd if=`, writing to `/dev/sd*` +- System operations: `shutdown`, `reboot`, `poweroff` +- Command substitution: `$()`, `${}`, backticks +- Pipe to shell: `| sh`, `| bash` +- Privilege escalation: `sudo`, `chmod`, `chown` +- Process control: `pkill`, `killall`, `kill -9` +- Remote operations: `curl | sh`, `wget | sh`, `ssh` +- Package management: `apt`, `yum`, `dnf`, `npm install -g`, `pip install --user` +- Containers: `docker run`, `docker exec` +- Git: `git push`, `git force` +- Other: `eval`, `source *.sh` + +### Configuration Example + +```json +{ + "tools": { + "exec": { + "enable_deny_patterns": true, + "custom_deny_patterns": [ + "\\brm\\s+-r\\b", + "\\bkillall\\s+python" + ], + } + } +} +``` + +## Approval Tool + +The approval tool controls permissions for dangerous operations. + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | true | Enable approval functionality | +| `write_file` | bool | true | Require approval for file writes | +| `edit_file` | bool | true | Require approval for file edits | +| `append_file` | bool | true | Require approval for file appends | +| `exec` | bool | true | Require approval for command execution | +| `timeout_minutes` | int | 5 | Approval timeout in minutes | + +## Cron Tool + +The cron tool is used for scheduling periodic tasks. + +| Config | Type | Default | Description | +|--------|------|---------|-------------| +| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit | + +## Environment Variables + +All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS_
_`: + +For example: +- `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true` +- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false` +- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10` + +Note: Array-type environment variables are not currently supported and must be set via the config file. diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go new file mode 100644 index 000000000..54a5396e7 --- /dev/null +++ b/pkg/agent/instance.go @@ -0,0 +1,145 @@ +package agent + +import ( + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// AgentInstance represents a fully configured agent with its own workspace, +// session manager, context builder, and tool registry. +type AgentInstance struct { + ID string + Name string + Model string + Fallbacks []string + Workspace string + MaxIterations int + ContextWindow int + Provider providers.LLMProvider + Sessions *session.SessionManager + ContextBuilder *ContextBuilder + Tools *tools.ToolRegistry + Subagents *config.SubagentsConfig + SkillsFilter []string + Candidates []providers.FallbackCandidate +} + +// NewAgentInstance creates an agent instance from config. +func NewAgentInstance( + agentCfg *config.AgentConfig, + defaults *config.AgentDefaults, + cfg *config.Config, + provider providers.LLMProvider, +) *AgentInstance { + workspace := resolveAgentWorkspace(agentCfg, defaults) + os.MkdirAll(workspace, 0755) + + model := resolveAgentModel(agentCfg, defaults) + fallbacks := resolveAgentFallbacks(agentCfg, defaults) + + restrict := defaults.RestrictToWorkspace + toolsRegistry := tools.NewToolRegistry() + toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) + toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg)) + toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) + + sessionsDir := filepath.Join(workspace, "sessions") + sessionsManager := session.NewSessionManager(sessionsDir) + + contextBuilder := NewContextBuilder(workspace) + contextBuilder.SetToolsRegistry(toolsRegistry) + + agentID := routing.DefaultAgentID + agentName := "" + var subagents *config.SubagentsConfig + var skillsFilter []string + + if agentCfg != nil { + agentID = routing.NormalizeAgentID(agentCfg.ID) + agentName = agentCfg.Name + subagents = agentCfg.Subagents + skillsFilter = agentCfg.Skills + } + + maxIter := defaults.MaxToolIterations + if maxIter == 0 { + maxIter = 20 + } + + // Resolve fallback candidates + modelCfg := providers.ModelConfig{ + Primary: model, + Fallbacks: fallbacks, + } + candidates := providers.ResolveCandidates(modelCfg, defaults.Provider) + + return &AgentInstance{ + ID: agentID, + Name: agentName, + Model: model, + Fallbacks: fallbacks, + Workspace: workspace, + MaxIterations: maxIter, + ContextWindow: defaults.MaxTokens, + Provider: provider, + Sessions: sessionsManager, + ContextBuilder: contextBuilder, + Tools: toolsRegistry, + Subagents: subagents, + SkillsFilter: skillsFilter, + Candidates: candidates, + } +} + +// resolveAgentWorkspace determines the workspace directory for an agent. +func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" { + return expandHome(strings.TrimSpace(agentCfg.Workspace)) + } + if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" { + return expandHome(defaults.Workspace) + } + home, _ := os.UserHomeDir() + id := routing.NormalizeAgentID(agentCfg.ID) + return filepath.Join(home, ".picoclaw", "workspace-"+id) +} + +// resolveAgentModel resolves the primary model for an agent. +func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string { + if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" { + return strings.TrimSpace(agentCfg.Model.Primary) + } + return defaults.Model +} + +// resolveAgentFallbacks resolves the fallback models for an agent. +func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string { + if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil { + return agentCfg.Model.Fallbacks + } + return defaults.ModelFallbacks +} + +func expandHome(path string) string { + if path == "" { + return path + } + if path[0] == '~' { + home, _ := os.UserHomeDir() + if len(path) > 1 && path[1] == '/' { + return home + path[1:] + } + return home + } + return path +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f4627f907..570ff6cd5 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,8 +10,6 @@ import ( "context" "encoding/json" "fmt" - "os" - "path/filepath" "strings" "sync" "sync/atomic" @@ -24,7 +22,7 @@ import ( "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" - "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" @@ -32,17 +30,12 @@ import ( type AgentLoop struct { bus *bus.MessageBus - provider providers.LLMProvider - workspace string - model string - contextWindow int // Maximum context window size in tokens - maxIterations int - sessions *session.SessionManager + cfg *config.Config + registry *AgentRegistry state *state.Manager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry running atomic.Bool - summarizing sync.Map // Tracks which sessions are currently being summarized + summarizing sync.Map + fallback *providers.FallbackChain channelManager *channels.Manager } @@ -58,99 +51,83 @@ type processOptions struct { NoHistory bool // If true, don't load session history (for heartbeat) } -// createToolRegistry creates a tool registry with common tools. -// This is shared between main agent and subagents. -func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { - registry := tools.NewToolRegistry() - - // File system tools - registry.Register(tools.NewReadFileTool(workspace, restrict)) - registry.Register(tools.NewWriteFileTool(workspace, restrict)) - registry.Register(tools.NewListDirTool(workspace, restrict)) - registry.Register(tools.NewEditFileTool(workspace, restrict)) - registry.Register(tools.NewAppendFileTool(workspace, restrict)) - - // Shell execution - registry.Register(tools.NewExecTool(workspace, restrict)) - - if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ - BraveAPIKey: cfg.Tools.Web.Brave.APIKey, - BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, - BraveEnabled: cfg.Tools.Web.Brave.Enabled, - DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, - DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, - PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, - PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, - PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, - }); searchTool != nil { - registry.Register(searchTool) - } - registry.Register(tools.NewWebFetchTool(50000)) - - // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms - registry.Register(tools.NewI2CTool()) - registry.Register(tools.NewSPITool()) - - // Message tool - available to both agent and subagent - // Subagent uses it to communicate directly with user - messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - Content: content, - }) - return nil - }) - registry.Register(messageTool) - - return registry -} - func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) + registry := NewAgentRegistry(cfg, provider) - restrict := cfg.Agents.Defaults.RestrictToWorkspace + // Register shared tools to all agents + registerSharedTools(cfg, msgBus, registry, provider) - // Create tool registry for main agent - toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + // Set up shared fallback chain + cooldown := providers.NewCooldownTracker() + fallbackChain := providers.NewFallbackChain(cooldown) - // Create subagent manager with its own tool registry - subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) - subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) - // Subagent doesn't need spawn/subagent tools to avoid recursion - subagentManager.SetTools(subagentTools) - - // Register spawn tool (for main agent) - spawnTool := tools.NewSpawnTool(subagentManager) - toolsRegistry.Register(spawnTool) - - // Register subagent tool (synchronous execution) - subagentTool := tools.NewSubagentTool(subagentManager) - toolsRegistry.Register(subagentTool) - - sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) - - // Create state manager for atomic state persistence - stateManager := state.NewManager(workspace) - - // Create context builder and set tools registry - contextBuilder := NewContextBuilder(workspace) - contextBuilder.SetToolsRegistry(toolsRegistry) + // Create state manager using default agent's workspace for channel recording + defaultAgent := registry.GetDefaultAgent() + var stateManager *state.Manager + if defaultAgent != nil { + stateManager = state.NewManager(defaultAgent.Workspace) + } return &AgentLoop{ - bus: msgBus, - provider: provider, - workspace: workspace, - model: cfg.Agents.Defaults.Model, - contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - state: stateManager, - contextBuilder: contextBuilder, - tools: toolsRegistry, - summarizing: sync.Map{}, + bus: msgBus, + cfg: cfg, + registry: registry, + state: stateManager, + summarizing: sync.Map{}, + fallback: fallbackChain, + } +} + +// registerSharedTools registers tools that are shared across all agents (web, message, spawn). +func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider) { + for _, agentID := range registry.ListAgentIDs() { + agent, ok := registry.GetAgent(agentID) + if !ok { + continue + } + + // Web tools + if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey, + PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, + PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, + }); searchTool != nil { + agent.Tools.Register(searchTool) + } + agent.Tools.Register(tools.NewWebFetchTool(50000)) + + // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms + agent.Tools.Register(tools.NewI2CTool()) + agent.Tools.Register(tools.NewSPITool()) + + // Message tool + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + return nil + }) + agent.Tools.Register(messageTool) + + // Spawn tool with allowlist checker + subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus) + spawnTool := tools.NewSpawnTool(subagentManager) + currentAgentID := agentID + spawnTool.SetAllowlistChecker(func(targetAgentID string) bool { + return registry.CanSpawnSubagent(currentAgentID, targetAgentID) + }) + agent.Tools.Register(spawnTool) + + // Update context builder with the complete tools registry + agent.ContextBuilder.SetToolsRegistry(agent.Tools) } } @@ -175,10 +152,14 @@ func (al *AgentLoop) Run(ctx context.Context) error { if response != "" { // Check if the message tool already sent a response during this round. // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). alreadySent := false - if tool, ok := al.tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } } } @@ -201,7 +182,11 @@ func (al *AgentLoop) Stop() { } func (al *AgentLoop) RegisterTool(tool tools.Tool) { - al.tools.Register(tool) + for _, agentID := range al.registry.ListAgentIDs() { + if agent, ok := al.registry.GetAgent(agentID); ok { + agent.Tools.Register(tool) + } + } } func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { @@ -211,12 +196,18 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { + if al.state == nil { + return nil + } return al.state.SetLastChannel(channel) } // RecordLastChatID records the last active chat ID for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChatID(chatID string) error { + if al.state == nil { + return nil + } return al.state.SetLastChatID(chatID) } @@ -239,7 +230,8 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess // ProcessHeartbeat processes a heartbeat request without session history. // Each heartbeat is independent and doesn't accumulate context. func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { - return al.runAgentLoop(ctx, processOptions{ + agent := al.registry.GetDefaultAgent() + return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: "heartbeat", Channel: channel, ChatID: chatID, @@ -277,9 +269,36 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return response, nil } - // Process as user message - return al.runAgentLoop(ctx, processOptions{ - SessionKey: msg.SessionKey, + // Route to determine agent and session key + route := al.registry.ResolveRoute(routing.RouteInput{ + Channel: msg.Channel, + AccountID: msg.Metadata["account_id"], + Peer: extractPeer(msg), + ParentPeer: extractParentPeer(msg), + GuildID: msg.Metadata["guild_id"], + TeamID: msg.Metadata["team_id"], + }) + + agent, ok := al.registry.GetAgent(route.AgentID) + if !ok { + agent = al.registry.GetDefaultAgent() + } + + // Use routed session key, but honor pre-set agent-scoped keys (for ProcessDirect/cron) + sessionKey := route.SessionKey + if msg.SessionKey != "" && strings.HasPrefix(msg.SessionKey, "agent:") { + sessionKey = msg.SessionKey + } + + logger.InfoCF("agent", "Routed message", + map[string]interface{}{ + "agent_id": agent.ID, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + }) + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, Channel: msg.Channel, ChatID: msg.ChatID, UserMessage: msg.Content, @@ -290,7 +309,6 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Verify this is a system message if msg.Channel != "system" { return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel) } @@ -302,12 +320,13 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe }) // Parse origin channel from chat_id (format: "channel:chat_id") - var originChannel string + var originChannel, originChatID string if idx := strings.Index(msg.ChatID, ":"); idx > 0 { originChannel = msg.ChatID[:idx] + originChatID = msg.ChatID[idx+1:] } else { - // Fallback originChannel = "cli" + originChatID = msg.ChatID } // Extract subagent result from message content @@ -328,44 +347,47 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe return "", nil } - // Agent acts as dispatcher only - subagent handles user interaction via message tool - // Don't forward result here, subagent should use message tool to communicate with user - logger.InfoCF("agent", "Subagent completed", - map[string]interface{}{ - "sender_id": msg.SenderID, - "channel": originChannel, - "content_len": len(content), - }) + // Use default agent for system messages + agent := al.registry.GetDefaultAgent() - // Agent only logs, does not respond to user - return "", nil + // Use the origin session for context + sessionKey := routing.BuildAgentMainSessionKey(agent.ID) + + return al.runAgentLoop(ctx, agent, processOptions{ + SessionKey: sessionKey, + Channel: originChannel, + ChatID: originChatID, + UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), + DefaultResponse: "Background task completed.", + EnableSummary: false, + SendResponse: true, + }) } // runAgentLoop is the core message processing logic. -// It handles context building, LLM calls, tool execution, and response handling. -func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { +func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opts processOptions) (string, error) { // 0. Record last channel for heartbeat notifications (skip internal channels) if opts.Channel != "" && opts.ChatID != "" { // Don't record internal channels (cli, system, subagent) if !constants.IsInternalChannel(opts.Channel) { channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) if err := al.RecordLastChannel(channelKey); err != nil { - logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()}) + logger.WarnCF("agent", "Failed to record last channel", map[string]interface{}{"error": err.Error()}) } } } // 1. Update tool contexts - al.updateToolContexts(opts.Channel, opts.ChatID) + al.updateToolContexts(agent, opts.Channel, opts.ChatID) // 2. Build messages (skip history for heartbeat) var history []providers.Message var summary string if !opts.NoHistory { - history = al.sessions.GetHistory(opts.SessionKey) - summary = al.sessions.GetSummary(opts.SessionKey) + history = agent.Sessions.GetHistory(opts.SessionKey) + summary = agent.Sessions.GetSummary(opts.SessionKey) } - messages := al.contextBuilder.BuildMessages( + messages := agent.ContextBuilder.BuildMessages( history, summary, opts.UserMessage, @@ -375,10 +397,10 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str ) // 3. Save user message to session - al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) // 4. Run LLM iteration loop - finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts) + finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err } @@ -392,12 +414,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // 6. Save final assistant message to session - al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - al.sessions.Save(opts.SessionKey) + agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent) + agent.Sessions.Save(opts.SessionKey) // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID) + al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) } // 8. Optional: send response via bus @@ -413,6 +435,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str responsePreview := utils.Truncate(finalContent, 120) logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), map[string]interface{}{ + "agent_id": agent.ID, "session_key": opts.SessionKey, "iterations": iteration, "final_length": len(finalContent), @@ -422,28 +445,29 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str } // runLLMIteration executes the LLM call loop with tool handling. -// Returns the final content, iteration count, and any error. -func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) { +func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions) (string, int, error) { iteration := 0 var finalContent string - for iteration < al.maxIterations { + for iteration < agent.MaxIterations { iteration++ logger.DebugCF("agent", "LLM iteration", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, - "max": al.maxIterations, + "max": agent.MaxIterations, }) // Build tool definitions - providerToolDefs := al.tools.ToProviderDefs() + providerToolDefs := agent.Tools.ToProviderDefs() // Log LLM request details logger.DebugCF("agent", "LLM request", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, - "model": al.model, + "model": agent.Model, "messages_count": len(messages), "tools_count": len(providerToolDefs), "max_tokens": 8192, @@ -459,23 +483,45 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "tools_json": formatToolsForLog(providerToolDefs), }) + // Call LLM with fallback chain if candidates are configured. var response *providers.LLMResponse var err error + callLLM := func() (*providers.LLMResponse, error) { + if len(agent.Candidates) > 1 && al.fallback != nil { + fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates, + func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) { + return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + }, + ) + if fbErr != nil { + return nil, fbErr + } + if fbResult.Provider != "" && len(fbResult.Attempts) > 0 { + logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts", + fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1), + map[string]interface{}{"agent_id": agent.ID, "iteration": iteration}) + } + return fbResult.Response, nil + } + return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + } + // Retry loop for context/token errors maxRetries := 2 for retry := 0; retry <= maxRetries; retry++ { - response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) - + response, err = callLLM() if err == nil { - break // Success + break } errMsg := strings.ToLower(err.Error()) - // Check for context window errors (provider specific, but usually contain "token" or "invalid") isContextError := strings.Contains(errMsg, "token") || strings.Contains(errMsg, "context") || strings.Contains(errMsg, "invalidparameter") || @@ -487,107 +533,30 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "retry": retry, }) - // Notify user on first retry only - if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse { + if retry == 0 && !constants.IsInternalChannel(opts.Channel) { al.bus.PublishOutbound(bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, - Content: "⚠️ Context window exceeded. Compressing history and retrying...", + Content: "Context window exceeded. Compressing history and retrying...", }) } - // Force compression - al.forceCompression(opts.SessionKey) - - // Rebuild messages with compressed history - // Note: We need to reload history from session manager because forceCompression changed it - newHistory := al.sessions.GetHistory(opts.SessionKey) - newSummary := al.sessions.GetSummary(opts.SessionKey) - - // Re-create messages for the next attempt - // We keep the current user message (opts.UserMessage) effectively - messages = al.contextBuilder.BuildMessages( - newHistory, - newSummary, - opts.UserMessage, - nil, - opts.Channel, - opts.ChatID, + al.forceCompression(agent, opts.SessionKey) + newHistory := agent.Sessions.GetHistory(opts.SessionKey) + newSummary := agent.Sessions.GetSummary(opts.SessionKey) + messages = agent.ContextBuilder.BuildMessages( + newHistory, newSummary, "", + nil, opts.Channel, opts.ChatID, ) - - // Important: If we are in the middle of a tool loop (iteration > 1), - // rebuilding messages from session history might duplicate the flow or miss context - // if intermediate steps weren't saved correctly. - // However, al.sessions.AddFullMessage is called after every tool execution, - // so GetHistory should reflect the current state including partial tool execution. - // But we need to ensure we don't duplicate the user message which is appended in BuildMessages. - // BuildMessages(history...) takes the stored history and appends the *current* user message. - // If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop. - // So if we pass opts.UserMessage again, we might duplicate it? - // Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - // So GetHistory ALREADY contains the user message! - - // CORRECTION: - // BuildMessages combines: [System] + [History] + [CurrentMessage] - // But Step 3 added CurrentMessage to History. - // So if we use GetHistory now, it has the user message. - // If we pass opts.UserMessage to BuildMessages, it adds it AGAIN. - - // For retry in the middle of a loop, we should rely on what's in the session. - // BUT checking BuildMessages implementation: - // It appends history... then appends currentMessage. - - // Logic fix for retry: - // If iteration == 1, opts.UserMessage corresponds to the user input. - // If iteration > 1, we are processing tool results. The "messages" passed to Chat - // already accumulated tool outputs. - // Rebuilding from session history is safest because it persists state. - // Start fresh with rebuilt history. - - // Special case: standard BuildMessages appends "currentMessage". - // If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed. - // However, the "messages" argument passed to runLLMIteration is constructed by the caller. - // If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history. - - // In runAgentLoop: - // 3. sessions.AddMessage(userMsg) - // 4. runLLMIteration(..., UserMessage) - - // So History contains the user message. - // BuildMessages typically appends the user message as a *new* pending message. - // Wait, standard BuildMessages usage in runAgentLoop: - // messages := BuildMessages(history (has old), UserMessage) - // THEN AddMessage(UserMessage). - // So "history" passed to BuildMessages does NOT contain the current UserMessage yet. - - // But here, inside the loop, we have already saved it. - // So GetHistory() includes the current user message. - // If we call BuildMessages(GetHistory(), UserMessage), we get duplicates. - - // Hack/Fix: - // If we are retrying, we rebuild from Session History ONLY. - // We pass empty string as "currentMessage" to BuildMessages - // because the "current message" is already saved in history (step 3). - - messages = al.contextBuilder.BuildMessages( - newHistory, - newSummary, - "", // Empty because history already contains the relevant messages - nil, - opts.Channel, - opts.ChatID, - ) - continue } - - // Real error or success, break loop break } if err != nil { logger.ErrorCF("agent", "LLM call failed", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, "error": err.Error(), }) @@ -599,6 +568,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M finalContent = response.Content logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]interface{}{ + "agent_id": agent.ID, "iteration": iteration, "content_chars": len(finalContent), }) @@ -617,6 +587,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } logger.InfoCF("agent", "LLM requested tool calls", map[string]interface{}{ + "agent_id": agent.ID, "tools": toolNames, "count": len(normalizedToolCalls), "iteration": iteration, @@ -649,15 +620,15 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M messages = append(messages, assistantMsg) // Save assistant message with tool calls to session - al.sessions.AddFullMessage(opts.SessionKey, assistantMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg) // Execute tool calls for _, tc := range normalizedToolCalls { - // Log tool call with arguments preview argsJSON, _ := json.Marshal(tc.Arguments) argsPreview := utils.Truncate(string(argsJSON), 200) logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), map[string]interface{}{ + "agent_id": agent.ID, "tool": tc.Name, "iteration": iteration, }) @@ -678,7 +649,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } } - toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) + toolResult := agent.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { @@ -708,7 +679,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M messages = append(messages, toolResultMsg) // Save tool result message to session - al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } } @@ -716,19 +687,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M } // updateToolContexts updates the context for tools that need channel/chatID info. -func (al *AgentLoop) updateToolContexts(channel, chatID string) { +func (al *AgentLoop) updateToolContexts(agent *AgentInstance, channel, chatID string) { // Use ContextualTool interface instead of type assertions - if tool, ok := al.tools.Get("message"); ok { + if tool, ok := agent.Tools.Get("message"); ok { if mt, ok := tool.(tools.ContextualTool); ok { mt.SetContext(channel, chatID) } } - if tool, ok := al.tools.Get("spawn"); ok { + if tool, ok := agent.Tools.Get("spawn"); ok { if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } } - if tool, ok := al.tools.Get("subagent"); ok { + if tool, ok := agent.Tools.Get("subagent"); ok { if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } @@ -736,24 +707,24 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) { } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) { - newHistory := al.sessions.GetHistory(sessionKey) +func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) { + newHistory := agent.Sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) - threshold := al.contextWindow * 75 / 100 + threshold := agent.ContextWindow * 75 / 100 if len(newHistory) > 20 || tokenEstimate > threshold { - if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading { + summarizeKey := agent.ID + ":" + sessionKey + if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading { go func() { - defer al.summarizing.Delete(sessionKey) - // Notify user about optimization if not an internal channel + defer al.summarizing.Delete(summarizeKey) if !constants.IsInternalChannel(channel) { al.bus.PublishOutbound(bus.OutboundMessage{ Channel: channel, ChatID: chatID, - Content: "⚠️ Memory threshold reached. Optimizing conversation history...", + Content: "Memory threshold reached. Optimizing conversation history...", }) } - al.summarizeSession(sessionKey) + al.summarizeSession(agent, sessionKey) }() } } @@ -761,8 +732,8 @@ func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) { // forceCompression aggressively reduces context when the limit is hit. // It drops the oldest 50% of messages (keeping system prompt and last user message). -func (al *AgentLoop) forceCompression(sessionKey string) { - history := al.sessions.GetHistory(sessionKey) +func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { + history := agent.Sessions.GetHistory(sessionKey) if len(history) <= 4 { return } @@ -799,8 +770,8 @@ func (al *AgentLoop) forceCompression(sessionKey string) { newHistory = append(newHistory, history[len(history)-1]) // Last message // Update session - al.sessions.SetHistory(sessionKey, newHistory) - al.sessions.Save(sessionKey) + agent.Sessions.SetHistory(sessionKey, newHistory) + agent.Sessions.Save(sessionKey) logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{ "session_key": sessionKey, @@ -813,15 +784,26 @@ func (al *AgentLoop) forceCompression(sessionKey string) { func (al *AgentLoop) GetStartupInfo() map[string]interface{} { info := make(map[string]interface{}) + agent := al.registry.GetDefaultAgent() + if agent == nil { + return info + } + // Tools info - tools := al.tools.List() + toolsList := agent.Tools.List() info["tools"] = map[string]interface{}{ - "count": len(tools), - "names": tools, + "count": len(toolsList), + "names": toolsList, } // Skills info - info["skills"] = al.contextBuilder.GetSkillsInfo() + info["skills"] = agent.ContextBuilder.GetSkillsInfo() + + // Agents info + info["agents"] = map[string]interface{}{ + "count": len(al.registry.ListAgentIDs()), + "ids": al.registry.ListAgentIDs(), + } return info } @@ -878,12 +860,12 @@ func formatToolsForLog(tools []providers.ToolDefinition) string { } // summarizeSession summarizes the conversation history for a session. -func (al *AgentLoop) summarizeSession(sessionKey string) { +func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - history := al.sessions.GetHistory(sessionKey) - summary := al.sessions.GetSummary(sessionKey) + history := agent.Sessions.GetHistory(sessionKey) + summary := agent.Sessions.GetSummary(sessionKey) // Keep last 4 messages for continuity if len(history) <= 4 { @@ -893,8 +875,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { toSummarize := history[:len(history)-4] // Oversized Message Guard - // Skip messages larger than 50% of context window to prevent summarizer overflow - maxMessageTokens := al.contextWindow / 2 + maxMessageTokens := agent.ContextWindow / 2 validMessages := make([]providers.Message, 0) omitted := false @@ -902,8 +883,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { if m.Role != "user" && m.Role != "assistant" { continue } - // Estimate tokens for this message - msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety) + msgTokens := len(m.Content) / 2 if msgTokens > maxMessageTokens { omitted = true continue @@ -916,19 +896,17 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } // Multi-Part Summarization - // Split into two parts if history is significant var finalSummary string if len(validMessages) > 10 { mid := len(validMessages) / 2 part1 := validMessages[:mid] part2 := validMessages[mid:] - s1, _ := al.summarizeBatch(ctx, part1, "") - s2, _ := al.summarizeBatch(ctx, part2, "") + s1, _ := al.summarizeBatch(ctx, agent, part1, "") + s2, _ := al.summarizeBatch(ctx, agent, part2, "") - // Merge them mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2) - resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{ + resp, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, agent.Model, map[string]interface{}{ "max_tokens": 1024, "temperature": 0.3, }) @@ -938,7 +916,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { finalSummary = s1 + " " + s2 } } else { - finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary) + finalSummary, _ = al.summarizeBatch(ctx, agent, validMessages, summary) } if omitted && finalSummary != "" { @@ -946,14 +924,14 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { } if finalSummary != "" { - al.sessions.SetSummary(sessionKey, finalSummary) - al.sessions.TruncateHistory(sessionKey, 4) - al.sessions.Save(sessionKey) + agent.Sessions.SetSummary(sessionKey, finalSummary) + agent.Sessions.TruncateHistory(sessionKey, 4) + agent.Sessions.Save(sessionKey) } } // summarizeBatch summarizes a batch of messages. -func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) { +func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, batch []providers.Message, existingSummary string) (string, error) { prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n" if existingSummary != "" { prompt += "Existing context: " + existingSummary + "\n" @@ -963,7 +941,7 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content) } - response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{ + response, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, agent.Model, map[string]interface{}{ "max_tokens": 1024, "temperature": 0.3, }) @@ -1002,25 +980,31 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) switch cmd { case "/show": if len(args) < 1 { - return "Usage: /show [model|channel]", true + return "Usage: /show [model|channel|agents]", true } switch args[0] { case "model": - return fmt.Sprintf("Current model: %s", al.model), true + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + return "No default agent configured", true + } + return fmt.Sprintf("Current model: %s", defaultAgent.Model), true case "channel": return fmt.Sprintf("Current channel: %s", msg.Channel), true + case "agents": + agentIDs := al.registry.ListAgentIDs() + return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true default: return fmt.Sprintf("Unknown show target: %s", args[0]), true } case "/list": if len(args) < 1 { - return "Usage: /list [models|channels]", true + return "Usage: /list [models|channels|agents]", true } switch args[0] { case "models": - // TODO: Fetch available models dynamically if possible - return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true + return "Available models: configured in config.json per agent", true case "channels": if al.channelManager == nil { return "Channel manager not initialized", true @@ -1030,6 +1014,9 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "No channels enabled", true } return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true + case "agents": + agentIDs := al.registry.ListAgentIDs() + return fmt.Sprintf("Registered agents: %s", strings.Join(agentIDs, ", ")), true default: return fmt.Sprintf("Unknown list target: %s", args[0]), true } @@ -1043,23 +1030,21 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) switch target { case "model": - oldModel := al.model - al.model = value + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + return "No default agent configured", true + } + oldModel := defaultAgent.Model + defaultAgent.Model = value return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true case "channel": - // This changes the 'default' channel for some operations, or effectively redirects output? - // For now, let's just validate if the channel exists if al.channelManager == nil { return "Channel manager not initialized", true } if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { return fmt.Sprintf("Channel '%s' not found or not enabled", value), true } - - // If message came from CLI, maybe we want to redirect CLI output to this channel? - // That would require state persistence about "redirected channel" - // For now, just acknowledged. - return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true + return fmt.Sprintf("Switched target channel to %s", value), true default: return fmt.Sprintf("Unknown switch target: %s", target), true } @@ -1067,3 +1052,30 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } + +// extractPeer extracts the routing peer from inbound message metadata. +func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { + peerKind := msg.Metadata["peer_kind"] + if peerKind == "" { + return nil + } + peerID := msg.Metadata["peer_id"] + if peerID == "" { + if peerKind == "direct" { + peerID = msg.SenderID + } else { + peerID = msg.ChatID + } + } + return &routing.RoutePeer{Kind: peerKind, ID: peerID} +} + +// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. +func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { + parentKind := msg.Metadata["parent_peer_kind"] + parentID := msg.Metadata["parent_peer_id"] + if parentKind == "" || parentID == "" { + return nil + } + return &routing.RoutePeer{Kind: parentKind, ID: parentID} +} diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 0bd38abf4..f2257973c 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -594,7 +594,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { {Role: "assistant", Content: "Old response 2"}, {Role: "user", Content: "Trigger message"}, } - al.sessions.SetHistory(sessionKey, history) + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("No default agent found") + } + defaultAgent.Sessions.SetHistory(sessionKey, history) // Call ProcessDirectWithChannel // Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration @@ -614,7 +618,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { } // Check final history length - finalHistory := al.sessions.GetHistory(sessionKey) + finalHistory := defaultAgent.Sessions.GetHistory(sessionKey) // We verify that the history has been modified (compressed) // Original length: 6 // Expected behavior: compression drops ~50% of history (mid slice) diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go new file mode 100644 index 000000000..4cf5a6fca --- /dev/null +++ b/pkg/agent/registry.go @@ -0,0 +1,114 @@ +package agent + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" +) + +// AgentRegistry manages multiple agent instances and routes messages to them. +type AgentRegistry struct { + agents map[string]*AgentInstance + resolver *routing.RouteResolver + mu sync.RWMutex +} + +// NewAgentRegistry creates a registry from config, instantiating all agents. +func NewAgentRegistry( + cfg *config.Config, + provider providers.LLMProvider, +) *AgentRegistry { + registry := &AgentRegistry{ + agents: make(map[string]*AgentInstance), + resolver: routing.NewRouteResolver(cfg), + } + + agentConfigs := cfg.Agents.List + if len(agentConfigs) == 0 { + implicitAgent := &config.AgentConfig{ + ID: "main", + Default: true, + } + instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, cfg, provider) + registry.agents["main"] = instance + logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil) + } else { + for i := range agentConfigs { + ac := &agentConfigs[i] + id := routing.NormalizeAgentID(ac.ID) + instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider) + registry.agents[id] = instance + logger.InfoCF("agent", "Registered agent", + map[string]interface{}{ + "agent_id": id, + "name": ac.Name, + "workspace": instance.Workspace, + "model": instance.Model, + }) + } + } + + return registry +} + +// GetAgent returns the agent instance for a given ID. +func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + id := routing.NormalizeAgentID(agentID) + agent, ok := r.agents[id] + return agent, ok +} + +// ResolveRoute determines which agent handles the message. +func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute { + return r.resolver.ResolveRoute(input) +} + +// ListAgentIDs returns all registered agent IDs. +func (r *AgentRegistry) ListAgentIDs() []string { + r.mu.RLock() + defer r.mu.RUnlock() + ids := make([]string, 0, len(r.agents)) + for id := range r.agents { + ids = append(ids, id) + } + return ids +} + +// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID. +func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool { + parent, ok := r.GetAgent(parentAgentID) + if !ok { + return false + } + if parent.Subagents == nil || parent.Subagents.AllowAgents == nil { + return false + } + targetNorm := routing.NormalizeAgentID(targetAgentID) + for _, allowed := range parent.Subagents.AllowAgents { + if allowed == "*" { + return true + } + if routing.NormalizeAgentID(allowed) == targetNorm { + return true + } + } + return false +} + +// GetDefaultAgent returns the default agent instance. +func (r *AgentRegistry) GetDefaultAgent() *AgentInstance { + r.mu.RLock() + defer r.mu.RUnlock() + if agent, ok := r.agents["main"]; ok { + return agent + } + for _, agent := range r.agents { + return agent + } + return nil +} diff --git a/pkg/agent/registry_test.go b/pkg/agent/registry_test.go new file mode 100644 index 000000000..f196d7fb7 --- /dev/null +++ b/pkg/agent/registry_test.go @@ -0,0 +1,199 @@ +package agent + +import ( + "context" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +type mockRegistryProvider struct{} + +func (m *mockRegistryProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil +} + +func (m *mockRegistryProvider) GetDefaultModel() string { + return "mock-model" +} + +func testCfg(agents []config.AgentConfig) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test-registry", + Model: "gpt-4", + MaxTokens: 8192, + MaxToolIterations: 10, + }, + List: agents, + }, + } +} + +func TestNewAgentRegistry_ImplicitMain(t *testing.T) { + cfg := testCfg(nil) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 1 || ids[0] != "main" { + t.Errorf("expected implicit main agent, got %v", ids) + } + + agent, ok := registry.GetAgent("main") + if !ok || agent == nil { + t.Fatal("expected to find 'main' agent") + } + if agent.ID != "main" { + t.Errorf("agent.ID = %q, want 'main'", agent.ID) + } +} + +func TestNewAgentRegistry_ExplicitAgents(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "sales", Default: true, Name: "Sales Bot"}, + {ID: "support", Name: "Support Bot"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + ids := registry.ListAgentIDs() + if len(ids) != 2 { + t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids) + } + + sales, ok := registry.GetAgent("sales") + if !ok || sales == nil { + t.Fatal("expected to find 'sales' agent") + } + if sales.Name != "Sales Bot" { + t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name) + } + + support, ok := registry.GetAgent("support") + if !ok || support == nil { + t.Fatal("expected to find 'support' agent") + } +} + +func TestAgentRegistry_GetAgent_Normalize(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "my-agent", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, ok := registry.GetAgent("My-Agent") + if !ok || agent == nil { + t.Fatal("expected to find agent with normalized ID") + } + if agent.ID != "my-agent" { + t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID) + } +} + +func TestAgentRegistry_GetDefaultAgent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + // GetDefaultAgent first checks for "main", then returns any + agent := registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected a default agent") + } +} + +func TestAgentRegistry_CanSpawnSubagent(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "parent", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"child1", "child2"}, + }, + }, + {ID: "child1"}, + {ID: "child2"}, + {ID: "restricted"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + if !registry.CanSpawnSubagent("parent", "child1") { + t.Error("expected parent to be allowed to spawn child1") + } + if !registry.CanSpawnSubagent("parent", "child2") { + t.Error("expected parent to be allowed to spawn child2") + } + if registry.CanSpawnSubagent("parent", "restricted") { + t.Error("expected parent to NOT be allowed to spawn restricted") + } + if registry.CanSpawnSubagent("child1", "child2") { + t.Error("expected child1 to NOT be allowed to spawn (no subagents config)") + } +} + +func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + { + ID: "admin", + Default: true, + Subagents: &config.SubagentsConfig{ + AllowAgents: []string{"*"}, + }, + }, + {ID: "any-agent"}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + if !registry.CanSpawnSubagent("admin", "any-agent") { + t.Error("expected wildcard to allow spawning any agent") + } + if !registry.CanSpawnSubagent("admin", "nonexistent") { + t.Error("expected wildcard to allow spawning even nonexistent agents") + } +} + +func TestAgentInstance_Model(t *testing.T) { + model := &config.AgentModelConfig{Primary: "claude-opus"} + cfg := testCfg([]config.AgentConfig{ + {ID: "custom", Default: true, Model: model}, + }) + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("custom") + if agent.Model != "claude-opus" { + t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model) + } +} + +func TestAgentInstance_FallbackInheritance(t *testing.T) { + cfg := testCfg([]config.AgentConfig{ + {ID: "inherit", Default: true}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"} + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("inherit") + if len(agent.Fallbacks) != 2 { + t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks)) + } +} + +func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) { + model := &config.AgentModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{}, // explicitly empty = disable + } + cfg := testCfg([]config.AgentConfig{ + {ID: "no-fallback", Default: true, Model: model}, + }) + cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"} + registry := NewAgentRegistry(cfg, &mockRegistryProvider{}) + + agent, _ := registry.GetAgent("no-fallback") + if len(agent.Fallbacks) != 0 { + t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks) + } +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 8d2d9a65b..4925099a3 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -2,7 +2,6 @@ package channels import ( "context" - "fmt" "strings" "github.com/sipeed/picoclaw/pkg/bus" @@ -87,17 +86,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st return } - // Build session key: channel:chatID - sessionKey := fmt.Sprintf("%s:%s", c.name, chatID) - msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - SessionKey: sessionKey, - Metadata: metadata, + Channel: c.name, + SenderID: senderID, + ChatID: chatID, + Content: content, + Media: media, + Metadata: metadata, } c.bus.PublishInbound(msg) diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index 00aa8ab4d..9ddec662c 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "os" - "strings" + "sync" "time" "github.com/bwmarrin/discordgo" @@ -26,6 +26,8 @@ type DiscordChannel struct { config config.DiscordConfig transcriber *voice.GroqTranscriber ctx context.Context + typingMu sync.Mutex + typingStop map[string]chan struct{} // chatID → stop signal } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -42,6 +44,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC config: cfg, transcriber: nil, ctx: context.Background(), + typingStop: make(map[string]chan struct{}), }, nil } @@ -84,6 +87,14 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { logger.InfoC("discord", "Stopping Discord bot") c.setRunning(false) + // Stop all typing goroutines before closing session + c.typingMu.Lock() + for chatID, stop := range c.typingStop { + close(stop) + delete(c.typingStop, chatID) + } + c.typingMu.Unlock() + if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) } @@ -92,6 +103,8 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + c.stopTyping(msg.ChatID) + if !c.IsRunning() { return fmt.Errorf("discord bot not running") } @@ -106,7 +119,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return nil } - chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks + chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars for _, chunk := range chunks { if err := c.sendChunk(ctx, channelID, chunk); err != nil { @@ -117,132 +130,6 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return nil } -// splitMessage splits long messages into chunks, preserving code block integrity -// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks -func splitMessage(content string, limit int) []string { - var messages []string - - for len(content) > 0 { - if len(content) <= limit { - messages = append(messages, content) - break - } - - msgEnd := limit - - // Find natural split point within the limit - msgEnd = findLastNewline(content[:limit], 200) - if msgEnd <= 0 { - msgEnd = findLastSpace(content[:limit], 100) - } - if msgEnd <= 0 { - msgEnd = limit - } - - // Check if this would end with an incomplete code block - candidate := content[:msgEnd] - unclosedIdx := findLastUnclosedCodeBlock(candidate) - - if unclosedIdx >= 0 { - // Message would end with incomplete code block - // Try to extend to include the closing ``` (with some buffer) - extendedLimit := limit + 500 // Allow 500 char buffer for code blocks - if len(content) > extendedLimit { - closingIdx := findNextClosingCodeBlock(content, msgEnd) - if closingIdx > 0 && closingIdx <= extendedLimit { - // Extend to include the closing ``` - msgEnd = closingIdx - } else { - // Can't find closing, split before the code block - msgEnd = findLastNewline(content[:unclosedIdx], 200) - if msgEnd <= 0 { - msgEnd = findLastSpace(content[:unclosedIdx], 100) - } - if msgEnd <= 0 { - msgEnd = unclosedIdx - } - } - } else { - // Remaining content fits within extended limit - msgEnd = len(content) - } - } - - if msgEnd <= 0 { - msgEnd = limit - } - - messages = append(messages, content[:msgEnd]) - content = strings.TrimSpace(content[msgEnd:]) - } - - return messages -} - -// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` -// Returns the position of the opening ``` or -1 if all code blocks are complete -func findLastUnclosedCodeBlock(text string) int { - count := 0 - lastOpenIdx := -1 - - for i := 0; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { - if count == 0 { - lastOpenIdx = i - } - count++ - i += 2 - } - } - - // If odd number of ``` markers, last one is unclosed - if count%2 == 1 { - return lastOpenIdx - } - return -1 -} - -// findNextClosingCodeBlock finds the next closing ``` starting from a position -// Returns the position after the closing ``` or -1 if not found -func findNextClosingCodeBlock(text string, startIdx int) int { - for i := startIdx; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { - return i + 3 - } - } - return -1 -} - -// findLastNewline finds the last newline character within the last N characters -// Returns the position of the newline or -1 if not found -func findLastNewline(s string, searchWindow int) int { - searchStart := len(s) - searchWindow - if searchStart < 0 { - searchStart = 0 - } - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == '\n' { - return i - } - } - return -1 -} - -// findLastSpace finds the last space character within the last N characters -// Returns the position of the space or -1 if not found -func findLastSpace(s string, searchWindow int) int { - searchStart := len(s) - searchWindow - if searchStart < 0 { - searchStart = 0 - } - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == ' ' || s[i] == '\t' { - return i - } - } - return -1 -} - func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { // 使用传入的 ctx 进行超时控制 sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) @@ -282,12 +169,6 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } - if err := c.session.ChannelTyping(m.ChannelID); err != nil { - logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{ - "error": err.Error(), - }) - } - // 检查白名单,避免为被拒绝的用户下载附件和转录 if !c.IsAllowed(m.Author.ID) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ @@ -370,12 +251,22 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } + // Start typing after all early returns — guaranteed to have a matching Send() + c.startTyping(m.ChannelID) + logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, "preview": utils.Truncate(content, 50), }) + peerKind := "channel" + peerID := m.ChannelID + if m.GuildID == "" { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_id": m.ID, "user_id": senderID, @@ -384,11 +275,59 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } +// startTyping starts a continuous typing indicator loop for the given chatID. +// It stops any existing typing loop for that chatID before starting a new one. +func (c *DiscordChannel) startTyping(chatID string) { + c.typingMu.Lock() + // Stop existing loop for this chatID if any + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + } + stop := make(chan struct{}) + c.typingStop[chatID] = stop + c.typingMu.Unlock() + + go func() { + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + } + ticker := time.NewTicker(8 * time.Second) + defer ticker.Stop() + timeout := time.After(5 * time.Minute) + for { + select { + case <-stop: + return + case <-timeout: + return + case <-c.ctx.Done(): + return + case <-ticker.C: + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + } + } + } + }() +} + +// stopTyping stops the typing indicator loop for the given chatID. +func (c *DiscordChannel) stopTyping(chatID string) { + c.typingMu.Lock() + defer c.typingMu.Unlock() + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + delete(c.typingStop, chatID) + } +} + func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam.go index 5fc19adbe..01e570b25 100644 --- a/pkg/channels/maixcam.go +++ b/pkg/channels/maixcam.go @@ -18,7 +18,6 @@ type MaixCamChannel struct { listener net.Listener clients map[net.Conn]bool clientsMux sync.RWMutex - running bool } type MaixCamMessage struct { @@ -35,7 +34,6 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC BaseChannel: base, config: cfg, clients: make(map[net.Conn]bool), - running: false, }, nil } diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot.go index 5d97fab9c..53e82b44d 100644 --- a/pkg/channels/onebot.go +++ b/pkg/channels/onebot.go @@ -4,9 +4,11 @@ import ( "context" "encoding/json" "fmt" + "os" "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -14,20 +16,28 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" ) type OneBotChannel struct { *BaseChannel - config config.OneBotConfig - conn *websocket.Conn - ctx context.Context - cancel context.CancelFunc - dedup map[string]struct{} - dedupRing []string - dedupIdx int - mu sync.Mutex - writeMu sync.Mutex - echoCounter int64 + config config.OneBotConfig + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + dedup map[string]struct{} + dedupRing []string + dedupIdx int + mu sync.Mutex + writeMu sync.Mutex + echoCounter int64 + selfID int64 + pending map[string]chan json.RawMessage + pendingMu sync.Mutex + transcriber *voice.GroqTranscriber + lastMessageID sync.Map + pendingEmojiMsg sync.Map } type oneBotRawEvent struct { @@ -43,9 +53,11 @@ type oneBotRawEvent struct { SelfID json.RawMessage `json:"self_id"` Time json.RawMessage `json:"time"` MetaEventType string `json:"meta_event_type"` + NoticeType string `json:"notice_type"` Echo string `json:"echo"` RetCode json.RawMessage `json:"retcode"` - Status BotStatus `json:"status"` + Status json.RawMessage `json:"status"` + Data json.RawMessage `json:"data"` } type BotStatus struct { @@ -53,42 +65,36 @@ type BotStatus struct { Good bool `json:"good"` } +func isAPIResponse(raw json.RawMessage) bool { + if len(raw) == 0 { + return false + } + var s string + if json.Unmarshal(raw, &s) == nil { + return s == "ok" || s == "failed" + } + var bs BotStatus + if json.Unmarshal(raw, &bs) == nil { + return bs.Online || bs.Good + } + return false +} + type oneBotSender struct { UserID json.RawMessage `json:"user_id"` Nickname string `json:"nickname"` Card string `json:"card"` } -type oneBotEvent struct { - PostType string - MessageType string - SubType string - MessageID string - UserID int64 - GroupID int64 - Content string - RawContent string - IsBotMentioned bool - Sender oneBotSender - SelfID int64 - Time int64 - MetaEventType string -} - type oneBotAPIRequest struct { Action string `json:"action"` Params interface{} `json:"params"` Echo string `json:"echo,omitempty"` } -type oneBotSendPrivateMsgParams struct { - UserID int64 `json:"user_id"` - Message string `json:"message"` -} - -type oneBotSendGroupMsgParams struct { - GroupID int64 `json:"group_id"` - Message string `json:"message"` +type oneBotMessageSegment struct { + Type string `json:"type"` + Data map[string]interface{} `json:"data"` } func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { @@ -101,9 +107,30 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One dedup: make(map[string]struct{}, dedupSize), dedupRing: make([]string, dedupSize), dedupIdx: 0, + pending: make(map[string]chan json.RawMessage), }, nil } +func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { + go func() { + _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{ + "message_id": messageID, + "emoji_id": emojiID, + "set": set, + }, 5*time.Second) + if err != nil { + logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{ + "message_id": messageID, + "error": err.Error(), + }) + } + }() +} + func (c *OneBotChannel) Start(ctx context.Context) error { if c.config.WSUrl == "" { return fmt.Errorf("OneBot ws_url not configured") @@ -121,12 +148,12 @@ func (c *OneBotChannel) Start(ctx context.Context) error { }) } else { go c.listen() + c.fetchSelfID() } if c.config.ReconnectInterval > 0 { go c.reconnectLoop() } else { - // If reconnect is disabled but initial connection failed, we cannot recover if c.conn == nil { return fmt.Errorf("failed to connect to OneBot and reconnect is disabled") } @@ -152,14 +179,141 @@ func (c *OneBotChannel) connect() error { return err } + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.mu.Lock() c.conn = conn c.mu.Unlock() + go c.pinger(conn) + logger.InfoC("onebot", "WebSocket connected") return nil } +func (c *OneBotChannel) pinger(conn *websocket.Conn) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.writeMu.Lock() + err := conn.WriteMessage(websocket.PingMessage, nil) + c.writeMu.Unlock() + if err != nil { + logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{ + "error": err.Error(), + }) + return + } + } + } +} + +func (c *OneBotChannel) fetchSelfID() { + resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second) + if err != nil { + logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + type loginInfo struct { + UserID json.RawMessage `json:"user_id"` + Nickname string `json:"nickname"` + } + for _, extract := range []func() (*loginInfo, error){ + func() (*loginInfo, error) { + var w struct { + Data loginInfo `json:"data"` + } + err := json.Unmarshal(resp, &w) + return &w.Data, err + }, + func() (*loginInfo, error) { + var f loginInfo + err := json.Unmarshal(resp, &f) + return &f, err + }, + } { + info, err := extract() + if err != nil || len(info.UserID) == 0 { + continue + } + if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 { + atomic.StoreInt64(&c.selfID, uid) + logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{ + "self_id": uid, + "nickname": info.Nickname, + }) + return + } + } + + logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{ + "response": string(resp), + }) +} + +func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return nil, fmt.Errorf("WebSocket not connected") + } + + echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1)) + + ch := make(chan json.RawMessage, 1) + c.pendingMu.Lock() + c.pending[echo] = ch + c.pendingMu.Unlock() + + defer func() { + c.pendingMu.Lock() + delete(c.pending, echo) + c.pendingMu.Unlock() + }() + + req := oneBotAPIRequest{ + Action: action, + Params: params, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal API request: %w", err) + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.writeMu.Unlock() + + if err != nil { + return nil, fmt.Errorf("failed to write API request: %w", err) + } + + select { + case resp := <-ch: + return resp, nil + case <-time.After(timeout): + return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) + case <-c.ctx.Done(): + return nil, fmt.Errorf("context cancelled") + } +} + func (c *OneBotChannel) reconnectLoop() { interval := time.Duration(c.config.ReconnectInterval) * time.Second if interval < 5*time.Second { @@ -183,6 +337,7 @@ func (c *OneBotChannel) reconnectLoop() { }) } else { go c.listen() + c.fetchSelfID() } } } @@ -197,6 +352,13 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { c.cancel() } + c.pendingMu.Lock() + for echo, ch := range c.pending { + close(ch) + delete(c.pending, echo) + } + c.pendingMu.Unlock() + c.mu.Lock() if c.conn != nil { c.conn.Close() @@ -225,10 +387,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return err } - c.writeMu.Lock() - c.echoCounter++ - echo := fmt.Sprintf("send_%d", c.echoCounter) - c.writeMu.Unlock() + echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) req := oneBotAPIRequest{ Action: action, @@ -252,67 +411,78 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return err } + if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { + if mid, ok := msgID.(string); ok && mid != "" { + c.setMsgEmojiLike(mid, 289, false) + } + } + return nil } +func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment { + var segments []oneBotMessageSegment + + if lastMsgID, ok := c.lastMessageID.Load(chatID); ok { + if msgID, ok := lastMsgID.(string); ok && msgID != "" { + segments = append(segments, oneBotMessageSegment{ + Type: "reply", + Data: map[string]interface{}{"id": msgID}, + }) + } + } + + segments = append(segments, oneBotMessageSegment{ + Type: "text", + Data: map[string]interface{}{"text": content}, + }) + + return segments +} + func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) { chatID := msg.ChatID + segments := c.buildMessageSegments(chatID, msg.Content) - if len(chatID) > 6 && chatID[:6] == "group:" { - groupID, err := strconv.ParseInt(chatID[6:], 10, 64) - if err != nil { - return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID) - } - return "send_group_msg", oneBotSendGroupMsgParams{ - GroupID: groupID, - Message: msg.Content, - }, nil + var action, idKey string + var rawID string + if rest, ok := strings.CutPrefix(chatID, "group:"); ok { + action, idKey, rawID = "send_group_msg", "group_id", rest + } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { + action, idKey, rawID = "send_private_msg", "user_id", rest + } else { + action, idKey, rawID = "send_private_msg", "user_id", chatID } - if len(chatID) > 8 && chatID[:8] == "private:" { - userID, err := strconv.ParseInt(chatID[8:], 10, 64) - if err != nil { - return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID) - } - return "send_private_msg", oneBotSendPrivateMsgParams{ - UserID: userID, - Message: msg.Content, - }, nil - } - - userID, err := strconv.ParseInt(chatID, 10, 64) + id, err := strconv.ParseInt(rawID, 10, 64) if err != nil { - return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID) + return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID) } - - return "send_private_msg", oneBotSendPrivateMsgParams{ - UserID: userID, - Message: msg.Content, - }, nil + return action, map[string]interface{}{idKey: id, "message": segments}, nil } func (c *OneBotChannel) listen() { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + logger.WarnC("onebot", "WebSocket connection is nil, listener exiting") + return + } + for { select { case <-c.ctx.Done(): return default: - c.mu.Lock() - conn := c.conn - c.mu.Unlock() - - if conn == nil { - logger.WarnC("onebot", "WebSocket connection is nil, listener exiting") - return - } - _, message, err := conn.ReadMessage() if err != nil { logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{ "error": err.Error(), }) c.mu.Lock() - if c.conn != nil { + if c.conn == conn { c.conn.Close() c.conn = nil } @@ -320,10 +490,7 @@ func (c *OneBotChannel) listen() { return } - logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{ - "length": len(message), - "payload": string(message), - }) + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) var raw oneBotRawEvent if err := json.Unmarshal(message, &raw); err != nil { @@ -334,20 +501,37 @@ func (c *OneBotChannel) listen() { continue } - if raw.Echo != "" || raw.Status.Online || raw.Status.Good { - logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{ - "echo": raw.Echo, - "status": raw.Status, - }) + logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{ + "length": len(message), + "post_type": raw.PostType, + "sub_type": raw.SubType, + }) + + if raw.Echo != "" { + c.pendingMu.Lock() + ch, ok := c.pending[raw.Echo] + c.pendingMu.Unlock() + + if ok { + select { + case ch <- message: + default: + } + } else { + logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{ + "echo": raw.Echo, + "status": string(raw.Status), + }) + } continue } - logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{ - "post_type": raw.PostType, - "message_type": raw.MessageType, - "sub_type": raw.SubType, - "meta_event_type": raw.MetaEventType, - }) + if isAPIResponse(raw.Status) { + logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{ + "status": string(raw.Status), + }) + continue + } c.handleRawEvent(&raw) } @@ -386,9 +570,12 @@ func parseJSONString(raw json.RawMessage) string { type parseMessageResult struct { Text string IsBotMentioned bool + Media []string + LocalFiles []string + ReplyTo string } -func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult { +func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { if len(raw) == 0 { return parseMessageResult{} } @@ -408,60 +595,155 @@ func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult } var segments []map[string]interface{} - if err := json.Unmarshal(raw, &segments); err == nil { - var text string - mentioned := false - selfIDStr := strconv.FormatInt(selfID, 10) - for _, seg := range segments { - segType, _ := seg["type"].(string) - data, _ := seg["data"].(map[string]interface{}) - switch segType { - case "text": - if data != nil { - if t, ok := data["text"].(string); ok { - text += t - } + if err := json.Unmarshal(raw, &segments); err != nil { + return parseMessageResult{} + } + + var textParts []string + mentioned := false + selfIDStr := strconv.FormatInt(selfID, 10) + var media []string + var localFiles []string + var replyTo string + + for _, seg := range segments { + segType, _ := seg["type"].(string) + data, _ := seg["data"].(map[string]interface{}) + + switch segType { + case "text": + if data != nil { + if t, ok := data["text"].(string); ok { + textParts = append(textParts, t) } - case "at": - if data != nil && selfID > 0 { - qqVal := fmt.Sprintf("%v", data["qq"]) - if qqVal == selfIDStr || qqVal == "all" { - mentioned = true + } + + case "at": + if data != nil && selfID > 0 { + qqVal := fmt.Sprintf("%v", data["qq"]) + if qqVal == selfIDStr || qqVal == "all" { + mentioned = true + } + } + + case "image", "video", "file": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"} + filename := defaults[segType] + if f, ok := data["file"].(string); ok && f != "" { + filename = f + } else if n, ok := data["name"].(string); ok && n != "" { + filename = n + } + localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + media = append(media, localPath) + localFiles = append(localFiles, localPath) + textParts = append(textParts, fmt.Sprintf("[%s]", segType)) } } } + + case "record": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + localFiles = append(localFiles, localPath) + if c.transcriber != nil && c.transcriber.IsAvailable() { + tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) + result, err := c.transcriber.Transcribe(tctx, localPath) + tcancel() + if err != nil { + logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + }) + textParts = append(textParts, "[voice (transcription failed)]") + media = append(media, localPath) + } else { + textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) + } + } else { + textParts = append(textParts, "[voice]") + media = append(media, localPath) + } + } + } + } + + case "reply": + if data != nil { + if id, ok := data["id"]; ok { + replyTo = fmt.Sprintf("%v", id) + } + } + + case "face": + if data != nil { + faceID, _ := data["id"] + textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID)) + } + + case "forward": + textParts = append(textParts, "[forward message]") + + default: + } - return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned} } - return parseMessageResult{} + + return parseMessageResult{ + Text: strings.TrimSpace(strings.Join(textParts, "")), + IsBotMentioned: mentioned, + Media: media, + LocalFiles: localFiles, + ReplyTo: replyTo, + } } func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { switch raw.PostType { case "message": - evt, err := c.normalizeMessageEvent(raw) - if err != nil { - logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{ - "error": err.Error(), - }) - return + if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { + if !c.IsAllowed(strconv.FormatInt(userID, 10)) { + logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{ + "user_id": userID, + }) + return + } } - c.handleMessage(evt) + c.handleMessage(raw) + + case "message_sent": + logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{ + "message_type": raw.MessageType, + "message_id": parseJSONString(raw.MessageID), + }) + case "meta_event": c.handleMetaEvent(raw) + case "notice": - logger.DebugCF("onebot", "Notice event received", map[string]interface{}{ - "sub_type": raw.SubType, - }) + c.handleNoticeEvent(raw) + case "request": logger.DebugCF("onebot", "Request event received", map[string]interface{}{ "sub_type": raw.SubType, }) + case "": logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{ "echo": raw.Echo, "status": raw.Status, }) + default: logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{ "post_type": raw.PostType, @@ -469,18 +751,51 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { } } -func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) { +func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) { + if raw.MetaEventType == "lifecycle" { + logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType}) + } else if raw.MetaEventType != "heartbeat" { + logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil) + } +} + +func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) { + fields := map[string]interface{}{ + "notice_type": raw.NoticeType, + "sub_type": raw.SubType, + "group_id": parseJSONString(raw.GroupID), + "user_id": parseJSONString(raw.UserID), + "message_id": parseJSONString(raw.MessageID), + } + switch raw.NoticeType { + case "group_recall", "group_increase", "group_decrease", + "friend_add", "group_admin", "group_ban": + logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields) + default: + logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields) + } +} + +func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { + // Parse fields from raw event userID, err := parseJSONInt64(raw.UserID) if err != nil { - return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID)) + logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{ + "error": err.Error(), + "raw": string(raw.UserID), + }) + return } groupID, _ := parseJSONInt64(raw.GroupID) selfID, _ := parseJSONInt64(raw.SelfID) - ts, _ := parseJSONInt64(raw.Time) messageID := parseJSONString(raw.MessageID) - parsed := parseMessageContentEx(raw.Message, selfID) + if selfID == 0 { + selfID = atomic.LoadInt64(&c.selfID) + } + + parsed := c.parseMessageSegments(raw.Message, selfID) isBotMentioned := parsed.IsBotMentioned content := raw.RawMessage @@ -495,6 +810,10 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent } } + if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") { + content = parsed.Text + } + var sender oneBotSender if len(raw.Sender) > 0 { if err := json.Unmarshal(raw.Sender, &sender); err != nil { @@ -505,137 +824,107 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent } } - logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{ - "message_type": raw.MessageType, - "user_id": userID, - "group_id": groupID, - "message_id": messageID, - "content_len": len(content), - "nickname": sender.Nickname, - }) - - return &oneBotEvent{ - PostType: raw.PostType, - MessageType: raw.MessageType, - SubType: raw.SubType, - MessageID: messageID, - UserID: userID, - GroupID: groupID, - Content: content, - RawContent: raw.RawMessage, - IsBotMentioned: isBotMentioned, - Sender: sender, - SelfID: selfID, - Time: ts, - MetaEventType: raw.MetaEventType, - }, nil -} - -func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) { - switch raw.MetaEventType { - case "lifecycle": - logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{ - "sub_type": raw.SubType, - }) - case "heartbeat": - logger.DebugC("onebot", "Heartbeat received") - default: - logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{ - "meta_event_type": raw.MetaEventType, - }) + // Clean up temp files when done + if len(parsed.LocalFiles) > 0 { + defer func() { + for _, f := range parsed.LocalFiles { + if err := os.Remove(f); err != nil { + logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{ + "path": f, + "error": err.Error(), + }) + } + } + }() } -} -func (c *OneBotChannel) handleMessage(evt *oneBotEvent) { - if c.isDuplicate(evt.MessageID) { + if c.isDuplicate(messageID) { logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{ - "message_id": evt.MessageID, + "message_id": messageID, }) return } - content := evt.Content if content == "" { logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{ - "message_id": evt.MessageID, + "message_id": messageID, }) return } - senderID := strconv.FormatInt(evt.UserID, 10) + senderID := strconv.FormatInt(userID, 10) var chatID string metadata := map[string]string{ - "message_id": evt.MessageID, + "message_id": messageID, } - switch evt.MessageType { + if parsed.ReplyTo != "" { + metadata["reply_to_message_id"] = parsed.ReplyTo + } + + switch raw.MessageType { case "private": chatID = "private:" + senderID - logger.InfoCF("onebot", "Received private message", map[string]interface{}{ - "sender": senderID, - "message_id": evt.MessageID, - "length": len(content), - "content": truncate(content, 100), - }) case "group": - groupIDStr := strconv.FormatInt(evt.GroupID, 10) + groupIDStr := strconv.FormatInt(groupID, 10) chatID = "group:" + groupIDStr metadata["group_id"] = groupIDStr - senderUserID, _ := parseJSONInt64(evt.Sender.UserID) + senderUserID, _ := parseJSONInt64(sender.UserID) if senderUserID > 0 { metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10) } - if evt.Sender.Card != "" { - metadata["sender_name"] = evt.Sender.Card - } else if evt.Sender.Nickname != "" { - metadata["sender_name"] = evt.Sender.Nickname + if sender.Card != "" { + metadata["sender_name"] = sender.Card + } else if sender.Nickname != "" { + metadata["sender_name"] = sender.Nickname } - triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned) + triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) if !triggered { logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{ "sender": senderID, "group": groupIDStr, - "is_mentioned": evt.IsBotMentioned, + "is_mentioned": isBotMentioned, "content": truncate(content, 100), }) return } content = strippedContent - logger.InfoCF("onebot", "Received group message", map[string]interface{}{ - "sender": senderID, - "group": groupIDStr, - "message_id": evt.MessageID, - "is_mentioned": evt.IsBotMentioned, - "length": len(content), - "content": truncate(content, 100), - }) - default: logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{ - "type": evt.MessageType, - "message_id": evt.MessageID, - "user_id": evt.UserID, + "type": raw.MessageType, + "message_id": messageID, + "user_id": userID, }) return } - if evt.Sender.Nickname != "" { - metadata["nickname"] = evt.Sender.Nickname - } - - logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{ - "sender_id": senderID, - "chat_id": chatID, - "content": truncate(content, 100), + logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{ + "sender": senderID, + "chat_id": chatID, + "message_id": messageID, + "length": len(content), + "content": truncate(content, 100), + "media_count": len(parsed.Media), }) - c.HandleMessage(senderID, chatID, content, []string{}, metadata) + if sender.Nickname != "" { + metadata["nickname"] = sender.Nickname + } + + c.lastMessageID.Store(chatID, messageID) + + if raw.MessageType == "group" && messageID != "" && messageID != "0" { + c.setMsgEmojiLike(messageID, 289, true) + c.pendingEmojiMsg.Store(chatID, messageID) + } + + c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) } func (c *OneBotChannel) isDuplicate(messageID string) bool { diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index 5387e9213..0060972ed 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -25,6 +25,7 @@ type SlackChannel struct { api *slack.Client socketClient *socketmode.Client botUserID string + teamID string transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc @@ -72,6 +73,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { return fmt.Errorf("slack auth test failed: %w", err) } c.botUserID = authResp.UserID + c.teamID = authResp.TeamID logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ "bot_user_id": c.botUserID, @@ -274,11 +276,21 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { return } + peerKind := "channel" + peerID := channelID + if strings.HasPrefix(channelID, "D") { + peerKind = "direct" + peerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", + "peer_kind": peerKind, + "peer_id": peerID, + "team_id": c.teamID, } logger.DebugCF("slack", "Received message", map[string]interface{}{ @@ -331,12 +343,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } + mentionPeerKind := "channel" + mentionPeerID := channelID + if strings.HasPrefix(channelID, "D") { + mentionPeerKind = "direct" + mentionPeerID = senderID + } + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", "is_mention": "true", + "peer_kind": mentionPeerKind, + "peer_id": mentionPeerID, + "team_id": c.teamID, } c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -373,6 +395,9 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "platform": "slack", "is_command": "true", "trigger_id": cmd.TriggerID, + "peer_kind": "channel", + "peer_id": channelID, + "team_id": c.teamID, } logger.DebugCF("slack", "Slash command received", map[string]interface{}{ diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index e096a0a7a..20bbf6830 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -354,12 +354,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes c.placeholders.Store(chatIDStr, pID) } + peerKind := "direct" + peerID := fmt.Sprintf("%d", user.ID) + if message.Chat.Type != "private" { + peerKind = "group" + peerID = fmt.Sprintf("%d", chatID) + } + metadata := map[string]string{ "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), + "peer_kind": peerKind, + "peer_id": peerID, } c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) diff --git a/pkg/config/config.go b/pkg/config/config.go index 0e6063e73..577799fac 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -46,6 +46,8 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { type Config struct { Agents AgentsConfig `json:"agents"` + Bindings []AgentBinding `json:"bindings,omitempty"` + Session SessionConfig `json:"session,omitempty"` Channels ChannelsConfig `json:"channels"` Providers ProvidersConfig `json:"providers"` ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration @@ -59,16 +61,97 @@ type Config struct { type AgentsConfig struct { Defaults AgentDefaults `json:"defaults"` + List []AgentConfig `json:"list,omitempty"` +} + +// AgentModelConfig supports both string and structured model config. +// String format: "gpt-4" (just primary, no fallbacks) +// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]} +type AgentModelConfig struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` +} + +func (m *AgentModelConfig) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + m.Primary = s + m.Fallbacks = nil + return nil + } + type raw struct { + Primary string `json:"primary"` + Fallbacks []string `json:"fallbacks"` + } + var r raw + if err := json.Unmarshal(data, &r); err != nil { + return err + } + m.Primary = r.Primary + m.Fallbacks = r.Fallbacks + return nil +} + +func (m AgentModelConfig) MarshalJSON() ([]byte, error) { + if len(m.Fallbacks) == 0 && m.Primary != "" { + return json.Marshal(m.Primary) + } + type raw struct { + Primary string `json:"primary,omitempty"` + Fallbacks []string `json:"fallbacks,omitempty"` + } + return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks}) +} + +type AgentConfig struct { + ID string `json:"id"` + Default bool `json:"default,omitempty"` + Name string `json:"name,omitempty"` + Workspace string `json:"workspace,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` + Skills []string `json:"skills,omitempty"` + Subagents *SubagentsConfig `json:"subagents,omitempty"` +} + +type SubagentsConfig struct { + AllowAgents []string `json:"allow_agents,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` +} + +type PeerMatch struct { + Kind string `json:"kind"` + ID string `json:"id"` +} + +type BindingMatch struct { + Channel string `json:"channel"` + AccountID string `json:"account_id,omitempty"` + Peer *PeerMatch `json:"peer,omitempty"` + GuildID string `json:"guild_id,omitempty"` + TeamID string `json:"team_id,omitempty"` +} + +type AgentBinding struct { + AgentID string `json:"agent_id"` + Match BindingMatch `json:"match"` +} + +type SessionConfig struct { + DMScope string `json:"dm_scope,omitempty"` + IdentityLinks map[string][]string `json:"identity_links,omitempty"` } type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` + ModelFallbacks []string `json:"model_fallbacks,omitempty"` + ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"` + ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } type ChannelsConfig struct { @@ -170,23 +253,23 @@ type DevicesConfig struct { } type ProvidersConfig struct { - Anthropic ProviderConfig `json:"anthropic"` - OpenAI ProviderConfig `json:"openai"` - OpenRouter ProviderConfig `json:"openrouter"` - Groq ProviderConfig `json:"groq"` - Zhipu ProviderConfig `json:"zhipu"` - VLLM ProviderConfig `json:"vllm"` - Gemini ProviderConfig `json:"gemini"` - Nvidia ProviderConfig `json:"nvidia"` - Ollama ProviderConfig `json:"ollama"` - Moonshot ProviderConfig `json:"moonshot"` - ShengSuanYun ProviderConfig `json:"shengsuanyun"` - DeepSeek ProviderConfig `json:"deepseek"` - Cerebras ProviderConfig `json:"cerebras"` - VolcEngine ProviderConfig `json:"volcengine"` - GitHubCopilot ProviderConfig `json:"github_copilot"` - Antigravity ProviderConfig `json:"antigravity"` - Qwen ProviderConfig `json:"qwen"` + Anthropic ProviderConfig `json:"anthropic"` + OpenAI OpenAIProviderConfig `json:"openai"` + OpenRouter ProviderConfig `json:"openrouter"` + Groq ProviderConfig `json:"groq"` + Zhipu ProviderConfig `json:"zhipu"` + VLLM ProviderConfig `json:"vllm"` + Gemini ProviderConfig `json:"gemini"` + Nvidia ProviderConfig `json:"nvidia"` + Ollama ProviderConfig `json:"ollama"` + Moonshot ProviderConfig `json:"moonshot"` + ShengSuanYun ProviderConfig `json:"shengsuanyun"` + DeepSeek ProviderConfig `json:"deepseek"` + Cerebras ProviderConfig `json:"cerebras"` + VolcEngine ProviderConfig `json:"volcengine"` + GitHubCopilot ProviderConfig `json:"github_copilot"` + Antigravity ProviderConfig `json:"antigravity"` + Qwen ProviderConfig `json:"qwen"` } type ProviderConfig struct { @@ -197,6 +280,11 @@ type ProviderConfig struct { ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc` } +type OpenAIProviderConfig struct { + ProviderConfig + WebSearch bool `json:"web_search" env:"PICOCLAW_PROVIDERS_OPENAI_WEB_SEARCH"` +} + // ModelConfig represents a model-centric provider configuration. // It allows adding new providers (especially OpenAI-compatible ones) via configuration only. // The model field uses protocol prefix format: [protocol/]model-identifier @@ -265,9 +353,15 @@ type CronToolsConfig struct { ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout } +type ExecConfig struct { + EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"` + CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"` +} + type ToolsConfig struct { Web WebToolsConfig `json:"web"` Cron CronToolsConfig `json:"cron"` + Exec ExecConfig `json:"exec"` } func LoadConfig(path string) (*Config, error) { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index febfd0456..47916d155 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1,12 +1,193 @@ package config import ( + "encoding/json" "os" "path/filepath" "runtime" "testing" ) +func TestAgentModelConfig_UnmarshalString(t *testing.T) { + var m AgentModelConfig + if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil { + t.Fatalf("unmarshal string: %v", err) + } + if m.Primary != "gpt-4" { + t.Errorf("Primary = %q, want 'gpt-4'", m.Primary) + } + if m.Fallbacks != nil { + t.Errorf("Fallbacks = %v, want nil", m.Fallbacks) + } +} + +func TestAgentModelConfig_UnmarshalObject(t *testing.T) { + var m AgentModelConfig + data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}` + if err := json.Unmarshal([]byte(data), &m); err != nil { + t.Fatalf("unmarshal object: %v", err) + } + if m.Primary != "claude-opus" { + t.Errorf("Primary = %q, want 'claude-opus'", m.Primary) + } + if len(m.Fallbacks) != 2 { + t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks)) + } + if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" { + t.Errorf("Fallbacks = %v", m.Fallbacks) + } +} + +func TestAgentModelConfig_MarshalString(t *testing.T) { + m := AgentModelConfig{Primary: "gpt-4"} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(data) != `"gpt-4"` { + t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data)) + } +} + +func TestAgentModelConfig_MarshalObject(t *testing.T) { + m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}} + data, err := json.Marshal(m) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var result map[string]interface{} + json.Unmarshal(data, &result) + if result["primary"] != "claude-opus" { + t.Errorf("primary = %v", result["primary"]) + } +} + +func TestAgentConfig_FullParse(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + }, + "list": [ + { + "id": "sales", + "default": true, + "name": "Sales Bot", + "model": "gpt-4" + }, + { + "id": "support", + "name": "Support Bot", + "model": { + "primary": "claude-opus", + "fallbacks": ["haiku"] + }, + "subagents": { + "allow_agents": ["sales"] + } + } + ] + }, + "bindings": [ + { + "agent_id": "support", + "match": { + "channel": "telegram", + "account_id": "*", + "peer": {"kind": "direct", "id": "user123"} + } + } + ], + "session": { + "dm_scope": "per-peer", + "identity_links": { + "john": ["telegram:123", "discord:john#1234"] + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 2 { + t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List)) + } + + sales := cfg.Agents.List[0] + if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" { + t.Errorf("sales = %+v", sales) + } + if sales.Model == nil || sales.Model.Primary != "gpt-4" { + t.Errorf("sales.Model = %+v", sales.Model) + } + + support := cfg.Agents.List[1] + if support.ID != "support" || support.Name != "Support Bot" { + t.Errorf("support = %+v", support) + } + if support.Model == nil || support.Model.Primary != "claude-opus" { + t.Errorf("support.Model = %+v", support.Model) + } + if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" { + t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks) + } + if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 { + t.Errorf("support.Subagents = %+v", support.Subagents) + } + + if len(cfg.Bindings) != 1 { + t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings)) + } + binding := cfg.Bindings[0] + if binding.AgentID != "support" || binding.Match.Channel != "telegram" { + t.Errorf("binding = %+v", binding) + } + if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" { + t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer) + } + + if cfg.Session.DMScope != "per-peer" { + t.Errorf("Session.DMScope = %q", cfg.Session.DMScope) + } + if len(cfg.Session.IdentityLinks) != 1 { + t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks) + } + links := cfg.Session.IdentityLinks["john"] + if len(links) != 2 { + t.Errorf("john links = %v", links) + } +} + +func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "max_tool_iterations": 20 + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if len(cfg.Agents.List) != 0 { + t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List)) + } + if len(cfg.Bindings) != 0 { + t.Errorf("bindings should be empty, got %d", len(cfg.Bindings)) + } +} + // TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { cfg := DefaultConfig() @@ -20,8 +201,6 @@ func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { func TestDefaultConfig_WorkspacePath(t *testing.T) { cfg := DefaultConfig() - // Just verify the workspace is set, don't compare exact paths - // since expandHome behavior may differ based on environment if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } @@ -79,7 +258,6 @@ func TestDefaultConfig_Gateway(t *testing.T) { func TestDefaultConfig_Providers(t *testing.T) { cfg := DefaultConfig() - // Verify all providers are empty by default if cfg.Providers.Anthropic.APIKey != "" { t.Error("Anthropic API key should be empty by default") } @@ -89,46 +267,18 @@ func TestDefaultConfig_Providers(t *testing.T) { if cfg.Providers.OpenRouter.APIKey != "" { t.Error("OpenRouter API key should be empty by default") } - if cfg.Providers.Groq.APIKey != "" { - t.Error("Groq API key should be empty by default") - } - if cfg.Providers.Zhipu.APIKey != "" { - t.Error("Zhipu API key should be empty by default") - } - if cfg.Providers.VLLM.APIKey != "" { - t.Error("VLLM API key should be empty by default") - } - if cfg.Providers.Gemini.APIKey != "" { - t.Error("Gemini API key should be empty by default") - } } // TestDefaultConfig_Channels verifies channels are disabled by default func TestDefaultConfig_Channels(t *testing.T) { cfg := DefaultConfig() - // Verify all channels are disabled by default - if cfg.Channels.WhatsApp.Enabled { - t.Error("WhatsApp should be disabled by default") - } if cfg.Channels.Telegram.Enabled { t.Error("Telegram should be disabled by default") } - if cfg.Channels.Feishu.Enabled { - t.Error("Feishu should be disabled by default") - } if cfg.Channels.Discord.Enabled { t.Error("Discord should be disabled by default") } - if cfg.Channels.MaixCam.Enabled { - t.Error("MaixCam should be disabled by default") - } - if cfg.Channels.QQ.Enabled { - t.Error("QQ should be disabled by default") - } - if cfg.Channels.DingTalk.Enabled { - t.Error("DingTalk should be disabled by default") - } if cfg.Channels.Slack.Enabled { t.Error("Slack should be disabled by default") } @@ -178,7 +328,6 @@ func TestSaveConfig_FilePermissions(t *testing.T) { func TestConfig_Complete(t *testing.T) { cfg := DefaultConfig() - // Verify complete config structure if cfg.Agents.Defaults.Workspace == "" { t.Error("Workspace should not be empty") } @@ -204,3 +353,42 @@ func TestConfig_Complete(t *testing.T) { t.Error("Heartbeat should be enabled by default") } } + +func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) { + cfg := DefaultConfig() + if !cfg.Providers.OpenAI.WebSearch { + t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true") + } +} + +func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if !cfg.Providers.OpenAI.WebSearch { + t.Fatal("OpenAI codex web search should remain true when unset in config file") + } +} + +func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Providers.OpenAI.WebSearch { + t.Fatal("OpenAI codex web search should be false when disabled in config file") + } +} diff --git a/pkg/constants/channels.go b/pkg/constants/channels.go index 3e3df3839..0a46e6cd9 100644 --- a/pkg/constants/channels.go +++ b/pkg/constants/channels.go @@ -1,15 +1,16 @@ // Package constants provides shared constants across the codebase. package constants -// InternalChannels defines channels that are used for internal communication +// internalChannels defines channels that are used for internal communication // and should not be exposed to external users or recorded as last active channel. -var InternalChannels = map[string]bool{ - "cli": true, - "system": true, - "subagent": true, +var internalChannels = map[string]struct{}{ + "cli": {}, + "system": {}, + "subagent": {}, } // IsInternalChannel returns true if the channel is an internal channel. func IsInternalChannel(channel string) bool { - return InternalChannels[channel] + _, found := internalChannels[channel] + return found } diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go index 8bb5e14c0..604178496 100644 --- a/pkg/migrate/config.go +++ b/pkg/migrate/config.go @@ -110,7 +110,10 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error case "anthropic": cfg.Providers.Anthropic = pc case "openai": - cfg.Providers.OpenAI = pc + cfg.Providers.OpenAI = config.OpenAIProviderConfig{ + ProviderConfig: pc, + WebSearch: getBoolOrDefault(pMap, "web_search", true), + } case "openrouter": cfg.Providers.OpenRouter = pc case "groq": @@ -374,6 +377,13 @@ func getBool(data map[string]interface{}, key string) (bool, bool) { return b, ok } +func getBoolOrDefault(data map[string]interface{}, key string, defaultVal bool) bool { + if v, ok := getBool(data, key); ok { + return v + } + return defaultVal +} + func getStringSlice(data map[string]interface{}, key string) []string { v, ok := data[key] if !ok { diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index cd36043f7..d57ea1c03 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -299,6 +299,24 @@ func TestConvertConfig(t *testing.T) { }) } +func TestSupportedProvidersCompatibility(t *testing.T) { + expected := []string{ + "anthropic", + "openai", + "openrouter", + "groq", + "zhipu", + "vllm", + "gemini", + } + + for _, provider := range expected { + if !supportedProviders[provider] { + t.Fatalf("supportedProviders missing expected key %q", provider) + } + } +} + func TestMergeConfig(t *testing.T) { t.Run("fills empty fields", func(t *testing.T) { existing := config.DefaultConfig() diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go new file mode 100644 index 000000000..8f46aa70c --- /dev/null +++ b/pkg/providers/anthropic/provider.go @@ -0,0 +1,248 @@ +package anthropicprovider + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + +const defaultBaseURL = "https://api.anthropic.com" + +type Provider struct { + client *anthropic.Client + tokenSource func() (string, error) + baseURL string +} + +func NewProvider(token string) *Provider { + return NewProviderWithBaseURL(token, "") +} + +func NewProviderWithBaseURL(token, apiBase string) *Provider { + baseURL := normalizeBaseURL(apiBase) + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL(baseURL), + ) + return &Provider{ + client: &client, + baseURL: baseURL, + } +} + +func NewProviderWithClient(client *anthropic.Client) *Provider { + return &Provider{ + client: client, + baseURL: defaultBaseURL, + } +} + +func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider { + return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "") +} + +func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider { + p := NewProviderWithBaseURL(token, apiBase) + p.tokenSource = tokenSource + return p +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseResponse(resp), nil +} + +func (p *Provider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func (p *Provider) BaseURL() string { + return p.baseURL +} + +func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateTools(tools) + } + + return params, nil +} + +func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err) + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func normalizeBaseURL(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if strings.HasSuffix(base, "/v1") { + base = strings.TrimSuffix(base, "/v1") + } + if base == "" { + return defaultBaseURL + } + + return base +} diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go new file mode 100644 index 000000000..6a1dabafb --- /dev/null +++ b/pkg/providers/anthropic/provider_test.go @@ -0,0 +1,265 @@ +package anthropicprovider + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestProvider_GetDefaultModel(t *testing.T) { + p := NewProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) { + p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/") + if got := p.BaseURL(); got != "https://api.anthropic.com" { + t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com") + } +} + +func TestProvider_ChatUsesTokenSource(t *testing.T) { + var requests int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + atomic.AddInt32(&requests, 1) + + if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "ok"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 1, + "output_tokens": 1, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) { + return "refreshed-token", nil + }, server.URL) + + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if got := atomic.LoadInt32(&requests); got != 1 { + t.Fatalf("requests = %d, want 1", got) + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go index ae6aca96d..3ca54d5a3 100644 --- a/pkg/providers/claude_provider.go +++ b/pkg/providers/claude_provider.go @@ -2,200 +2,58 @@ package providers import ( "context" - "encoding/json" "fmt" - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" - "github.com/sipeed/picoclaw/pkg/auth" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) type ClaudeProvider struct { - client *anthropic.Client - tokenSource func() (string, error) + delegate *anthropicprovider.Provider } func NewClaudeProvider(token string) *ClaudeProvider { - client := anthropic.NewClient( - option.WithAuthToken(token), - option.WithBaseURL("https://api.anthropic.com"), - ) - return &ClaudeProvider{client: &client} + return &ClaudeProvider{ + delegate: anthropicprovider.NewProvider(token), + } +} + +func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase), + } } func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { - p := NewClaudeProvider(token) - p.tokenSource = tokenSource - return p + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource), + } +} + +func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase), + } +} + +func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider { + return &ClaudeProvider{delegate: delegate} } func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - var opts []option.RequestOption - if p.tokenSource != nil { - tok, err := p.tokenSource() - if err != nil { - return nil, fmt.Errorf("refreshing token: %w", err) - } - opts = append(opts, option.WithAuthToken(tok)) - } - - params, err := buildClaudeParams(messages, tools, model, options) + resp, err := p.delegate.Chat(ctx, messages, tools, model, options) if err != nil { return nil, err } - - resp, err := p.client.Messages.New(ctx, params, opts...) - if err != nil { - return nil, fmt.Errorf("claude API call: %w", err) - } - - return parseClaudeResponse(resp), nil + return resp, nil } func (p *ClaudeProvider) GetDefaultModel() string { - return "claude-sonnet-4-5-20250929" -} - -func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { - var system []anthropic.TextBlockParam - var anthropicMessages []anthropic.MessageParam - - for _, msg := range messages { - switch msg.Role { - case "system": - system = append(system, anthropic.TextBlockParam{Text: msg.Content}) - case "user": - if msg.ToolCallID != "" { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "assistant": - if len(msg.ToolCalls) > 0 { - var blocks []anthropic.ContentBlockParamUnion - if msg.Content != "" { - blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) - } - for _, tc := range msg.ToolCalls { - blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) - } - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "tool": - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } - } - - maxTokens := int64(4096) - if mt, ok := options["max_tokens"].(int); ok { - maxTokens = int64(mt) - } - - params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), - Messages: anthropicMessages, - MaxTokens: maxTokens, - } - - if len(system) > 0 { - params.System = system - } - - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = anthropic.Float(temp) - } - - if len(tools) > 0 { - params.Tools = translateToolsForClaude(tools) - } - - return params, nil -} - -func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { - result := make([]anthropic.ToolUnionParam, 0, len(tools)) - for _, t := range tools { - tool := anthropic.ToolParam{ - Name: t.Function.Name, - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: t.Function.Parameters["properties"], - }, - } - if desc := t.Function.Description; desc != "" { - tool.Description = anthropic.String(desc) - } - if req, ok := t.Function.Parameters["required"].([]interface{}); ok { - required := make([]string, 0, len(req)) - for _, r := range req { - if s, ok := r.(string); ok { - required = append(required, s) - } - } - tool.InputSchema.Required = required - } - result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) - } - return result -} - -func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { - var content string - var toolCalls []ToolCall - - for _, block := range resp.Content { - switch block.Type { - case "text": - tb := block.AsText() - content += tb.Text - case "tool_use": - tu := block.AsToolUse() - var args map[string]interface{} - if err := json.Unmarshal(tu.Input, &args); err != nil { - args = map[string]interface{}{"raw": string(tu.Input)} - } - toolCalls = append(toolCalls, ToolCall{ - ID: tu.ID, - Name: tu.Name, - Arguments: args, - }) - } - } - - finishReason := "stop" - switch resp.StopReason { - case anthropic.StopReasonToolUse: - finishReason = "tool_calls" - case anthropic.StopReasonMaxTokens: - finishReason = "length" - case anthropic.StopReasonEndTurn: - finishReason = "stop" - } - - return &LLMResponse{ - Content: content, - ToolCalls: toolCalls, - FinishReason: finishReason, - Usage: &UsageInfo{ - PromptTokens: int(resp.Usage.InputTokens), - CompletionTokens: int(resp.Usage.OutputTokens), - TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), - }, - } + return p.delegate.GetDefaultModel() } func createClaudeTokenSource() func() (string, error) { return func() (string, error) { - cred, err := auth.GetCredential("anthropic") + cred, err := getCredential("anthropic") if err != nil { return "", fmt.Errorf("loading auth credentials: %w", err) } diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go index bbad2d269..13bbde1fc 100644 --- a/pkg/providers/claude_provider_test.go +++ b/pkg/providers/claude_provider_test.go @@ -8,140 +8,9 @@ import ( "github.com/anthropics/anthropic-sdk-go" anthropicoption "github.com/anthropics/anthropic-sdk-go/option" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) -func TestBuildClaudeParams_BasicMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "Hello"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ - "max_tokens": 1024, - }) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if string(params.Model) != "claude-sonnet-4-5-20250929" { - t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") - } - if params.MaxTokens != 1024 { - t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_SystemMessage(t *testing.T) { - messages := []Message{ - {Role: "system", Content: "You are helpful"}, - {Role: "user", Content: "Hi"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.System) != 1 { - t.Fatalf("len(System) = %d, want 1", len(params.System)) - } - if params.System[0].Text != "You are helpful" { - t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{ - { - ID: "call_1", - Name: "get_weather", - Arguments: map[string]interface{}{"city": "SF"}, - }, - }, - }, - {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Messages) != 3 { - t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) - } -} - -func TestBuildClaudeParams_WithTools(t *testing.T) { - tools := []ToolDefinition{ - { - Type: "function", - Function: ToolFunctionDefinition{ - Name: "get_weather", - Description: "Get weather for a city", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{"type": "string"}, - }, - "required": []interface{}{"city"}, - }, - }, - }, - } - params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Tools) != 1 { - t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) - } -} - -func TestParseClaudeResponse_TextOnly(t *testing.T) { - resp := &anthropic.Message{ - Content: []anthropic.ContentBlockUnion{}, - Usage: anthropic.Usage{ - InputTokens: 10, - OutputTokens: 20, - }, - } - result := parseClaudeResponse(resp) - if result.Usage.PromptTokens != 10 { - t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) - } - if result.Usage.CompletionTokens != 20 { - t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) - } - if result.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") - } -} - -func TestParseClaudeResponse_StopReasons(t *testing.T) { - tests := []struct { - stopReason anthropic.StopReason - want string - }{ - {anthropic.StopReasonEndTurn, "stop"}, - {anthropic.StopReasonMaxTokens, "length"}, - {anthropic.StopReasonToolUse, "tool_calls"}, - } - for _, tt := range tests { - resp := &anthropic.Message{ - StopReason: tt.stopReason, - } - result := parseClaudeResponse(resp) - if result.FinishReason != tt.want { - t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) - } - } -} - func TestClaudeProvider_ChatRoundTrip(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/messages" { @@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) { })) defer server.Close() - provider := NewClaudeProvider("test-token") - provider.client = createAnthropicTestClient(server.URL, "test-token") + delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + provider := newClaudeProviderWithDelegate(delegate) messages := []Message{{Role: "user", Content: "Hello"}} resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) diff --git a/pkg/providers/codex_cli_provider_integration_test.go b/pkg/providers/codex_cli_provider_integration_test.go new file mode 100644 index 000000000..0267c730f --- /dev/null +++ b/pkg/providers/codex_cli_provider_integration_test.go @@ -0,0 +1,119 @@ +//go:build integration + +package providers + +import ( + "context" + exec "os/exec" + "strings" + "testing" + "time" +) + +// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI. +// Run with: go test -tags=integration ./pkg/providers/... +func TestIntegration_RealCodexCLI(t *testing.T) { + path, err := exec.LookPath("codex") + if err != nil { + t.Skip("codex CLI not found in PATH, skipping integration test") + } + t.Logf("Using codex CLI at: %s", path) + + p := NewCodexCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with real CLI error = %v", err) + } + + if resp.Content == "" { + t.Error("Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage != nil { + t.Logf("Usage: prompt=%d, completion=%d, total=%d", + resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) + } + + t.Logf("Response content: %q", resp.Content) + + if !strings.Contains(strings.ToLower(resp.Content), "pong") { + t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) + } +} + +func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) { + if _, err := exec.LookPath("codex"); err != nil { + t.Skip("codex CLI not found in PATH") + } + + p := NewCodexCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, + {Role: "user", Content: "What is 2+2?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + t.Logf("Response: %q", resp.Content) + + if !strings.Contains(resp.Content, "4") { + t.Errorf("Content = %q, expected to contain '4'", resp.Content) + } +} + +func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) { + if _, err := exec.LookPath("codex"); err != nil { + t.Skip("codex CLI not found in PATH") + } + + // Run codex directly and verify our parser handles real output + cmd := exec.Command("codex", "exec", + "--json", + "--dangerously-bypass-approvals-and-sandbox", + "--skip-git-repo-check", + "--color", "never", + "-C", t.TempDir(), + "-") + cmd.Stdin = strings.NewReader("Say hi") + + output, err := cmd.Output() + if err != nil { + // codex may write diagnostic noise to stderr but still produce valid output + if len(output) == 0 { + t.Fatalf("codex CLI failed: %v", err) + } + } + + t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)])) + + // Verify our parser can handle real output + p := NewCodexCliProvider("") + resp, err := p.parseJSONLEvents(string(output)) + if err != nil { + t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err) + } + + if resp.Content == "" { + t.Error("parsed Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", resp.FinishReason) + } + + t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index 6dff3a52e..e3526cfb5 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -18,9 +18,10 @@ const codexDefaultModel = "gpt-5.2" const codexDefaultInstructions = "You are Codex, a coding assistant." type CodexProvider struct { - client *openai.Client - accountID string - tokenSource func() (string, string, error) + client *openai.Client + accountID string + tokenSource func() (string, string, error) + enableWebSearch bool } const defaultCodexInstructions = "You are Codex, a coding assistant." @@ -37,8 +38,9 @@ func NewCodexProvider(token, accountID string) *CodexProvider { } client := openai.NewClient(opts...) return &CodexProvider{ - client: &client, - accountID: accountID, + client: &client, + accountID: accountID, + enableWebSearch: true, } } @@ -78,7 +80,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To }) } - params := buildCodexParams(messages, tools, resolvedModel, options) + params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch) stream := p.client.Responses.NewStreaming(ctx, params, opts...) defer stream.Close() @@ -182,7 +184,7 @@ func resolveCodexModel(model string) (string, string) { return codexDefaultModel, "unsupported model family" } -func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { +func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams { var inputItems responses.ResponseInputParam var instructions string @@ -217,12 +219,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, }) } for _, tc := range msg.ToolCalls { - argsJSON, _ := json.Marshal(tc.Arguments) + name, args, ok := resolveCodexToolCall(tc) + if !ok { + logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{ + "call_id": tc.ID, + }) + continue + } inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ OfFunctionCall: &responses.ResponseFunctionToolCallParam{ CallID: tc.ID, - Name: tc.Name, - Arguments: string(argsJSON), + Name: name, + Arguments: args, }, }) } @@ -260,20 +268,50 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, params.Instructions = openai.Opt(defaultCodexInstructions) } - if maxTokens, ok := options["max_tokens"].(int); ok { - params.MaxOutputTokens = openai.Opt(int64(maxTokens)) - } - - if len(tools) > 0 { - params.Tools = translateToolsForCodex(tools) + if len(tools) > 0 || enableWebSearch { + params.Tools = translateToolsForCodex(tools, enableWebSearch) } return params } -func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { - result := make([]responses.ToolUnionParam, 0, len(tools)) +func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) { + name = tc.Name + if name == "" && tc.Function != nil { + name = tc.Function.Name + } + if name == "" { + return "", "", false + } + + if len(tc.Arguments) > 0 { + argsJSON, err := json.Marshal(tc.Arguments) + if err != nil { + return "", "", false + } + return name, string(argsJSON), true + } + + if tc.Function != nil && tc.Function.Arguments != "" { + return name, tc.Function.Arguments, true + } + + return name, "{}", true +} + +func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam { + capHint := len(tools) + if enableWebSearch { + capHint++ + } + result := make([]responses.ToolUnionParam, 0, capHint) for _, t := range tools { + if t.Type != "function" { + continue + } + if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") { + continue + } ft := responses.FunctionToolParam{ Name: t.Function.Name, Parameters: t.Function.Parameters, @@ -284,6 +322,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { } result = append(result, responses.ToolUnionParam{OfFunction: &ft}) } + if enableWebSearch { + result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch)) + } return result } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 317b1a5de..92e276165 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -19,7 +19,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) { params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ "max_tokens": 2048, "temperature": 0.7, - }) + }, true) if params.Model != "gpt-4o" { t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") } @@ -29,6 +29,9 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) { if params.Instructions.Or("") != defaultCodexInstructions { t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions) } + if params.MaxOutputTokens.Valid() { + t.Fatalf("MaxOutputTokens should not be set for Codex backend") + } } func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { @@ -36,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { {Role: "system", Content: "You are helpful"}, {Role: "user", Content: "Hi"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true) if !params.Instructions.Valid() { t.Fatal("Instructions should be set") } @@ -56,7 +59,7 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) { }, {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, } - params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false) if params.Input.OfInputItemList == nil { t.Fatal("Input.OfInputItemList should not be nil") } @@ -65,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) { } } +func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Read a file"}, + { + Role: "assistant", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"README.md"}`, + }, + }, + }, + }, + {Role: "tool", Content: "ok", ToolCallID: "call_1"}, + } + + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false) + if params.Input.OfInputItemList == nil { + t.Fatal("Input.OfInputItemList should not be nil") + } + if len(params.Input.OfInputItemList) != 3 { + t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList)) + } + + fc := params.Input.OfInputItemList[1].OfFunctionCall + if fc == nil { + t.Fatal("assistant tool call should be converted to function_call input item") + } + if fc.Name != "read_file" { + t.Errorf("Function call name = %q, want %q", fc.Name, "read_file") + } + if fc.Arguments != `{"path":"README.md"}` { + t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`) + } +} + func TestBuildCodexParams_WithTools(t *testing.T) { tools := []ToolDefinition{ { @@ -81,7 +123,7 @@ func TestBuildCodexParams_WithTools(t *testing.T) { }, }, } - params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false) if len(params.Tools) != 1 { t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) } @@ -94,12 +136,61 @@ func TestBuildCodexParams_WithTools(t *testing.T) { } func TestBuildCodexParams_StoreIsFalse(t *testing.T) { - params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false) if !params.Store.Valid() || params.Store.Or(true) != false { t.Error("Store should be explicitly set to false") } } +func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfWebSearch == nil { + t.Fatal("Tool should include built-in web_search") + } + if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch { + t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch) + } +} + +func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "web_search", + Description: "local web search", + Parameters: map[string]interface{}{ + "type": "object", + }, + }, + }, + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "read_file", + Description: "read file", + Parameters: map[string]interface{}{ + "type": "object", + }, + }, + }, + } + + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true) + if len(params.Tools) != 2 { + t.Fatalf("len(Tools) = %d, want 2", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" { + t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0]) + } + if params.Tools[1].OfWebSearch == nil { + t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1]) + } +} + func TestParseCodexResponse_TextOutput(t *testing.T) { respJSON := `{ "id": "resp_test", @@ -214,6 +305,20 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { http.Error(w, "stream must be true", http.StatusBadRequest) return } + if _, ok := reqBody["max_output_tokens"]; ok { + http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest) + return + } + toolsAny, ok := reqBody["tools"].([]interface{}) + if !ok || len(toolsAny) != 1 { + http.Error(w, "missing default web search tool", http.StatusBadRequest) + return + } + toolObj, ok := toolsAny[0].(map[string]interface{}) + if !ok || toolObj["type"] != "web_search" { + http.Error(w, "expected web_search tool", http.StatusBadRequest) + return + } resp := map[string]interface{}{ "id": "resp_test", @@ -261,6 +366,64 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { } } +func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if _, ok := reqBody["tools"]; ok { + http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 4, + "output_tokens": 3, + "total_tokens": 7, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.enableWebSearch = false + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/responses" { @@ -293,6 +456,10 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) http.Error(w, "temperature is not supported", http.StatusBadRequest) return } + if _, ok := reqBody["max_output_tokens"]; ok { + http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest) + return + } if reqBody["stream"] != true { http.Error(w, "stream must be true", http.StatusBadRequest) return diff --git a/pkg/providers/cooldown.go b/pkg/providers/cooldown.go new file mode 100644 index 000000000..b0d8608dc --- /dev/null +++ b/pkg/providers/cooldown.go @@ -0,0 +1,207 @@ +package providers + +import ( + "math" + "sync" + "time" +) + +const ( + defaultFailureWindow = 24 * time.Hour +) + +// CooldownTracker manages per-provider cooldown state for the fallback chain. +// Thread-safe via sync.RWMutex. In-memory only (resets on restart). +type CooldownTracker struct { + mu sync.RWMutex + entries map[string]*cooldownEntry + failureWindow time.Duration + nowFunc func() time.Time // for testing +} + +type cooldownEntry struct { + ErrorCount int + FailureCounts map[FailoverReason]int + CooldownEnd time.Time // standard cooldown expiry + DisabledUntil time.Time // billing-specific disable expiry + DisabledReason FailoverReason // reason for disable (billing) + LastFailure time.Time +} + +// NewCooldownTracker creates a tracker with default 24h failure window. +func NewCooldownTracker() *CooldownTracker { + return &CooldownTracker{ + entries: make(map[string]*cooldownEntry), + failureWindow: defaultFailureWindow, + nowFunc: time.Now, + } +} + +// MarkFailure records a failure for a provider and sets appropriate cooldown. +// Resets error counts if last failure was more than failureWindow ago. +func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) { + ct.mu.Lock() + defer ct.mu.Unlock() + + now := ct.nowFunc() + entry := ct.getOrCreate(provider) + + // 24h failure window reset: if no failure in failureWindow, reset counters. + if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow { + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + } + + entry.ErrorCount++ + entry.FailureCounts[reason]++ + entry.LastFailure = now + + if reason == FailoverBilling { + billingCount := entry.FailureCounts[FailoverBilling] + entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount)) + entry.DisabledReason = FailoverBilling + } else { + entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount)) + } +} + +// MarkSuccess resets all counters and cooldowns for a provider. +func (ct *CooldownTracker) MarkSuccess(provider string) { + ct.mu.Lock() + defer ct.mu.Unlock() + + entry := ct.entries[provider] + if entry == nil { + return + } + + entry.ErrorCount = 0 + entry.FailureCounts = make(map[FailoverReason]int) + entry.CooldownEnd = time.Time{} + entry.DisabledUntil = time.Time{} + entry.DisabledReason = "" +} + +// IsAvailable returns true if the provider is not in cooldown or disabled. +func (ct *CooldownTracker) IsAvailable(provider string) bool { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return true + } + + now := ct.nowFunc() + + // Billing disable takes precedence (longer cooldown). + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + return false + } + + // Standard cooldown. + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + return false + } + + return true +} + +// CooldownRemaining returns how long until the provider becomes available. +// Returns 0 if already available. +func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + + now := ct.nowFunc() + var remaining time.Duration + + if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) { + d := entry.DisabledUntil.Sub(now) + if d > remaining { + remaining = d + } + } + + if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) { + d := entry.CooldownEnd.Sub(now) + if d > remaining { + remaining = d + } + } + + return remaining +} + +// ErrorCount returns the current error count for a provider. +func (ct *CooldownTracker) ErrorCount(provider string) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.ErrorCount +} + +// FailureCount returns the failure count for a specific reason. +func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int { + ct.mu.RLock() + defer ct.mu.RUnlock() + + entry := ct.entries[provider] + if entry == nil { + return 0 + } + return entry.FailureCounts[reason] +} + +func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry { + entry := ct.entries[provider] + if entry == nil { + entry = &cooldownEntry{ + FailureCounts: make(map[FailoverReason]int), + } + ct.entries[provider] = entry + } + return entry +} + +// calculateStandardCooldown computes standard exponential backoff. +// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3)) +// +// 1 error → 1 min +// 2 errors → 5 min +// 3 errors → 25 min +// 4+ errors → 1 hour (cap) +func calculateStandardCooldown(errorCount int) time.Duration { + n := max(1, errorCount) + exp := min(n-1, 3) + ms := 60_000 * int(math.Pow(5, float64(exp))) + ms = min(3_600_000, ms) // cap at 1 hour + return time.Duration(ms) * time.Millisecond +} + +// calculateBillingCooldown computes billing-specific exponential backoff. +// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10)) +// +// 1 error → 5 hours +// 2 errors → 10 hours +// 3 errors → 20 hours +// 4+ errors → 24 hours (cap) +func calculateBillingCooldown(billingErrorCount int) time.Duration { + const baseMs = 5 * 60 * 60 * 1000 // 5 hours + const maxMs = 24 * 60 * 60 * 1000 // 24 hours + + n := max(1, billingErrorCount) + exp := min(n-1, 10) + raw := float64(baseMs) * math.Pow(2, float64(exp)) + ms := int(math.Min(float64(maxMs), raw)) + return time.Duration(ms) * time.Millisecond +} diff --git a/pkg/providers/cooldown_test.go b/pkg/providers/cooldown_test.go new file mode 100644 index 000000000..47f43ad5c --- /dev/null +++ b/pkg/providers/cooldown_test.go @@ -0,0 +1,269 @@ +package providers + +import ( + "sync" + "testing" + "time" +) + +func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) { + current := now + ct := NewCooldownTracker() + ct.nowFunc = func() time.Time { return current } + return ct, ¤t +} + +func TestCooldown_InitiallyAvailable(t *testing.T) { + ct := NewCooldownTracker() + if !ct.IsAvailable("openai") { + t.Error("new provider should be available") + } + if ct.ErrorCount("openai") != 0 { + t.Error("new provider should have 0 errors") + } +} + +func TestCooldown_StandardEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st error → 1 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown after 1st error") + } + + // Advance 61 seconds → available + *current = now.Add(61 * time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 1 min cooldown") + } + + // 2nd error → 5 min cooldown + ct.MarkFailure("openai", FailoverRateLimit) + *current = now.Add(61*time.Second + 4*time.Minute) + if ct.IsAvailable("openai") { + t.Error("should be in cooldown (5 min) after 2nd error") + } + *current = now.Add(61*time.Second + 6*time.Minute) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5 min cooldown") + } +} + +func TestCooldown_StandardCap(t *testing.T) { + // Verify formula: 1m, 5m, 25m, 1h, 1h, 1h... + expected := []time.Duration{ + 1 * time.Minute, + 5 * time.Minute, + 25 * time.Minute, + 1 * time.Hour, + 1 * time.Hour, + } + + for i, want := range expected { + got := calculateStandardCooldown(i + 1) + if got != want { + t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_BillingEscalation(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 1st billing error → 5h cooldown + ct.MarkFailure("openai", FailoverBilling) + if ct.IsAvailable("openai") { + t.Error("should be disabled after billing error") + } + + // Advance 4h → still disabled + *current = now.Add(4 * time.Hour) + if ct.IsAvailable("openai") { + t.Error("should still be disabled (5h cooldown)") + } + + // Advance 5h + 1s → available + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after 5h billing cooldown") + } +} + +func TestCooldown_BillingCap(t *testing.T) { + expected := []time.Duration{ + 5 * time.Hour, + 10 * time.Hour, + 20 * time.Hour, + 24 * time.Hour, + 24 * time.Hour, + } + + for i, want := range expected { + got := calculateBillingCooldown(i + 1) + if got != want { + t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want) + } + } +} + +func TestCooldown_SuccessReset(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + if ct.ErrorCount("openai") != 2 { + t.Errorf("error count = %d, want 2", ct.ErrorCount("openai")) + } + + ct.MarkSuccess("openai") + if ct.ErrorCount("openai") != 0 { + t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai")) + } + if !ct.IsAvailable("openai") { + t.Error("should be available after success") + } + if ct.FailureCount("openai", FailoverRateLimit) != 0 { + t.Error("failure counts should be reset after success") + } + if ct.FailureCount("openai", FailoverBilling) != 0 { + t.Error("billing failure count should be reset after success") + } +} + +func TestCooldown_FailureWindowReset(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // 4 errors → 1h cooldown + for i := 0; i < 4; i++ { + ct.MarkFailure("openai", FailoverRateLimit) + *current = current.Add(2 * time.Second) // small advance between errors + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("error count = %d, want 4", ct.ErrorCount("openai")) + } + + // Advance 25 hours (past 24h failure window) + *current = now.Add(25 * time.Hour) + + // Next error should reset counters first, then increment to 1 + ct.MarkFailure("openai", FailoverRateLimit) + if ct.ErrorCount("openai") != 1 { + t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai")) + } +} + +func TestCooldown_PerReasonTracking(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("openai", FailoverBilling) + ct.MarkFailure("openai", FailoverAuth) + + if ct.FailureCount("openai", FailoverRateLimit) != 2 { + t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit)) + } + if ct.FailureCount("openai", FailoverBilling) != 1 { + t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling)) + } + if ct.FailureCount("openai", FailoverAuth) != 1 { + t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth)) + } + if ct.ErrorCount("openai") != 4 { + t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai")) + } +} + +func TestCooldown_BillingTakesPrecedence(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // Standard cooldown (1 min) + billing disable (5h) + ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown + ct.MarkFailure("openai", FailoverBilling) // 5h disable + + // After 2 min: standard cooldown expired but billing still active + *current = now.Add(2 * time.Minute) + if ct.IsAvailable("openai") { + t.Error("billing disable should take precedence over standard cooldown") + } + + // After 5h + 1s: both expired + *current = now.Add(5*time.Hour + 1*time.Second) + if !ct.IsAvailable("openai") { + t.Error("should be available after all cooldowns expire") + } +} + +func TestCooldown_CooldownRemaining(t *testing.T) { + now := time.Now() + ct, current := newTestTracker(now) + + // No failures → 0 remaining + if ct.CooldownRemaining("openai") != 0 { + t.Error("expected 0 remaining for new provider") + } + + ct.MarkFailure("openai", FailoverRateLimit) + + *current = now.Add(30 * time.Second) + remaining := ct.CooldownRemaining("openai") + if remaining <= 0 || remaining > 1*time.Minute { + t.Errorf("remaining = %v, expected ~30s", remaining) + } +} + +func TestCooldown_SuccessOnUnknownProvider(t *testing.T) { + ct := NewCooldownTracker() + // Should not panic + ct.MarkSuccess("nonexistent") + if !ct.IsAvailable("nonexistent") { + t.Error("nonexistent provider should be available") + } +} + +func TestCooldown_ConcurrentAccess(t *testing.T) { + ct := NewCooldownTracker() + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(3) + go func() { + defer wg.Done() + ct.MarkFailure("openai", FailoverRateLimit) + }() + go func() { + defer wg.Done() + ct.IsAvailable("openai") + }() + go func() { + defer wg.Done() + ct.MarkSuccess("openai") + }() + } + + wg.Wait() + // If we got here without panic, concurrent access is safe +} + +func TestCooldown_MultipleProviders(t *testing.T) { + ct := NewCooldownTracker() + + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + if ct.IsAvailable("openai") { + t.Error("openai should be in cooldown") + } + if ct.IsAvailable("anthropic") { + t.Error("anthropic should be in cooldown") + } + // groq was never touched + if !ct.IsAvailable("groq") { + t.Error("groq should be available") + } +} diff --git a/pkg/providers/error_classifier.go b/pkg/providers/error_classifier.go new file mode 100644 index 000000000..a0f003006 --- /dev/null +++ b/pkg/providers/error_classifier.go @@ -0,0 +1,253 @@ +package providers + +import ( + "context" + "regexp" + "strings" +) + +// errorPattern defines a single pattern (string or regex) for error classification. +type errorPattern struct { + substring string + regex *regexp.Regexp +} + +func substr(s string) errorPattern { return errorPattern{substring: s} } +func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} } + +// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns). +var ( + rateLimitPatterns = []errorPattern{ + rxp(`rate[_ ]limit`), + substr("too many requests"), + substr("429"), + substr("exceeded your current quota"), + rxp(`exceeded.*quota`), + rxp(`resource has been exhausted`), + rxp(`resource.*exhausted`), + substr("resource_exhausted"), + substr("quota exceeded"), + substr("usage limit"), + } + + overloadedPatterns = []errorPattern{ + rxp(`overloaded_error`), + rxp(`"type"\s*:\s*"overloaded_error"`), + substr("overloaded"), + } + + timeoutPatterns = []errorPattern{ + substr("timeout"), + substr("timed out"), + substr("deadline exceeded"), + substr("context deadline exceeded"), + } + + billingPatterns = []errorPattern{ + rxp(`\b402\b`), + substr("payment required"), + substr("insufficient credits"), + substr("credit balance"), + substr("plans & billing"), + substr("insufficient balance"), + } + + authPatterns = []errorPattern{ + rxp(`invalid[_ ]?api[_ ]?key`), + substr("incorrect api key"), + substr("invalid token"), + substr("authentication"), + substr("re-authenticate"), + substr("oauth token refresh failed"), + substr("unauthorized"), + substr("forbidden"), + substr("access denied"), + substr("expired"), + substr("token has expired"), + rxp(`\b401\b`), + rxp(`\b403\b`), + substr("no credentials found"), + substr("no api key found"), + } + + formatPatterns = []errorPattern{ + substr("string should match pattern"), + substr("tool_use.id"), + substr("tool_use_id"), + substr("messages.1.content.1.tool_use.id"), + substr("invalid request format"), + } + + imageDimensionPatterns = []errorPattern{ + rxp(`image dimensions exceed max`), + } + + imageSizePatterns = []errorPattern{ + rxp(`image exceeds.*mb`), + } + + // Transient HTTP status codes that map to timeout (server-side failures). + transientStatusCodes = map[int]bool{ + 500: true, 502: true, 503: true, + 521: true, 522: true, 523: true, 524: true, + 529: true, + } +) + +// ClassifyError classifies an error into a FailoverError with reason. +// Returns nil if the error is not classifiable (unknown errors should not trigger fallback). +func ClassifyError(err error, provider, model string) *FailoverError { + if err == nil { + return nil + } + + // Context cancellation: user abort, never fallback. + if err == context.Canceled { + return nil + } + + // Context deadline exceeded: treat as timeout, always fallback. + if err == context.DeadlineExceeded { + return &FailoverError{ + Reason: FailoverTimeout, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + msg := strings.ToLower(err.Error()) + + // Image dimension/size errors: non-retriable, non-fallback. + if IsImageDimensionError(msg) || IsImageSizeError(msg) { + return &FailoverError{ + Reason: FailoverFormat, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + // Try HTTP status code extraction first. + if status := extractHTTPStatus(msg); status > 0 { + if reason := classifyByStatus(status); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Status: status, + Wrapped: err, + } + } + } + + // Message pattern matching (priority order from OpenClaw). + if reason := classifyByMessage(msg); reason != "" { + return &FailoverError{ + Reason: reason, + Provider: provider, + Model: model, + Wrapped: err, + } + } + + return nil +} + +// classifyByStatus maps HTTP status codes to FailoverReason. +func classifyByStatus(status int) FailoverReason { + switch { + case status == 401 || status == 403: + return FailoverAuth + case status == 402: + return FailoverBilling + case status == 408: + return FailoverTimeout + case status == 429: + return FailoverRateLimit + case status == 400: + return FailoverFormat + case transientStatusCodes[status]: + return FailoverTimeout + } + return "" +} + +// classifyByMessage matches error messages against patterns. +// Priority order matters (from OpenClaw classifyFailoverReason). +func classifyByMessage(msg string) FailoverReason { + if matchesAny(msg, rateLimitPatterns) { + return FailoverRateLimit + } + if matchesAny(msg, overloadedPatterns) { + return FailoverRateLimit // Overloaded treated as rate_limit + } + if matchesAny(msg, billingPatterns) { + return FailoverBilling + } + if matchesAny(msg, timeoutPatterns) { + return FailoverTimeout + } + if matchesAny(msg, authPatterns) { + return FailoverAuth + } + if matchesAny(msg, formatPatterns) { + return FailoverFormat + } + return "" +} + +// extractHTTPStatus extracts an HTTP status code from an error message. +// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429". +func extractHTTPStatus(msg string) int { + // Common patterns in Go HTTP error messages + patterns := []*regexp.Regexp{ + regexp.MustCompile(`status[:\s]+(\d{3})`), + regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`), + } + + for _, p := range patterns { + if m := p.FindStringSubmatch(msg); len(m) > 1 { + return parseDigits(m[1]) + } + } + + return 0 +} + +// IsImageDimensionError returns true if the message indicates an image dimension error. +func IsImageDimensionError(msg string) bool { + return matchesAny(msg, imageDimensionPatterns) +} + +// IsImageSizeError returns true if the message indicates an image file size error. +func IsImageSizeError(msg string) bool { + return matchesAny(msg, imageSizePatterns) +} + +// matchesAny checks if msg matches any of the patterns. +func matchesAny(msg string, patterns []errorPattern) bool { + for _, p := range patterns { + if p.regex != nil { + if p.regex.MatchString(msg) { + return true + } + } else if p.substring != "" { + if strings.Contains(msg, p.substring) { + return true + } + } + } + return false +} + +// parseDigits converts a string of digits to an int. +func parseDigits(s string) int { + n := 0 + for _, c := range s { + if c >= '0' && c <= '9' { + n = n*10 + int(c-'0') + } + } + return n +} diff --git a/pkg/providers/error_classifier_test.go b/pkg/providers/error_classifier_test.go new file mode 100644 index 000000000..865aea57a --- /dev/null +++ b/pkg/providers/error_classifier_test.go @@ -0,0 +1,337 @@ +package providers + +import ( + "context" + "errors" + "fmt" + "testing" +) + +func TestClassifyError_Nil(t *testing.T) { + result := ClassifyError(nil, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for nil error, got %+v", result) + } +} + +func TestClassifyError_ContextCanceled(t *testing.T) { + result := ClassifyError(context.Canceled, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for context.Canceled (user abort), got %+v", result) + } +} + +func TestClassifyError_ContextDeadlineExceeded(t *testing.T) { + result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4") + if result == nil { + t.Fatal("expected non-nil for deadline exceeded") + } + if result.Reason != FailoverTimeout { + t.Errorf("reason = %q, want timeout", result.Reason) + } +} + +func TestClassifyError_StatusCodes(t *testing.T) { + tests := []struct { + status int + reason FailoverReason + }{ + {401, FailoverAuth}, + {403, FailoverAuth}, + {402, FailoverBilling}, + {408, FailoverTimeout}, + {429, FailoverRateLimit}, + {400, FailoverFormat}, + {500, FailoverTimeout}, + {502, FailoverTimeout}, + {503, FailoverTimeout}, + {521, FailoverTimeout}, + {522, FailoverTimeout}, + {523, FailoverTimeout}, + {524, FailoverTimeout}, + {529, FailoverTimeout}, + } + + for _, tt := range tests { + err := fmt.Errorf("API error: status: %d something went wrong", tt.status) + result := ClassifyError(err, "test", "model") + if result == nil { + t.Errorf("status %d: expected non-nil", tt.status) + continue + } + if result.Reason != tt.reason { + t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason) + } + } +} + +func TestClassifyError_RateLimitPatterns(t *testing.T) { + patterns := []string{ + "rate limit exceeded", + "rate_limit reached", + "too many requests", + "exceeded your current quota", + "resource has been exhausted", + "resource_exhausted", + "quota exceeded", + "usage limit reached", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_OverloadedPatterns(t *testing.T) { + patterns := []string{ + "overloaded_error", + `{"type": "overloaded_error"}`, + "server is overloaded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + // Overloaded is treated as rate_limit + if result.Reason != FailoverRateLimit { + t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason) + } + } +} + +func TestClassifyError_BillingPatterns(t *testing.T) { + patterns := []string{ + "payment required", + "insufficient credits", + "credit balance too low", + "plans & billing page", + "insufficient balance", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverBilling { + t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason) + } + } +} + +func TestClassifyError_TimeoutPatterns(t *testing.T) { + patterns := []string{ + "request timeout", + "connection timed out", + "deadline exceeded", + "context deadline exceeded", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverTimeout { + t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason) + } + } +} + +func TestClassifyError_AuthPatterns(t *testing.T) { + patterns := []string{ + "invalid api key", + "invalid_api_key", + "incorrect api key", + "invalid token", + "authentication failed", + "re-authenticate", + "oauth token refresh failed", + "unauthorized access", + "forbidden", + "access denied", + "expired", + "token has expired", + "no credentials found", + "no api key found", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "openai", "gpt-4") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverAuth { + t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason) + } + } +} + +func TestClassifyError_FormatPatterns(t *testing.T) { + patterns := []string{ + "string should match pattern", + "tool_use.id is required", + "invalid tool_use_id", + "messages.1.content.1.tool_use.id must be valid", + "invalid request format", + } + + for _, msg := range patterns { + err := errors.New(msg) + result := ClassifyError(err, "anthropic", "claude") + if result == nil { + t.Errorf("pattern %q: expected non-nil", msg) + continue + } + if result.Reason != FailoverFormat { + t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason) + } + } +} + +func TestClassifyError_ImageDimensionError(t *testing.T) { + err := errors.New("image dimensions exceed max allowed 2048x2048") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image dimension error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } + if result.IsRetriable() { + t.Error("image dimension error should not be retriable") + } +} + +func TestClassifyError_ImageSizeError(t *testing.T) { + err := errors.New("image exceeds 20 mb limit") + result := ClassifyError(err, "openai", "gpt-4o") + if result == nil { + t.Fatal("expected non-nil for image size error") + } + if result.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", result.Reason) + } +} + +func TestClassifyError_UnknownError(t *testing.T) { + err := errors.New("some completely random error") + result := ClassifyError(err, "openai", "gpt-4") + if result != nil { + t.Errorf("expected nil for unknown error, got %+v", result) + } +} + +func TestClassifyError_ProviderModelPropagation(t *testing.T) { + err := errors.New("rate limit exceeded") + result := ClassifyError(err, "my-provider", "my-model") + if result == nil { + t.Fatal("expected non-nil") + } + if result.Provider != "my-provider" { + t.Errorf("provider = %q, want my-provider", result.Provider) + } + if result.Model != "my-model" { + t.Errorf("model = %q, want my-model", result.Model) + } +} + +func TestFailoverError_IsRetriable(t *testing.T) { + tests := []struct { + reason FailoverReason + retriable bool + }{ + {FailoverAuth, true}, + {FailoverRateLimit, true}, + {FailoverBilling, true}, + {FailoverTimeout, true}, + {FailoverOverloaded, true}, + {FailoverFormat, false}, + {FailoverUnknown, true}, + } + + for _, tt := range tests { + fe := &FailoverError{Reason: tt.reason} + if fe.IsRetriable() != tt.retriable { + t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable) + } + } +} + +func TestFailoverError_ErrorString(t *testing.T) { + fe := &FailoverError{ + Reason: FailoverRateLimit, + Provider: "openai", + Model: "gpt-4", + Status: 429, + Wrapped: errors.New("too many requests"), + } + s := fe.Error() + if s == "" { + t.Error("expected non-empty error string") + } +} + +func TestFailoverError_Unwrap(t *testing.T) { + inner := errors.New("inner error") + fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner} + if fe.Unwrap() != inner { + t.Error("Unwrap should return wrapped error") + } +} + +func TestExtractHTTPStatus(t *testing.T) { + tests := []struct { + msg string + want int + }{ + {"status: 429 rate limited", 429}, + {"status 401 unauthorized", 401}, + {"HTTP/1.1 502 Bad Gateway", 502}, + {"no status code here", 0}, + {"random number 12345", 0}, + } + + for _, tt := range tests { + got := extractHTTPStatus(tt.msg) + if got != tt.want { + t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want) + } + } +} + +func TestIsImageDimensionError(t *testing.T) { + if !IsImageDimensionError("image dimensions exceed max 4096x4096") { + t.Error("should match image dimensions exceed max") + } + if IsImageDimensionError("normal error message") { + t.Error("should not match normal error") + } +} + +func TestIsImageSizeError(t *testing.T) { + if !IsImageSizeError("image exceeds 20 mb") { + t.Error("should match image exceeds mb") + } + if IsImageSizeError("normal error message") { + t.Error("should not match normal error") + } +} diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go new file mode 100644 index 000000000..e39cfe32b --- /dev/null +++ b/pkg/providers/factory.go @@ -0,0 +1,360 @@ +package providers + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +const defaultAnthropicAPIBase = "https://api.anthropic.com/v1" + +var getCredential = auth.GetCredential + +type providerType int + +const ( + providerTypeHTTPCompat providerType = iota + providerTypeClaudeAuth + providerTypeCodexAuth + providerTypeCodexCLIToken + providerTypeClaudeCLI + providerTypeCodexCLI + providerTypeGitHubCopilot +) + +type providerSelection struct { + providerType providerType + apiKey string + apiBase string + proxy string + model string + workspace string + connectMode string + enableWebSearch bool +} + +func createClaudeAuthProvider(apiBase string) (LLMProvider, error) { + if apiBase == "" { + apiBase = defaultAnthropicAPIBase + } + cred, err := getCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil +} + +func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) { + cred, err := getCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()) + p.enableWebSearch = enableWebSearch + return p, nil +} + +func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { + model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) + lowerModel := strings.ToLower(model) + + sel := providerSelection{ + providerType: providerTypeHTTPCompat, + model: model, + } + + // First, prefer explicit provider configuration. + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + sel.providerType = providerTypeCodexCLIToken + return sel, nil + } + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + sel.apiKey = cfg.Providers.ShengSuanYun.APIKey + sel.apiBase = cfg.Providers.ShengSuanYun.APIBase + sel.proxy = cfg.Providers.ShengSuanYun.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://router.shengsuanyun.com/api/v1" + } + } + case "nvidia": + if cfg.Providers.Nvidia.APIKey != "" { + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + } + case "claude-cli", "claude-code", "claudecode": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeClaudeCLI + sel.workspace = workspace + return sel, nil + case "codex-cli", "codex-code": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeCodexCLI + sel.workspace = workspace + return sel, nil + case "deepseek": + if cfg.Providers.DeepSeek.APIKey != "" { + sel.apiKey = cfg.Providers.DeepSeek.APIKey + sel.apiBase = cfg.Providers.DeepSeek.APIBase + sel.proxy = cfg.Providers.DeepSeek.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.deepseek.com/v1" + } + if model != "deepseek-chat" && model != "deepseek-reasoner" { + sel.model = "deepseek-chat" + } + } + case "github_copilot", "copilot": + sel.providerType = providerTypeGitHubCopilot + if cfg.Providers.GitHubCopilot.APIBase != "" { + sel.apiBase = cfg.Providers.GitHubCopilot.APIBase + } else { + sel.apiBase = "localhost:4321" + } + sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode + return sel, nil + } + } + + // Fallback: infer provider from model and configured keys. + if sel.apiKey == "" && sel.apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + sel.apiKey = cfg.Providers.Moonshot.APIKey + sel.apiBase = cfg.Providers.Moonshot.APIBase + sel.proxy = cfg.Providers.Moonshot.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.moonshot.cn/v1" + } + case strings.HasPrefix(model, "openrouter/") || + strings.HasPrefix(model, "anthropic/") || + strings.HasPrefix(model, "openai/") || + strings.HasPrefix(model, "meta-llama/") || + strings.HasPrefix(model, "deepseek/") || + strings.HasPrefix(model, "google/"): + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && + (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && + (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + sel.providerType = providerTypeCodexCLIToken + return sel, nil + } + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": + sel.apiKey = cfg.Providers.Ollama.APIKey + sel.apiBase = cfg.Providers.Ollama.APIBase + sel.proxy = cfg.Providers.Ollama.Proxy + if sel.apiBase == "" { + sel.apiBase = "http://localhost:11434/v1" + } + case cfg.Providers.VLLM.APIBase != "": + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + default: + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } else { + return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model) + } + } + } + + if sel.providerType == providerTypeHTTPCompat { + if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") { + return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model) + } + if sel.apiBase == "" { + return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model) + } + } + + return sel, nil +} + +func CreateProvider(cfg *config.Config) (LLMProvider, error) { + sel, err := resolveProviderSelection(cfg) + if err != nil { + return nil, err + } + + switch sel.providerType { + case providerTypeClaudeAuth: + return createClaudeAuthProvider(sel.apiBase) + case providerTypeCodexAuth: + return createCodexAuthProvider(sel.enableWebSearch) + case providerTypeCodexCLIToken: + c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()) + c.enableWebSearch = sel.enableWebSearch + return c, nil + case providerTypeClaudeCLI: + return NewClaudeCliProvider(sel.workspace), nil + case providerTypeCodexCLI: + return NewCodexCliProvider(sel.workspace), nil + case providerTypeGitHubCopilot: + return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model) + default: + return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil + } +} diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go new file mode 100644 index 000000000..e31737eb9 --- /dev/null +++ b/pkg/providers/factory_test.go @@ -0,0 +1,299 @@ +package providers + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestResolveProviderSelection(t *testing.T) { + tests := []struct { + name string + setup func(*config.Config) + wantType providerType + wantAPIBase string + wantProxy string + wantErrSubstr string + }{ + { + name: "explicit claude-cli provider routes to cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeClaudeCLI, + }, + { + name: "explicit copilot provider routes to github copilot type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "copilot" + }, + wantType: providerTypeGitHubCopilot, + wantAPIBase: "localhost:4321", + }, + { + name: "explicit deepseek provider uses deepseek defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "deepseek" + cfg.Agents.Defaults.Model = "deepseek/deepseek-chat" + cfg.Providers.DeepSeek.APIKey = "deepseek-key" + cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.deepseek.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit shengsuanyun provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "shengsuanyun" + cfg.Providers.ShengSuanYun.APIKey = "ssy-key" + cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://router.shengsuanyun.com/api/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit nvidia provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "nvidia" + cfg.Providers.Nvidia.APIKey = "nvapi-test" + cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://integrate.api.nvidia.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "openrouter model uses openrouter defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://openrouter.ai/api/v1", + }, + { + name: "anthropic oauth routes to claude auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929" + cfg.Providers.Anthropic.AuthMethod = "oauth" + }, + wantType: providerTypeClaudeAuth, + }, + { + name: "openai oauth routes to codex auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "oauth" + }, + wantType: providerTypeCodexAuth, + }, + { + name: "openai codex-cli auth routes to codex cli token provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "codex-cli" + }, + wantType: providerTypeCodexCLIToken, + }, + { + name: "explicit codex-code provider routes to codex cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "codex-code" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeCodexCLI, + }, + { + name: "zhipu model uses zhipu base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "glm-4.7" + cfg.Providers.Zhipu.APIKey = "zhipu-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://open.bigmodel.cn/api/paas/v4", + }, + { + name: "groq model uses groq base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "groq/llama-3.3-70b" + cfg.Providers.Groq.APIKey = "gsk-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.groq.com/openai/v1", + }, + { + name: "ollama model uses ollama base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b" + cfg.Providers.Ollama.APIKey = "ollama-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "http://localhost:11434/v1", + }, + { + name: "moonshot model keeps proxy and default base", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5" + cfg.Providers.Moonshot.APIKey = "moonshot-key" + cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.moonshot.cn/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "missing keys returns model config error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "custom-model" + }, + wantErrSubstr: "no API key configured for model", + }, + { + name: "openrouter prefix without key returns provider key error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + }, + wantErrSubstr: "no API key configured for provider", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.DefaultConfig() + tt.setup(cfg) + + got, err := resolveProviderSelection(cfg) + if tt.wantErrSubstr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr) + } + if !strings.Contains(err.Error(), tt.wantErrSubstr) { + t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr) + } + return + } + + if err != nil { + t.Fatalf("resolveProviderSelection() error = %v", err) + } + if got.providerType != tt.wantType { + t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType) + } + if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase { + t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase) + } + if tt.wantProxy != "" && got.proxy != tt.wantProxy { + t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy) + } + }) + } +} + +func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*HTTPProvider); !ok { + t.Fatalf("provider type = %T, want *HTTPProvider", provider) + } +} + +func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "codex-code" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexCliProvider); !ok { + t.Fatalf("provider type = %T, want *CodexCliProvider", provider) + } +} + +func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "openai" + cfg.Providers.OpenAI.AuthMethod = "codex-cli" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexProvider); !ok { + t.Fatalf("provider type = %T, want *CodexProvider", provider) + } +} + +func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "anthropic" { + t.Fatalf("provider = %q, want anthropic", provider) + } + return &auth.AuthCredential{ + AccessToken: "anthropic-token", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "anthropic" + cfg.Providers.Anthropic.AuthMethod = "oauth" + cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + claudeProvider, ok := provider.(*ClaudeProvider) + if !ok { + t.Fatalf("provider type = %T, want *ClaudeProvider", provider) + } + if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" { + t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com") + } +} + +func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "openai" { + t.Fatalf("provider = %q, want openai", provider) + } + return &auth.AuthCredential{ + AccessToken: "openai-token", + AccountID: "acct_123", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "openai" + cfg.Providers.OpenAI.AuthMethod = "oauth" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexProvider); !ok { + t.Fatalf("provider type = %T, want *CodexProvider", provider) + } +} diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go new file mode 100644 index 000000000..9b07f9153 --- /dev/null +++ b/pkg/providers/fallback.go @@ -0,0 +1,283 @@ +package providers + +import ( + "context" + "fmt" + "strings" + "time" +) + +// FallbackChain orchestrates model fallback across multiple candidates. +type FallbackChain struct { + cooldown *CooldownTracker +} + +// FallbackCandidate represents one model/provider to try. +type FallbackCandidate struct { + Provider string + Model string +} + +// FallbackResult contains the successful response and metadata about all attempts. +type FallbackResult struct { + Response *LLMResponse + Provider string + Model string + Attempts []FallbackAttempt +} + +// FallbackAttempt records one attempt in the fallback chain. +type FallbackAttempt struct { + Provider string + Model string + Error error + Reason FailoverReason + Duration time.Duration + Skipped bool // true if skipped due to cooldown +} + +// NewFallbackChain creates a new fallback chain with the given cooldown tracker. +func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain { + return &FallbackChain{cooldown: cooldown} +} + +// ResolveCandidates parses model config into a deduplicated candidate list. +func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate { + seen := make(map[string]bool) + var candidates []FallbackCandidate + + addCandidate := func(raw string) { + ref := ParseModelRef(raw, defaultProvider) + if ref == nil { + return + } + key := ModelKey(ref.Provider, ref.Model) + if seen[key] { + return + } + seen[key] = true + candidates = append(candidates, FallbackCandidate{ + Provider: ref.Provider, + Model: ref.Model, + }) + } + + // Primary first. + addCandidate(cfg.Primary) + + // Then fallbacks. + for _, fb := range cfg.Fallbacks { + addCandidate(fb) + } + + return candidates +} + +// Execute runs the fallback chain for text/chat requests. +// It tries each candidate in order, respecting cooldowns and error classification. +// +// Behavior: +// - Candidates in cooldown are skipped (logged as skipped attempt). +// - context.Canceled aborts immediately (user abort, no fallback). +// - Non-retriable errors (format) abort immediately. +// - Retriable errors trigger fallback to next candidate. +// - Success marks provider as good (resets cooldown). +// - If all fail, returns aggregate error with all attempts. +func (fc *FallbackChain) Execute( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + // Check context before each attempt. + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + // Check cooldown. + if !fc.cooldown.IsAvailable(candidate.Provider) { + remaining := fc.cooldown.CooldownRemaining(candidate.Provider) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Skipped: true, + Reason: FailoverRateLimit, + Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)), + }) + continue + } + + // Execute the run function. + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + // Success. + fc.cooldown.MarkSuccess(candidate.Provider) + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + // Context cancellation: abort immediately, no fallback. + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Classify the error. + failErr := ClassifyError(err, candidate.Provider, candidate.Model) + + if failErr == nil { + // Unclassifiable error: do not fallback, return immediately. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w", + candidate.Provider, candidate.Model, err) + } + + // Non-retriable error: abort immediately. + if !failErr.IsRetriable() { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + return nil, failErr + } + + // Retriable error: mark failure and continue to next candidate. + fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason) + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: failErr, + Reason: failErr.Reason, + Duration: elapsed, + }) + + // If this was the last candidate, return aggregate error. + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + // All candidates were skipped (all in cooldown). + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// ExecuteImage runs the fallback chain for image/vision requests. +// Simpler than Execute: no cooldown checks (image endpoints have different rate limits). +// Image dimension/size errors abort immediately (non-retriable). +func (fc *FallbackChain) ExecuteImage( + ctx context.Context, + candidates []FallbackCandidate, + run func(ctx context.Context, provider, model string) (*LLMResponse, error), +) (*FallbackResult, error) { + if len(candidates) == 0 { + return nil, fmt.Errorf("image fallback: no candidates configured") + } + + result := &FallbackResult{ + Attempts: make([]FallbackAttempt, 0, len(candidates)), + } + + for i, candidate := range candidates { + if ctx.Err() == context.Canceled { + return nil, context.Canceled + } + + start := time.Now() + resp, err := run(ctx, candidate.Provider, candidate.Model) + elapsed := time.Since(start) + + if err == nil { + result.Response = resp + result.Provider = candidate.Provider + result.Model = candidate.Model + return result, nil + } + + if ctx.Err() == context.Canceled { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + return nil, context.Canceled + } + + // Image dimension/size errors are non-retriable. + errMsg := strings.ToLower(err.Error()) + if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) { + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Reason: FailoverFormat, + Duration: elapsed, + }) + return nil, &FailoverError{ + Reason: FailoverFormat, + Provider: candidate.Provider, + Model: candidate.Model, + Wrapped: err, + } + } + + // Any other error: record and try next. + result.Attempts = append(result.Attempts, FallbackAttempt{ + Provider: candidate.Provider, + Model: candidate.Model, + Error: err, + Duration: elapsed, + }) + + if i == len(candidates)-1 { + return nil, &FallbackExhaustedError{Attempts: result.Attempts} + } + } + + return nil, &FallbackExhaustedError{Attempts: result.Attempts} +} + +// FallbackExhaustedError indicates all fallback candidates were tried and failed. +type FallbackExhaustedError struct { + Attempts []FallbackAttempt +} + +func (e *FallbackExhaustedError) Error() string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts))) + for i, a := range e.Attempts { + if a.Skipped { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model)) + } else { + sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)", + i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond))) + } + } + return sb.String() +} diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go new file mode 100644 index 000000000..ea81e0d48 --- /dev/null +++ b/pkg/providers/fallback_test.go @@ -0,0 +1,473 @@ +package providers + +import ( + "context" + "errors" + "testing" + "time" +) + +func makeCandidate(provider, model string) FallbackCandidate { + return FallbackCandidate{Provider: provider, Model: model} +} + +func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return &LLMResponse{Content: content, FinishReason: "stop"}, nil + } +} + +func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, err + } +} + +func TestFallback_SingleCandidate_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("hello")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "hello" { + t.Errorf("content = %q, want hello", result.Response.Content) + } + if result.Provider != "openai" || result.Model != "gpt-4" { + t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model) + } +} + +func TestFallback_SecondCandidateSuccess(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude-opus"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + if result.Response.Content != "from claude" { + t.Errorf("content = %q, want 'from claude'", result.Response.Content) + } + if len(result.Attempts) != 1 { + t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts)) + } +} + +func TestFallback_AllFail(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + makeCandidate("groq", "llama"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + return nil, errors.New("rate limit exceeded") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error when all candidates fail") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err) + } + if len(exhausted.Attempts) != 3 { + t.Errorf("attempts = %d, want 3", len(exhausted.Attempts)) + } +} + +func TestFallback_ContextCanceled(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + ctx, cancel := context.WithCancel(context.Background()) + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + cancel() // cancel context + return nil, context.Canceled + } + t.Error("should not reach second candidate after cancel") + return nil, nil + } + + _, err := fc.Execute(ctx, candidates, run) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) + } +} + +func TestFallback_NonRetriableError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("string should match pattern") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for non-retriable") + } + var fe *FailoverError + if !errors.As(err, &fe) { + t.Fatalf("expected FailoverError, got %T", err) + } + if fe.Reason != FailoverFormat { + t.Errorf("reason = %q, want format", fe.Reason) + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt) + } +} + +func TestFallback_CooldownSkip(t *testing.T) { + now := time.Now() + ct, _ := newTestTracker(now) + fc := NewFallbackChain(ct) + + // Put openai in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + if provider == "openai" { + t.Error("should not call openai (in cooldown)") + } + return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil + } + + result, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } + // Should have 1 skipped attempt + skipped := 0 + for _, a := range result.Attempts { + if a.Skipped { + skipped++ + } + } + if skipped != 1 { + t.Errorf("skipped = %d, want 1", skipped) + } +} + +func TestFallback_AllInCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + // Put all providers in cooldown + ct.MarkFailure("openai", FailoverRateLimit) + ct.MarkFailure("anthropic", FailoverBilling) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + _, err := fc.Execute(context.Background(), candidates, + func(ctx context.Context, provider, model string) (*LLMResponse, error) { + t.Error("should not call any provider (all in cooldown)") + return nil, nil + }) + + if err == nil { + t.Fatal("expected error when all in cooldown") + } + var exhausted *FallbackExhaustedError + if !errors.As(err, &exhausted) { + t.Fatalf("expected FallbackExhaustedError, got %T", err) + } +} + +func TestFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.Execute(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +func TestFallback_EmptyFallbacks(t *testing.T) { + // Single primary, no fallbacks: should work like direct call + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + result, err := fc.Execute(context.Background(), candidates, successRun("ok")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "ok" { + t.Error("expected success with single candidate") + } +} + +func TestFallback_UnclassifiedError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("completely unknown internal error") + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for unclassified error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt) + } +} + +func TestFallback_SuccessResetsCooldown(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")} + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere + } + return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + + _, err := fc.Execute(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ct.IsAvailable("openai") { + t.Error("success should reset cooldown") + } +} + +// --- Image Fallback Tests --- + +func TestImageFallback_Success(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")} + result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Response.Content != "image result" { + t.Error("expected image result") + } +} + +func TestImageFallback_DimensionError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image dimensions exceed max 4096x4096") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image dimension error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt) + } +} + +func TestImageFallback_SizeError(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + return nil, errors.New("image exceeds 20 mb") + } + + _, err := fc.ExecuteImage(context.Background(), candidates, run) + if err == nil { + t.Fatal("expected error for image size error") + } + if attempt != 1 { + t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt) + } +} + +func TestImageFallback_RetryOnOtherErrors(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + candidates := []FallbackCandidate{ + makeCandidate("openai", "gpt-4o"), + makeCandidate("anthropic", "claude-sonnet"), + } + + attempt := 0 + run := func(ctx context.Context, provider, model string) (*LLMResponse, error) { + attempt++ + if attempt == 1 { + return nil, errors.New("rate limit exceeded") + } + return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil + } + + result, err := fc.ExecuteImage(context.Background(), candidates, run) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", result.Provider) + } +} + +func TestImageFallback_NoCandidates(t *testing.T) { + ct := NewCooldownTracker() + fc := NewFallbackChain(ct) + + _, err := fc.ExecuteImage(context.Background(), nil, successRun("ok")) + if err == nil { + t.Error("expected error for empty candidates") + } +} + +// --- ResolveCandidates Tests --- + +func TestResolveCandidates_Simple(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 3 { + t.Fatalf("candidates = %d, want 3", len(candidates)) + } + + if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" { + t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model) + } + if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" { + t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model) + } + if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" { + t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model) + } +} + +func TestResolveCandidates_Deduplication(t *testing.T) { + cfg := ModelConfig{ + Primary: "openai/gpt-4", + Fallbacks: []string{"openai/gpt-4", "anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "default") + if len(candidates) != 2 { + t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates)) + } +} + +func TestResolveCandidates_EmptyFallbacks(t *testing.T) { + cfg := ModelConfig{ + Primary: "gpt-4", + Fallbacks: nil, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestResolveCandidates_EmptyPrimary(t *testing.T) { + cfg := ModelConfig{ + Primary: "", + Fallbacks: []string{"anthropic/claude"}, + } + + candidates := ResolveCandidates(cfg, "openai") + if len(candidates) != 1 { + t.Errorf("candidates = %d, want 1", len(candidates)) + } +} + +func TestFallbackExhaustedError_Message(t *testing.T) { + e := &FallbackExhaustedError{ + Attempts: []FallbackAttempt{ + {Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond}, + {Provider: "anthropic", Model: "claude", Skipped: true}, + }, + } + msg := e.Error() + if msg == "" { + t.Error("expected non-empty error message") + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 15b22e3a0..eeaa9690a 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -7,201 +7,29 @@ package providers import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" + + "github.com/sipeed/picoclaw/pkg/providers/openai_compat" ) type HTTPProvider struct { - apiKey string - apiBase string - maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models) - httpClient *http.Client + delegate *openai_compat.Provider } func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { - return NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, "") + return &HTTPProvider{ + delegate: openai_compat.NewProvider(apiKey, apiBase, proxy), + } } func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider { - client := &http.Client{ - Timeout: 120 * time.Second, - } - - if proxy != "" { - proxyURL, err := url.Parse(proxy) - if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - } - } - return &HTTPProvider{ - apiKey: apiKey, - apiBase: strings.TrimRight(apiBase, "/"), - maxTokensField: maxTokensField, - httpClient: client, + delegate: openai_compat.NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField), } } func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - if p.apiBase == "" { - return nil, fmt.Errorf("API base not configured") - } - - // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b) - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" || prefix == "qwen" || prefix == "cerebras" { - model = model[idx+1:] - } - } - - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, - } - - if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" - } - - if maxTokens, ok := options["max_tokens"].(int); ok { - // Use configured max_tokens_field if specified, otherwise fallback to model-based detection - fieldName := p.maxTokensField - if fieldName == "" { - // Fallback: detect from model name for backward compatibility - lowerModel := strings.ToLower(model) - if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { - fieldName = "max_completion_tokens" - } else { - fieldName = "max_tokens" - } - } - requestBody[fieldName] = maxTokens - } - - if temperature, ok := options["temperature"].(float64); ok { - lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1 - if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { - requestBody["temperature"] = 1.0 - } else { - requestBody["temperature"] = temperature - } - } - - jsonData, err := json.Marshal(requestBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } - - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) - } - - return p.parseResponse(body) -} - -func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - ThoughtSignature string `json:"thought_signature"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.Unmarshal(body, &apiResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - arguments := make(map[string]interface{}) - name := "" - thoughtSignature := "" - argsStr := "" - - if tc.Function != nil { - name = tc.Function.Name - thoughtSignature = tc.Function.ThoughtSignature - argsStr = tc.Function.Arguments - if argsStr != "" { - if err := json.Unmarshal([]byte(argsStr), &arguments); err != nil { - arguments["raw"] = argsStr - } - } - } - - toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Function: &FunctionCall{ - Name: name, - Arguments: argsStr, - ThoughtSignature: thoughtSignature, - }, - Name: name, - Arguments: arguments, - }) - } - - return &LLMResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil + return p.delegate.Chat(ctx, messages, tools, model, options) } func (p *HTTPProvider) GetDefaultModel() string { diff --git a/pkg/providers/model_ref.go b/pkg/providers/model_ref.go new file mode 100644 index 000000000..0d1b02d16 --- /dev/null +++ b/pkg/providers/model_ref.go @@ -0,0 +1,64 @@ +package providers + +import "strings" + +// ModelRef represents a parsed model reference with provider and model name. +type ModelRef struct { + Provider string + Model string +} + +// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}. +// If no slash present, uses defaultProvider. +// Returns nil for empty input. +func ParseModelRef(raw string, defaultProvider string) *ModelRef { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + if idx := strings.Index(raw, "/"); idx > 0 { + provider := NormalizeProvider(raw[:idx]) + model := strings.TrimSpace(raw[idx+1:]) + if model == "" { + return nil + } + return &ModelRef{Provider: provider, Model: model} + } + + return &ModelRef{ + Provider: NormalizeProvider(defaultProvider), + Model: raw, + } +} + +// NormalizeProvider normalizes provider identifiers to canonical form. +func NormalizeProvider(provider string) string { + p := strings.ToLower(strings.TrimSpace(provider)) + + switch p { + case "z.ai", "z-ai": + return "zai" + case "opencode-zen": + return "opencode" + case "qwen": + return "qwen-portal" + case "kimi-code": + return "kimi-coding" + case "gpt": + return "openai" + case "claude": + return "anthropic" + case "glm": + return "zhipu" + case "google": + return "gemini" + } + + return p +} + +// ModelKey returns a canonical "provider/model" key for deduplication. +func ModelKey(provider, model string) string { + return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model)) +} diff --git a/pkg/providers/model_ref_test.go b/pkg/providers/model_ref_test.go new file mode 100644 index 000000000..6dd25167f --- /dev/null +++ b/pkg/providers/model_ref_test.go @@ -0,0 +1,125 @@ +package providers + +import "testing" + +func TestParseModelRef_WithSlash(t *testing.T) { + ref := ParseModelRef("anthropic/claude-opus", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestParseModelRef_WithoutSlash(t *testing.T) { + ref := ParseModelRef("gpt-4", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai", ref.Provider) + } + if ref.Model != "gpt-4" { + t.Errorf("model = %q, want gpt-4", ref.Model) + } +} + +func TestParseModelRef_Empty(t *testing.T) { + ref := ParseModelRef("", "openai") + if ref != nil { + t.Errorf("expected nil for empty string, got %+v", ref) + } +} + +func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) { + ref := ParseModelRef("openai/", "default") + if ref != nil { + t.Errorf("expected nil for empty model, got %+v", ref) + } +} + +func TestParseModelRef_WhitespaceHandling(t *testing.T) { + ref := ParseModelRef(" anthropic / claude-opus ", "openai") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "anthropic" { + t.Errorf("provider = %q, want anthropic", ref.Provider) + } + if ref.Model != "claude-opus" { + t.Errorf("model = %q, want claude-opus", ref.Model) + } +} + +func TestNormalizeProvider(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"OpenAI", "openai"}, + {"ANTHROPIC", "anthropic"}, + {"z.ai", "zai"}, + {"z-ai", "zai"}, + {"Z.AI", "zai"}, + {"opencode-zen", "opencode"}, + {"qwen", "qwen-portal"}, + {"kimi-code", "kimi-coding"}, + {"gpt", "openai"}, + {"claude", "anthropic"}, + {"glm", "zhipu"}, + {"google", "gemini"}, + {"groq", "groq"}, + {"", ""}, + } + + for _, tt := range tests { + got := NormalizeProvider(tt.input) + if got != tt.want { + t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestModelKey(t *testing.T) { + tests := []struct { + provider string + model string + want string + }{ + {"openai", "gpt-4", "openai/gpt-4"}, + {"Anthropic", "Claude-Opus", "anthropic/claude-opus"}, + {"claude", "sonnet", "anthropic/sonnet"}, + {"z.ai", "Model-X", "zai/model-x"}, + } + + for _, tt := range tests { + got := ModelKey(tt.provider, tt.model) + if got != tt.want { + t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want) + } + } +} + +func TestParseModelRef_ProviderNormalization(t *testing.T) { + ref := ParseModelRef("Z.AI/model-x", "default") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "zai" { + t.Errorf("provider = %q, want zai", ref.Provider) + } +} + +func TestParseModelRef_DefaultProviderNormalization(t *testing.T) { + ref := ParseModelRef("gpt-4o", "GPT") + if ref == nil { + t.Fatal("expected non-nil ref") + } + if ref.Provider != "openai" { + t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider) + } +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go new file mode 100644 index 000000000..73fac3435 --- /dev/null +++ b/pkg/providers/openai_compat/provider.go @@ -0,0 +1,232 @@ +package openai_compat + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +func NewProvider(apiKey, apiBase, proxy string) *Provider { + client := &http.Client{ + Timeout: 120 * time.Second, + } + + if proxy != "" { + parsed, err := url.Parse(proxy) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } else { + log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) + } + } + + return &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: client, + } +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + model = normalizeModel(model, p.apiBase) + + requestBody := map[string]interface{}{ + "model": model, + "messages": messages, + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + if maxTokens, ok := asInt(options["max_tokens"]); ok { + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } + } + + if temperature, ok := asFloat(options["temperature"]); ok { + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1. + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + } + + return parseResponse(body) +} + +func parseResponse(body []byte) (*LLMResponse, error) { + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function *struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + choice := apiResponse.Choices[0] + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + arguments := make(map[string]interface{}) + name := "" + + if tc.Function != nil { + name = tc.Function.Name + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) + arguments["raw"] = tc.Function.Arguments + } + } + } + + toolCalls = append(toolCalls, ToolCall{ + ID: tc.ID, + Name: name, + Arguments: arguments, + }) + } + + return &LLMResponse{ + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, + }, nil +} + +func normalizeModel(model, apiBase string) string { + idx := strings.Index(model, "/") + if idx == -1 { + return model + } + + if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") { + return model + } + + prefix := strings.ToLower(model[:idx]) + switch prefix { + case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu": + return model[idx+1:] + default: + return model + } +} + +func asInt(v interface{}) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case int64: + return int(val), true + case float64: + return int(val), true + case float32: + return int(val), true + default: + return 0, false + } +} + +func asFloat(v interface{}) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case float32: + return float64(val), true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go new file mode 100644 index 000000000..94779b39c --- /dev/null +++ b/pkg/providers/openai_compat/provider_test.go @@ -0,0 +1,277 @@ +package openai_compat + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234}) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, ok := requestBody["max_completion_tokens"]; !ok { + t.Fatalf("expected max_completion_tokens in request body") + } + if _, ok := requestBody["max_tokens"]; ok { + t.Fatalf("did not expect max_tokens key for glm model") + } +} + +func TestProviderChat_ParsesToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": "", + "tool_calls": []map[string]interface{}{ + { + "id": "call_1", + "type": "function", + "function": map[string]interface{}{ + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } +} + +func TestProviderChat_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "moonshot/kimi-k2.5", + map[string]interface{}{"temperature": 0.3}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != "kimi-k2.5" { + t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"]) + } + if requestBody["temperature"] != 1.0 { + t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"]) + } +} + +func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { + tests := []struct { + name string + input string + wantModel string + }{ + { + name: "strips groq prefix and keeps nested model", + input: "groq/openai/gpt-oss-120b", + wantModel: "openai/gpt-oss-120b", + }, + { + name: "strips ollama prefix", + input: "ollama/qwen2.5:14b", + wantModel: "qwen2.5:14b", + }, + { + name: "strips deepseek prefix", + input: "deepseek/deepseek-chat", + wantModel: "deepseek-chat", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != tt.wantModel { + t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel) + } + }) + } +} + +func TestProvider_ProxyConfigured(t *testing.T) { + proxyURL := "http://127.0.0.1:8080" + p := NewProvider("key", "https://example.com", proxyURL) + + transport, ok := p.httpClient.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport) + } + + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function returned error: %v", err) + } + if gotProxy == nil || gotProxy.String() != proxyURL { + t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL) + } +} + +func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "gpt-4o", + map[string]interface{}{"max_tokens": float64(512), "temperature": 1}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["max_tokens"] != float64(512) { + t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"]) + } + if requestBody["temperature"] != float64(1) { + t.Fatalf("temperature = %v, want 1", requestBody["temperature"]) + } +} + +func TestNormalizeModel_UsesAPIBase(t *testing.T) { + if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" { + t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat") + } + if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" { + t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go new file mode 100644 index 000000000..6b33ae734 --- /dev/null +++ b/pkg/providers/protocoltypes/types.go @@ -0,0 +1,45 @@ +package protocoltypes + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 107331d9e..c4a9de58a 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,53 +1,64 @@ package providers -import "context" +import ( + "context" + "fmt" -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - ThoughtSignature string `json:"thought_signature,omitempty"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition type LLMProvider interface { Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) GetDefaultModel() string } -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` +// FailoverReason classifies why an LLM request failed for fallback decisions. +type FailoverReason string + +const ( + FailoverAuth FailoverReason = "auth" + FailoverRateLimit FailoverReason = "rate_limit" + FailoverBilling FailoverReason = "billing" + FailoverTimeout FailoverReason = "timeout" + FailoverFormat FailoverReason = "format" + FailoverOverloaded FailoverReason = "overloaded" + FailoverUnknown FailoverReason = "unknown" +) + +// FailoverError wraps an LLM provider error with classification metadata. +type FailoverError struct { + Reason FailoverReason + Provider string + Model string + Status int + Wrapped error } -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` +func (e *FailoverError) Error() string { + return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v", + e.Reason, e.Provider, e.Model, e.Status, e.Wrapped) +} + +func (e *FailoverError) Unwrap() error { + return e.Wrapped +} + +// IsRetriable returns true if this error should trigger fallback to next candidate. +// Non-retriable: Format errors (bad request structure, image dimension/size). +func (e *FailoverError) IsRetriable() bool { + return e.Reason != FailoverFormat +} + +// ModelConfig holds primary model and fallback list. +type ModelConfig struct { + Primary string + Fallbacks []string } diff --git a/pkg/routing/agent_id.go b/pkg/routing/agent_id.go new file mode 100644 index 000000000..bcf2f0dc0 --- /dev/null +++ b/pkg/routing/agent_id.go @@ -0,0 +1,66 @@ +package routing + +import ( + "regexp" + "strings" +) + +const ( + DefaultAgentID = "main" + DefaultMainKey = "main" + DefaultAccountID = "default" + MaxAgentIDLength = 64 +) + +var ( + validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`) + invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`) + leadingDashRe = regexp.MustCompile(`^-+`) + trailingDashRe = regexp.MustCompile(`-+$`) +) + +// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}. +// Invalid characters are collapsed to "-". Leading/trailing dashes stripped. +// Empty input returns DefaultAgentID ("main"). +func NormalizeAgentID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAgentID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAgentID + } + return result +} + +// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID. +func NormalizeAccountID(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return DefaultAccountID + } + lower := strings.ToLower(trimmed) + if validIDRe.MatchString(lower) { + return lower + } + result := invalidCharsRe.ReplaceAllString(lower, "-") + result = leadingDashRe.ReplaceAllString(result, "") + result = trailingDashRe.ReplaceAllString(result, "") + if len(result) > MaxAgentIDLength { + result = result[:MaxAgentIDLength] + } + if result == "" { + return DefaultAccountID + } + return result +} diff --git a/pkg/routing/agent_id_test.go b/pkg/routing/agent_id_test.go new file mode 100644 index 000000000..050fe0645 --- /dev/null +++ b/pkg/routing/agent_id_test.go @@ -0,0 +1,86 @@ +package routing + +import "testing" + +func TestNormalizeAgentID_Empty(t *testing.T) { + if got := NormalizeAgentID(""); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Whitespace(t *testing.T) { + if got := NormalizeAgentID(" "); got != DefaultAgentID { + t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_Valid(t *testing.T) { + tests := []struct { + input, want string + }{ + {"main", "main"}, + {"Main", "main"}, + {"SALES", "sales"}, + {"support-bot", "support-bot"}, + {"agent_1", "agent_1"}, + {"a", "a"}, + {"0test", "0test"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_InvalidChars(t *testing.T) { + tests := []struct { + input, want string + }{ + {"Hello World", "hello-world"}, + {"agent@123", "agent-123"}, + {"foo.bar.baz", "foo-bar-baz"}, + {"--leading", "leading"}, + {"--both--", "both"}, + } + for _, tt := range tests { + if got := NormalizeAgentID(tt.input); got != tt.want { + t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNormalizeAgentID_AllInvalid(t *testing.T) { + if got := NormalizeAgentID("@@@"); got != DefaultAgentID { + t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID) + } +} + +func TestNormalizeAgentID_TruncatesAt64(t *testing.T) { + long := "" + for i := 0; i < 100; i++ { + long += "a" + } + got := NormalizeAgentID(long) + if len(got) > MaxAgentIDLength { + t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength) + } +} + +func TestNormalizeAccountID_Empty(t *testing.T) { + if got := NormalizeAccountID(""); got != DefaultAccountID { + t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID) + } +} + +func TestNormalizeAccountID_Valid(t *testing.T) { + if got := NormalizeAccountID("MyBot"); got != "mybot" { + t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got) + } +} + +func TestNormalizeAccountID_InvalidChars(t *testing.T) { + if got := NormalizeAccountID("bot@home"); got != "bot-home" { + t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got) + } +} diff --git a/pkg/routing/route.go b/pkg/routing/route.go new file mode 100644 index 000000000..9eb060c53 --- /dev/null +++ b/pkg/routing/route.go @@ -0,0 +1,252 @@ +package routing + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// RouteInput contains the routing context from an inbound message. +type RouteInput struct { + Channel string + AccountID string + Peer *RoutePeer + ParentPeer *RoutePeer + GuildID string + TeamID string +} + +// ResolvedRoute is the result of agent routing. +type ResolvedRoute struct { + AgentID string + Channel string + AccountID string + SessionKey string + MainSessionKey string + MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" +} + +// RouteResolver determines which agent handles a message based on config bindings. +type RouteResolver struct { + cfg *config.Config +} + +// NewRouteResolver creates a new route resolver. +func NewRouteResolver(cfg *config.Config) *RouteResolver { + return &RouteResolver{cfg: cfg} +} + +// ResolveRoute determines which agent handles the message and constructs session keys. +// Implements the 7-level priority cascade: +// peer > parent_peer > guild > team > account > channel_wildcard > default +func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { + channel := strings.ToLower(strings.TrimSpace(input.Channel)) + accountID := NormalizeAccountID(input.AccountID) + peer := input.Peer + + dmScope := DMScope(r.cfg.Session.DMScope) + if dmScope == "" { + dmScope = DMScopeMain + } + identityLinks := r.cfg.Session.IdentityLinks + + bindings := r.filterBindings(channel, accountID) + + choose := func(agentID string, matchedBy string) ResolvedRoute { + resolvedAgentID := r.pickAgentID(agentID) + sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + Peer: peer, + DMScope: dmScope, + IdentityLinks: identityLinks, + })) + mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID)) + return ResolvedRoute{ + AgentID: resolvedAgentID, + Channel: channel, + AccountID: accountID, + SessionKey: sessionKey, + MainSessionKey: mainSessionKey, + MatchedBy: matchedBy, + } + } + + // Priority 1: Peer binding + if peer != nil && strings.TrimSpace(peer.ID) != "" { + if match := r.findPeerMatch(bindings, peer); match != nil { + return choose(match.AgentID, "binding.peer") + } + } + + // Priority 2: Parent peer binding + parentPeer := input.ParentPeer + if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" { + if match := r.findPeerMatch(bindings, parentPeer); match != nil { + return choose(match.AgentID, "binding.peer.parent") + } + } + + // Priority 3: Guild binding + guildID := strings.TrimSpace(input.GuildID) + if guildID != "" { + if match := r.findGuildMatch(bindings, guildID); match != nil { + return choose(match.AgentID, "binding.guild") + } + } + + // Priority 4: Team binding + teamID := strings.TrimSpace(input.TeamID) + if teamID != "" { + if match := r.findTeamMatch(bindings, teamID); match != nil { + return choose(match.AgentID, "binding.team") + } + } + + // Priority 5: Account binding + if match := r.findAccountMatch(bindings); match != nil { + return choose(match.AgentID, "binding.account") + } + + // Priority 6: Channel wildcard binding + if match := r.findChannelWildcardMatch(bindings); match != nil { + return choose(match.AgentID, "binding.channel") + } + + // Priority 7: Default agent + return choose(r.resolveDefaultAgentID(), "default") +} + +func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding { + var filtered []config.AgentBinding + for _, b := range r.cfg.Bindings { + matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel)) + if matchChannel == "" || matchChannel != channel { + continue + } + if !matchesAccountID(b.Match.AccountID, accountID) { + continue + } + filtered = append(filtered, b) + } + return filtered +} + +func matchesAccountID(matchAccountID, actual string) bool { + trimmed := strings.TrimSpace(matchAccountID) + if trimmed == "" { + return actual == DefaultAccountID + } + if trimmed == "*" { + return true + } + return strings.ToLower(trimmed) == strings.ToLower(actual) +} + +func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + if b.Match.Peer == nil { + continue + } + peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind)) + peerID := strings.TrimSpace(b.Match.Peer.ID) + if peerKind == "" || peerID == "" { + continue + } + if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID { + return b + } + } + return nil +} + +func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchGuild := strings.TrimSpace(b.Match.GuildID) + if matchGuild != "" && matchGuild == guildID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + matchTeam := strings.TrimSpace(b.Match.TeamID) + if matchTeam != "" && matchTeam == teamID { + return &bindings[i] + } + } + return nil +} + +func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID == "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding { + for i := range bindings { + b := &bindings[i] + accountID := strings.TrimSpace(b.Match.AccountID) + if accountID != "*" { + continue + } + if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { + continue + } + return &bindings[i] + } + return nil +} + +func (r *RouteResolver) pickAgentID(agentID string) string { + trimmed := strings.TrimSpace(agentID) + if trimmed == "" { + return NormalizeAgentID(r.resolveDefaultAgentID()) + } + normalized := NormalizeAgentID(trimmed) + agents := r.cfg.Agents.List + if len(agents) == 0 { + return normalized + } + for _, a := range agents { + if NormalizeAgentID(a.ID) == normalized { + return normalized + } + } + return NormalizeAgentID(r.resolveDefaultAgentID()) +} + +func (r *RouteResolver) resolveDefaultAgentID() string { + agents := r.cfg.Agents.List + if len(agents) == 0 { + return DefaultAgentID + } + for _, a := range agents { + if a.Default { + id := strings.TrimSpace(a.ID) + if id != "" { + return NormalizeAgentID(id) + } + } + } + if id := strings.TrimSpace(agents[0].ID); id != "" { + return NormalizeAgentID(id) + } + return DefaultAgentID +} diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go new file mode 100644 index 000000000..8255db5f9 --- /dev/null +++ b/pkg/routing/route_test.go @@ -0,0 +1,297 @@ +package routing + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config { + return &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: "/tmp/picoclaw-test", + Model: "gpt-4", + }, + List: agents, + }, + Bindings: bindings, + Session: config.SessionConfig{ + DMScope: "per-peer", + }, + } +} + +func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { + cfg := testConfig(nil, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != DefaultAgentID { + t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID) + } + if route.MatchedBy != "default" { + t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy) + } +} + +func TestResolveRoute_PeerBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "sales", Default: true}, + {ID: "support"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "support", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user123"}, + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + }) + + if route.AgentID != "support" { + t.Errorf("AgentID = %q, want 'support'", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_GuildBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-abc", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-abc", + Peer: &RoutePeer{Kind: "channel", ID: "ch1"}, + }) + + if route.AgentID != "gaming" { + t.Errorf("AgentID = %q, want 'gaming'", route.AgentID) + } + if route.MatchedBy != "binding.guild" { + t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy) + } +} + +func TestResolveRoute_TeamBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "work"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "work", + Match: config.BindingMatch{ + Channel: "slack", + AccountID: "*", + TeamID: "T12345", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "slack", + TeamID: "T12345", + Peer: &RoutePeer{Kind: "channel", ID: "C001"}, + }) + + if route.AgentID != "work" { + t.Errorf("AgentID = %q, want 'work'", route.AgentID) + } + if route.MatchedBy != "binding.team" { + t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy) + } +} + +func TestResolveRoute_AccountBinding(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "default-agent", Default: true}, + {ID: "premium"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "premium", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "bot2", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + AccountID: "bot2", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "premium" { + t.Errorf("AgentID = %q, want 'premium'", route.AgentID) + } + if route.MatchedBy != "binding.account" { + t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy) + } +} + +func TestResolveRoute_ChannelWildcard(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "telegram-bot"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "telegram-bot", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + }) + + if route.AgentID != "telegram-bot" { + t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID) + } + if route.MatchedBy != "binding.channel" { + t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy) + } +} + +func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "general", Default: true}, + {ID: "vip"}, + {ID: "gaming"}, + } + bindings := []config.AgentBinding{ + { + AgentID: "vip", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"}, + }, + }, + { + AgentID: "gaming", + Match: config.BindingMatch{ + Channel: "discord", + AccountID: "*", + GuildID: "guild-1", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "discord", + GuildID: "guild-1", + Peer: &RoutePeer{Kind: "direct", ID: "user-vip"}, + }) + + if route.AgentID != "vip" { + t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID) + } + if route.MatchedBy != "binding.peer" { + t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + } +} + +func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "main", Default: true}, + } + bindings := []config.AgentBinding{ + { + AgentID: "nonexistent", + Match: config.BindingMatch{ + Channel: "telegram", + AccountID: "*", + }, + }, + } + cfg := testConfig(agents, bindings) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "telegram", + }) + + if route.AgentID != "main" { + t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID) + } +} + +func TestResolveRoute_DefaultAgentSelection(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta", Default: true}, + {ID: "gamma"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "beta" { + t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID) + } +} + +func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) { + agents := []config.AgentConfig{ + {ID: "alpha"}, + {ID: "beta"}, + } + cfg := testConfig(agents, nil) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(RouteInput{ + Channel: "cli", + }) + + if route.AgentID != "alpha" { + t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID) + } +} diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go new file mode 100644 index 000000000..e12f0d1d8 --- /dev/null +++ b/pkg/routing/session_key.go @@ -0,0 +1,183 @@ +package routing + +import ( + "fmt" + "strings" +) + +// DMScope controls DM session isolation granularity. +type DMScope string + +const ( + DMScopeMain DMScope = "main" + DMScopePerPeer DMScope = "per-peer" + DMScopePerChannelPeer DMScope = "per-channel-peer" + DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer" +) + +// RoutePeer represents a chat peer with kind and ID. +type RoutePeer struct { + Kind string // "direct", "group", "channel" + ID string +} + +// SessionKeyParams holds all inputs for session key construction. +type SessionKeyParams struct { + AgentID string + Channel string + AccountID string + Peer *RoutePeer + DMScope DMScope + IdentityLinks map[string][]string +} + +// ParsedSessionKey is the result of parsing an agent-scoped session key. +type ParsedSessionKey struct { + AgentID string + Rest string +} + +// BuildAgentMainSessionKey returns "agent::main". +func BuildAgentMainSessionKey(agentID string) string { + return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey) +} + +// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope. +func BuildAgentPeerSessionKey(params SessionKeyParams) string { + agentID := NormalizeAgentID(params.AgentID) + + peer := params.Peer + if peer == nil { + peer = &RoutePeer{Kind: "direct"} + } + peerKind := strings.TrimSpace(peer.Kind) + if peerKind == "" { + peerKind = "direct" + } + + if peerKind == "direct" { + dmScope := params.DMScope + if dmScope == "" { + dmScope = DMScopeMain + } + peerID := strings.TrimSpace(peer.ID) + + // Resolve identity links (cross-platform collapse) + if dmScope != DMScopeMain && peerID != "" { + if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" { + peerID = linked + } + } + peerID = strings.ToLower(peerID) + + switch dmScope { + case DMScopePerAccountChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + accountID := NormalizeAccountID(params.AccountID) + return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID) + } + case DMScopePerChannelPeer: + if peerID != "" { + channel := normalizeChannel(params.Channel) + return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID) + } + case DMScopePerPeer: + if peerID != "" { + return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID) + } + } + return BuildAgentMainSessionKey(agentID) + } + + // Group/channel peers always get per-peer sessions + channel := normalizeChannel(params.Channel) + peerID := strings.ToLower(strings.TrimSpace(peer.ID)) + if peerID == "" { + peerID = "unknown" + } + return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) +} + +// ParseAgentSessionKey extracts agentId and rest from "agent::". +func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return nil + } + parts := strings.SplitN(raw, ":", 3) + if len(parts) < 3 { + return nil + } + if parts[0] != "agent" { + return nil + } + agentID := strings.TrimSpace(parts[1]) + rest := parts[2] + if agentID == "" || rest == "" { + return nil + } + return &ParsedSessionKey{AgentID: agentID, Rest: rest} +} + +// IsSubagentSessionKey returns true if the session key represents a subagent. +func IsSubagentSessionKey(sessionKey string) bool { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return false + } + if strings.HasPrefix(strings.ToLower(raw), "subagent:") { + return true + } + parsed := ParseAgentSessionKey(raw) + if parsed == nil { + return false + } + return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:") +} + +func normalizeChannel(channel string) string { + c := strings.TrimSpace(strings.ToLower(channel)) + if c == "" { + return "unknown" + } + return c +} + +func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]bool) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = true + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) + candidates[scopedCandidate] = true + } + if len(candidates) == 0 { + return "" + } + + for canonical, ids := range identityLinks { + canonicalName := strings.TrimSpace(canonical) + if canonicalName == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized != "" && candidates[normalized] { + return canonicalName + } + } + } + return "" +} diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go new file mode 100644 index 000000000..81e4ce018 --- /dev/null +++ b/pkg/routing/session_key_test.go @@ -0,0 +1,162 @@ +package routing + +import "testing" + +func TestBuildAgentMainSessionKey(t *testing.T) { + got := BuildAgentMainSessionKey("sales") + want := "agent:sales:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want) + } +} + +func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) { + got := BuildAgentMainSessionKey("Sales Bot") + want := "agent:sales-bot:main" + if got != want { + t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopeMain, + }) + want := "agent:main:main" + if got != want { + t.Errorf("DMScopeMain = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:direct:user123" + if got != want { + t.Errorf("DMScopePerPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerChannelPeer, + }) + want := "agent:main:telegram:direct:user123" + if got != want { + t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + AccountID: "bot1", + Peer: &RoutePeer{Kind: "direct", ID: "User123"}, + DMScope: DMScopePerAccountChannelPeer, + }) + want := "agent:main:telegram:bot1:direct:user123" + if got != want { + t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "group", ID: "chat456"}, + DMScope: DMScopePerPeer, + }) + want := "agent:main:telegram:group:chat456" + if got != want { + t.Errorf("GroupPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) { + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: nil, + DMScope: DMScopePerPeer, + }) + // nil peer defaults to direct with empty ID, falls to main + want := "agent:main:main" + if got != want { + t.Errorf("NilPeer = %q, want %q", got, want) + } +} + +func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { + links := map[string][]string{ + "john": {"telegram:user123", "discord:john#1234"}, + } + got := BuildAgentPeerSessionKey(SessionKeyParams{ + AgentID: "main", + Channel: "telegram", + Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + DMScope: DMScopePerPeer, + IdentityLinks: links, + }) + want := "agent:main:direct:john" + if got != want { + t.Errorf("IdentityLink = %q, want %q", got, want) + } +} + +func TestParseAgentSessionKey_Valid(t *testing.T) { + parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") + if parsed == nil { + t.Fatal("expected non-nil result") + } + if parsed.AgentID != "sales" { + t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID) + } + if parsed.Rest != "telegram:direct:user123" { + t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest) + } +} + +func TestParseAgentSessionKey_Invalid(t *testing.T) { + tests := []string{ + "", + "foo:bar", + "notprefix:sales:main", + "agent::main", + "agent:sales:", + } + for _, input := range tests { + if got := ParseAgentSessionKey(input); got != nil { + t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got) + } + } +} + +func TestIsSubagentSessionKey(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"subagent:task-1", true}, + {"agent:main:subagent:task-1", true}, + {"agent:main:main", false}, + {"agent:main:telegram:direct:user123", false}, + {"", false}, + } + for _, tt := range tests { + if got := IsSubagentSessionKey(tt.input); got != tt.want { + t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go index a3263c525..0856254e8 100644 --- a/pkg/skills/installer.go +++ b/pkg/skills/installer.go @@ -8,7 +8,6 @@ import ( "net/http" "os" "path/filepath" - "strings" "time" ) @@ -24,12 +23,6 @@ type AvailableSkill struct { Tags []string `json:"tags"` } -type BuiltinSkill struct { - Name string `json:"name"` - Path string `json:"path"` - Enabled bool `json:"enabled"` -} - func NewSkillInstaller(workspace string) *SkillInstaller { return &SkillInstaller{ workspace: workspace, @@ -123,49 +116,3 @@ func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableS return skills, nil } - -func (si *SkillInstaller) ListBuiltinSkills() []BuiltinSkill { - builtinSkillsDir := filepath.Join(filepath.Dir(si.workspace), "picoclaw", "skills") - - entries, err := os.ReadDir(builtinSkillsDir) - if err != nil { - return nil - } - - var skills []BuiltinSkill - for _, entry := range entries { - if entry.IsDir() { - _ = entry - skillName := entry.Name() - skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md") - - data, err := os.ReadFile(skillFile) - description := "" - if err == nil { - content := string(data) - if idx := strings.Index(content, "\n"); idx > 0 { - firstLine := content[:idx] - if strings.Contains(firstLine, "description:") { - descLine := strings.Index(content[idx:], "\n") - if descLine > 0 { - description = strings.TrimSpace(content[idx+descLine : idx+descLine]) - } - } - } - } - - // skill := BuiltinSkill{ - // Name: skillName, - // Path: description, - // Enabled: true, - // } - - status := "✓" - fmt.Printf(" %s %s\n", status, entry.Name()) - if description != "" { - fmt.Printf(" %s\n", description) - } - } - } - return skills -} diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index 0c63ae067..bb0abbdcc 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -9,6 +9,8 @@ import ( "path/filepath" "regexp" "strings" + + "github.com/sipeed/picoclaw/pkg/logger" ) var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`) @@ -251,6 +253,11 @@ func (sl *SkillsLoader) BuildSkillsSummary() string { func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata { content, err := os.ReadFile(skillPath) if err != nil { + logger.WarnCF("skills", "Failed to read skill metadata", + map[string]interface{}{ + "skill_path": skillPath, + "error": err.Error(), + }) return nil } @@ -283,10 +290,15 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata { // parseSimpleYAML parses simple key: value YAML format // Example: name: github\n description: "..." +// Normalizes line endings to handle \n (Unix), \r\n (Windows), and \r (classic Mac) func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string { result := make(map[string]string) - for _, line := range strings.Split(content, "\n") { + // Normalize line endings: convert \r\n and \r to \n + normalized := strings.ReplaceAll(content, "\r\n", "\n") + normalized = strings.ReplaceAll(normalized, "\r", "\n") + + for _, line := range strings.Split(normalized, "\n") { line = strings.TrimSpace(line) if line == "" || strings.HasPrefix(line, "#") { continue @@ -306,9 +318,10 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string { } func (sl *SkillsLoader) extractFrontmatter(content string) string { - // (?s) enables DOTALL mode so . matches newlines - // Match first ---, capture everything until next --- on its own line - re := regexp.MustCompile(`(?s)^---\n(.*)\n---`) + // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks + // (?s) enables DOTALL so . matches newlines; + // ^--- at start, then ... --- at start of line, honoring all three line ending types + re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`) match := re.FindStringSubmatch(content) if len(match) > 1 { return match[1] @@ -317,7 +330,11 @@ func (sl *SkillsLoader) extractFrontmatter(content string) string { } func (sl *SkillsLoader) stripFrontmatter(content string) string { - re := regexp.MustCompile(`^---\n.*?\n---\n`) + // Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks + // (?s) enables DOTALL so . matches newlines; + // ^--- at start, then ... --- at start of line, honoring all three line ending types + // Match zero or more trailing line endings after closing --- (handles both with and without blank lines) + re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`) return re.ReplaceAllString(content, "") } diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go index e0e7109cf..efadcdbf2 100644 --- a/pkg/skills/loader_test.go +++ b/pkg/skills/loader_test.go @@ -75,3 +75,105 @@ func TestSkillsInfoValidate(t *testing.T) { }) } } + +func TestExtractFrontmatter(t *testing.T) { + sl := &SkillsLoader{} + + testcases := []struct { + name string + content string + expectedName string + expectedDesc string + lineEndingType string + }{ + { + name: "unix-line-endings", + lineEndingType: "Unix (\\n)", + content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content", + expectedName: "test-skill", + expectedDesc: "A test skill", + }, + { + name: "windows-line-endings", + lineEndingType: "Windows (\\r\\n)", + content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content", + expectedName: "test-skill", + expectedDesc: "A test skill", + }, + { + name: "classic-mac-line-endings", + lineEndingType: "Classic Mac (\\r)", + content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content", + expectedName: "test-skill", + expectedDesc: "A test skill", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // Extract frontmatter + frontmatter := sl.extractFrontmatter(tc.content) + assert.NotEmpty(t, frontmatter, "Frontmatter should be extracted for %s line endings", tc.lineEndingType) + + // Parse YAML to get name and description (parseSimpleYAML now handles all line ending types) + yamlMeta := sl.parseSimpleYAML(frontmatter) + assert.Equal(t, tc.expectedName, yamlMeta["name"], "Name should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType) + assert.Equal(t, tc.expectedDesc, yamlMeta["description"], "Description should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType) + }) + } +} + +func TestStripFrontmatter(t *testing.T) { + sl := &SkillsLoader{} + + testcases := []struct { + name string + content string + expectedContent string + lineEndingType string + }{ + { + name: "unix-line-endings", + lineEndingType: "Unix (\\n)", + content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "windows-line-endings", + lineEndingType: "Windows (\\r\\n)", + content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "classic-mac-line-endings", + lineEndingType: "Classic Mac (\\r)", + content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "unix-line-endings-without-trailing-newline", + lineEndingType: "Unix (\\n) without trailing newline", + content: "---\nname: test-skill\ndescription: A test skill\n---\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "windows-line-endings-without-trailing-newline", + lineEndingType: "Windows (\\r\\n) without trailing newline", + content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n# Skill Content", + expectedContent: "# Skill Content", + }, + { + name: "no-frontmatter", + lineEndingType: "No frontmatter", + content: "# Skill Content\n\nSome content here.", + expectedContent: "# Skill Content\n\nSome content here.", + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + result := sl.stripFrontmatter(tc.content) + assert.Equal(t, tc.expectedContent, result, "Frontmatter should be stripped correctly for %s", tc.lineEndingType) + }) + } +} diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 21bee42ef..e2764d8ac 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -29,9 +30,9 @@ type CronTool struct { // NewCronTool creates a new CronTool // execTimeout: 0 means no timeout, >0 sets the timeout duration -func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *CronTool { - execTool := NewExecTool(workspace, restrict) - execTool.SetTimeout(execTimeout) // 0 means no timeout +func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *CronTool { + execTool := NewExecToolWithConfig(workspace, restrict, config) + execTool.SetTimeout(execTimeout) return &CronTool{ cronService: cronService, executor: executor, diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 713850f97..d9430672f 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -11,6 +11,8 @@ import ( "runtime" "strings" "time" + + "github.com/sipeed/picoclaw/pkg/config" ) type ExecTool struct { @@ -21,16 +23,82 @@ type ExecTool struct { restrictToWorkspace bool } +var defaultDenyPatterns = []*regexp.Regexp{ + regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), + regexp.MustCompile(`\bdel\s+/[fq]\b`), + regexp.MustCompile(`\brmdir\s+/s\b`), + regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args) + regexp.MustCompile(`\bdd\s+if=`), + regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) + regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`), + regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`), + regexp.MustCompile(`\$\([^)]+\)`), + regexp.MustCompile(`\$\{[^}]+\}`), + regexp.MustCompile("`[^`]+`"), + regexp.MustCompile(`\|\s*sh\b`), + regexp.MustCompile(`\|\s*bash\b`), + regexp.MustCompile(`;\s*rm\s+-[rf]`), + regexp.MustCompile(`&&\s*rm\s+-[rf]`), + regexp.MustCompile(`\|\|\s*rm\s+-[rf]`), + regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`), + regexp.MustCompile(`<<\s*EOF`), + regexp.MustCompile(`\$\(\s*cat\s+`), + regexp.MustCompile(`\$\(\s*curl\s+`), + regexp.MustCompile(`\$\(\s*wget\s+`), + regexp.MustCompile(`\$\(\s*which\s+`), + regexp.MustCompile(`\bsudo\b`), + regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`), + regexp.MustCompile(`\bchown\b`), + regexp.MustCompile(`\bpkill\b`), + regexp.MustCompile(`\bkillall\b`), + regexp.MustCompile(`\bkill\s+-[9]\b`), + regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`), + regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`), + regexp.MustCompile(`\bnpm\s+install\s+-g\b`), + regexp.MustCompile(`\bpip\s+install\s+--user\b`), + regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`), + regexp.MustCompile(`\byum\s+(install|remove)\b`), + regexp.MustCompile(`\bdnf\s+(install|remove)\b`), + regexp.MustCompile(`\bdocker\s+run\b`), + regexp.MustCompile(`\bdocker\s+exec\b`), + regexp.MustCompile(`\bgit\s+push\b`), + regexp.MustCompile(`\bgit\s+force\b`), + regexp.MustCompile(`\bssh\b.*@`), + regexp.MustCompile(`\beval\b`), + regexp.MustCompile(`\bsource\s+.*\.sh\b`), +} + func NewExecTool(workingDir string, restrict bool) *ExecTool { - denyPatterns := []*regexp.Regexp{ - regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), - regexp.MustCompile(`\bdel\s+/[fq]\b`), - regexp.MustCompile(`\brmdir\s+/s\b`), - regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args) - regexp.MustCompile(`\bdd\s+if=`), - regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) - regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`), - regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`), + return NewExecToolWithConfig(workingDir, restrict, nil) +} + +func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool { + denyPatterns := make([]*regexp.Regexp, 0) + + enableDenyPatterns := true + if config != nil { + execConfig := config.Tools.Exec + enableDenyPatterns = execConfig.EnableDenyPatterns + if enableDenyPatterns { + if len(execConfig.CustomDenyPatterns) > 0 { + fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns) + for _, pattern := range execConfig.CustomDenyPatterns { + re, err := regexp.Compile(pattern) + if err != nil { + fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err) + continue + } + denyPatterns = append(denyPatterns, re) + } + } else { + denyPatterns = append(denyPatterns, defaultDenyPatterns...) + } + } else { + // If deny patterns are disabled, we won't add any patterns, allowing all commands. + fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.") + } + } else { + denyPatterns = append(denyPatterns, defaultDenyPatterns...) } return &ExecTool{ diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 42dd36a33..f01372467 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -6,10 +6,11 @@ import ( ) type SpawnTool struct { - manager *SubagentManager - originChannel string - originChatID string - callback AsyncCallback // For async completion notification + manager *SubagentManager + originChannel string + originChatID string + allowlistCheck func(targetAgentID string) bool + callback AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -45,6 +46,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} { "type": "string", "description": "Optional short label for the task (for display)", }, + "agent_id": map[string]interface{}{ + "type": "string", + "description": "Optional target agent ID to delegate the task to", + }, }, "required": []string{"task"}, } @@ -55,6 +60,10 @@ func (t *SpawnTool) SetContext(channel, chatID string) { t.originChatID = chatID } +func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) { + t.allowlistCheck = check +} + func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { task, ok := args["task"].(string) if !ok { @@ -62,13 +71,21 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T } label, _ := args["label"].(string) + agentID, _ := args["agent_id"].(string) + + // Check allowlist if targeting a specific agent + if agentID != "" && t.allowlistCheck != nil { + if !t.allowlistCheck(agentID) { + return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID)) + } + } if t.manager == nil { return ErrorResult("Subagent manager not configured") } // Pass callback to manager for async completion notification - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback) + result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback) if err != nil { return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index efa1d33aa..2fc7162d0 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -14,6 +14,7 @@ type SubagentTask struct { ID string Task string Label string + AgentID string OriginChannel string OriginChatID string Status string @@ -61,7 +62,7 @@ func (sm *SubagentManager) RegisterTool(tool Tool) { sm.tools.Register(tool) } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) { +func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -72,6 +73,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel ID: taskID, Task: task, Label: label, + AgentID: agentID, OriginChannel: originChannel, OriginChatID: originChatID, Status: "running", diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go index 0109c3447..917b4a378 100644 --- a/pkg/tools/toolloop.go +++ b/pkg/tools/toolloop.go @@ -116,6 +116,7 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider Name: tc.Name, Arguments: string(argumentsJSON), }, + Name: tc.Name, }) } messages = append(messages, assistantMsg) diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 6a6d40ecf..1f5c58ea5 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -492,8 +492,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string { result = strings.TrimSpace(result) - re = regexp.MustCompile(`\s+`) - result = re.ReplaceAllLiteralString(result, " ") + re = regexp.MustCompile(`[^\S\n]+`) + result = re.ReplaceAllString(result, " ") + re = regexp.MustCompile(`\n{3,}`) + result = re.ReplaceAllString(result, "\n\n") lines := strings.Split(result, "\n") var cleanLines []string diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index a526ea34a..7e6d62213 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -234,6 +234,80 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { } } +// TestWebFetchTool_extractText verifies text extraction preserves newlines +func TestWebFetchTool_extractText(t *testing.T) { + tool := &WebFetchTool{} + + tests := []struct { + name string + input string + wantFunc func(t *testing.T, got string) + }{ + { + name: "preserves newlines between block elements", + input: "

Title

\n

Paragraph 1

\n

Paragraph 2

", + wantFunc: func(t *testing.T, got string) { + lines := strings.Split(got, "\n") + if len(lines) < 2 { + t.Errorf("Expected multiple lines, got %d: %q", len(lines), got) + } + if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") { + t.Errorf("Missing expected text: %q", got) + } + }, + }, + { + name: "removes script and style tags", + input: "

Keep this

", + wantFunc: func(t *testing.T, got string) { + if strings.Contains(got, "alert") || strings.Contains(got, "body{}") { + t.Errorf("Expected script/style content removed, got: %q", got) + } + if !strings.Contains(got, "Keep this") { + t.Errorf("Expected 'Keep this' to remain, got: %q", got) + } + }, + }, + { + name: "collapses excessive blank lines", + input: "

A

\n\n\n\n\n

B

", + wantFunc: func(t *testing.T, got string) { + if strings.Contains(got, "\n\n\n") { + t.Errorf("Expected excessive blank lines collapsed, got: %q", got) + } + }, + }, + { + name: "collapses horizontal whitespace", + input: "

hello world

", + wantFunc: func(t *testing.T, got string) { + if strings.Contains(got, " ") { + t.Errorf("Expected spaces collapsed, got: %q", got) + } + if !strings.Contains(got, "hello world") { + t.Errorf("Expected 'hello world', got: %q", got) + } + }, + }, + { + name: "empty input", + input: "", + wantFunc: func(t *testing.T, got string) { + if got != "" { + t.Errorf("Expected empty string, got: %q", got) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tool.extractText(tt.input) + tt.wantFunc(t, got) + }) + } +} + // TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain func TestWebTool_WebFetch_MissingDomain(t *testing.T) { tool := NewWebFetchTool(50000) diff --git a/pkg/utils/message.go b/pkg/utils/message.go new file mode 100644 index 000000000..1d05950d9 --- /dev/null +++ b/pkg/utils/message.go @@ -0,0 +1,179 @@ +package utils + +import ( + "strings" +) + +// SplitMessage splits long messages into chunks, preserving code block integrity. +// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, +// but may extend to maxLen when needed. +// Call SplitMessage with the full text content and the maximum allowed length of a single message; +// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. +func SplitMessage(content string, maxLen int) []string { + var messages []string + + // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible + codeBlockBuffer := maxLen / 10 + if codeBlockBuffer < 50 { + codeBlockBuffer = 50 + } + if codeBlockBuffer > maxLen/2 { + codeBlockBuffer = maxLen / 2 + } + + for len(content) > 0 { + if len(content) <= maxLen { + messages = append(messages, content) + break + } + + // Effective split point: maxLen minus buffer, to leave room for code blocks + effectiveLimit := maxLen - codeBlockBuffer + if effectiveLimit < maxLen/2 { + effectiveLimit = maxLen / 2 + } + + // Find natural split point within the effective limit + msgEnd := findLastNewline(content[:effectiveLimit], 200) + if msgEnd <= 0 { + msgEnd = findLastSpace(content[:effectiveLimit], 100) + } + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + // Check if this would end with an incomplete code block + candidate := content[:msgEnd] + unclosedIdx := findLastUnclosedCodeBlock(candidate) + + if unclosedIdx >= 0 { + // Message would end with incomplete code block + // Try to extend up to maxLen to include the closing ``` + if len(content) > msgEnd { + closingIdx := findNextClosingCodeBlock(content, msgEnd) + if closingIdx > 0 && closingIdx <= maxLen { + // Extend to include the closing ``` + msgEnd = closingIdx + } else { + // Code block is too long to fit in one chunk or missing closing fence. + // Try to split inside by injecting closing and reopening fences. + headerEnd := strings.Index(content[unclosedIdx:], "\n") + if headerEnd == -1 { + headerEnd = unclosedIdx + 3 + } else { + headerEnd += unclosedIdx + } + header := strings.TrimSpace(content[unclosedIdx:headerEnd]) + + // If we have a reasonable amount of content after the header, split inside + if msgEnd > headerEnd+20 { + // Find a better split point closer to maxLen + innerLimit := maxLen - 5 // Leave room for "\n```" + betterEnd := findLastNewline(content[:innerLimit], 200) + if betterEnd > headerEnd { + msgEnd = betterEnd + } else { + msgEnd = innerLimit + } + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + + // Otherwise, try to split before the code block starts + newEnd := findLastNewline(content[:unclosedIdx], 200) + if newEnd <= 0 { + newEnd = findLastSpace(content[:unclosedIdx], 100) + } + if newEnd > 0 { + msgEnd = newEnd + } else { + // If we can't split before, we MUST split inside (last resort) + if unclosedIdx > 20 { + msgEnd = unclosedIdx + } else { + msgEnd = maxLen - 5 + messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") + content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + continue + } + } + } + } + } + + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + messages = append(messages, content[:msgEnd]) + content = strings.TrimSpace(content[msgEnd:]) + } + + return messages +} + +// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` +// Returns the position of the opening ``` or -1 if all code blocks are complete +func findLastUnclosedCodeBlock(text string) int { + inCodeBlock := false + lastOpenIdx := -1 + + for i := 0; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + // Toggle code block state on each fence + if !inCodeBlock { + // Entering a code block: record this opening fence + lastOpenIdx = i + } + inCodeBlock = !inCodeBlock + i += 2 + } + } + + if inCodeBlock { + return lastOpenIdx + } + return -1 +} + +// findNextClosingCodeBlock finds the next closing ``` starting from a position +// Returns the position after the closing ``` or -1 if not found +func findNextClosingCodeBlock(text string, startIdx int) int { + for i := startIdx; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + return i + 3 + } + } + return -1 +} + +// findLastNewline finds the last newline character within the last N characters +// Returns the position of the newline or -1 if not found +func findLastNewline(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == '\n' { + return i + } + } + return -1 +} + +// findLastSpace finds the last space character within the last N characters +// Returns the position of the space or -1 if not found +func findLastSpace(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == ' ' || s[i] == '\t' { + return i + } + } + return -1 +} diff --git a/pkg/utils/message_test.go b/pkg/utils/message_test.go new file mode 100644 index 000000000..338509437 --- /dev/null +++ b/pkg/utils/message_test.go @@ -0,0 +1,151 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestSplitMessage(t *testing.T) { + longText := strings.Repeat("a", 2500) + longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars + + tests := []struct { + name string + content string + maxLen int + expectChunks int // Check number of chunks + checkContent func(t *testing.T, chunks []string) // Custom validation + }{ + { + name: "Empty message", + content: "", + maxLen: 2000, + expectChunks: 0, + }, + { + name: "Short message fits in one chunk", + content: "Hello world", + maxLen: 2000, + expectChunks: 1, + }, + { + name: "Simple split regular text", + content: longText, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) > 2000 { + t.Errorf("Chunk 0 too large: %d", len(chunks[0])) + } + if len(chunks[0])+len(chunks[1]) != len(longText) { + t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText)) + } + }, + }, + { + name: "Split at newline", + // 1750 chars then newline, then more chars. + // Dynamic buffer: 2000 / 10 = 200. + // Effective limit: 2000 - 200 = 1800. + // Split should happen at newline because it's at 1750 (< 1800). + // Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051. + content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300), + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + if len(chunks[0]) != 1750 { + t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0])) + } + if chunks[1] != strings.Repeat("b", 300) { + t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1])) + } + }, + }, + { + name: "Long code block split", + content: "Prefix\n" + longCode, + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Check that first chunk ends with closing fence + if !strings.HasSuffix(chunks[0], "\n```") { + t.Error("First chunk should end with injected closing fence") + } + // Check that second chunk starts with execution header + if !strings.HasPrefix(chunks[1], "```go") { + t.Error("Second chunk should start with injected code block header") + } + }, + }, + { + name: "Preserve Unicode characters", + content: strings.Repeat("\u4e16", 1000), // 3000 bytes + maxLen: 2000, + expectChunks: 2, + checkContent: func(t *testing.T, chunks []string) { + // Just verify we didn't panic and got valid strings. + // Go strings are UTF-8, if we split mid-rune it would be bad, + // but standard slicing might do that. + // Let's assume standard behavior is acceptable or check if it produces invalid rune? + if !strings.Contains(chunks[0], "\u4e16") { + t.Error("Chunk should contain unicode characters") + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := SplitMessage(tc.content, tc.maxLen) + + if tc.expectChunks == 0 { + if len(got) != 0 { + t.Errorf("Expected 0 chunks, got %d", len(got)) + } + return + } + + if len(got) != tc.expectChunks { + t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got)) + // Log sizes for debugging + for i, c := range got { + t.Logf("Chunk %d length: %d", i, len(c)) + } + return // Stop further checks if count assumes specific split + } + + if tc.checkContent != nil { + tc.checkContent(t, got) + } + }) + } +} + +func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { + // Focused test for the core requirement: splitting inside a code block preserves syntax highlighting + + // 60 chars total approximately + content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```" + maxLen := 40 + + chunks := SplitMessage(content, maxLen) + + if len(chunks) != 2 { + t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks) + } + + // First chunk must end with "\n```" + if !strings.HasSuffix(chunks[0], "\n```") { + t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0]) + } + + // Second chunk must start with the header "```go" + if !strings.HasPrefix(chunks[1], "```go") { + t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1]) + } + + // First chunk should contain meaningful content + if len(chunks[0]) > 40 { + t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0])) + } +}