mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
chore: merge main branch into mcp-tools-support
Resolved conflicts in: - config/config.example.json: Added empty MCP config block - pkg/config/config.go: Added MCP config structures to new ToolsConfig - pkg/agent/loop.go: Integrated MCP tools with new AgentRegistry architecture MCP tools now register to all agents in the registry during startup.
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Report a bug or unexpected behavior
|
||||
title: "[BUG]"
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Quick Summary
|
||||
|
||||
## Environment & Tools
|
||||
- **PicoClaw Version:** (e.g., v0.1.2 or commit hash)
|
||||
- **Go Version:** (e.g., go 1.22)
|
||||
- **AI Model & Provider:** (e.g., GPT-4o via OpenAI / DeepSeek via SiliconFlow)
|
||||
- **Operating System:** (e.g., Ubuntu 22.04 / macOS / Android Termux)
|
||||
- **Channels:** (e.g., Discord, Telegram, Feishu, ...)
|
||||
|
||||
## 📸 Steps to Reproduce
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## ❌ Actual Behavior
|
||||
|
||||
## ✅ Expected Behavior
|
||||
|
||||
## 💬 Additional Context
|
||||
@@ -0,0 +1,23 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest a new idea or improvement
|
||||
title: "[Feature]"
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## 🎯 The Goal / Use Case
|
||||
|
||||
## 💡 Proposed Solution
|
||||
|
||||
## 🛠 Potential Implementation (Optional)
|
||||
|
||||
## 🚦 Impact & Roadmap Alignment
|
||||
- [ ] This is a Core Feature
|
||||
- [ ] This is a Nice-to-Have / Enhancement
|
||||
- [ ] This aligns with the current Roadmap
|
||||
|
||||
## 🔄 Alternatives Considered
|
||||
|
||||
## 💬 Additional Context
|
||||
@@ -0,0 +1,26 @@
|
||||
---
|
||||
name: General Task / Todo
|
||||
about: A specific piece of work like doc, refactoring, or maintenance.
|
||||
title: "[Task]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## 📝 Objective
|
||||
|
||||
## 📋 To-Do List
|
||||
- [ ] Step 1
|
||||
- [ ] Step 2
|
||||
- [ ] Step 3
|
||||
|
||||
## 🎯 Definition of Done (Acceptance Criteria)
|
||||
- [ ] Documentation is updated in the README/docs folder.
|
||||
- [ ] Code follows project linting standards.
|
||||
- [ ] (If applicable) Basic tests pass.
|
||||
|
||||
## 💡 Context / Motivation
|
||||
|
||||
## 🔗 Related Issues / PRs
|
||||
- Fixes #
|
||||
- Relates to #
|
||||
@@ -0,0 +1,43 @@
|
||||
## 📝 Description
|
||||
|
||||
<!-- Please briefly describe the changes and purpose of this PR -->
|
||||
|
||||
## 🗣️ Type of Change
|
||||
- [ ] 🐞 Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] ✨ New feature (non-breaking change which adds functionality)
|
||||
- [ ] 📖 Documentation update
|
||||
- [ ] ⚡ Code refactoring (no functional changes, no api changes)
|
||||
|
||||
## 🤖 AI Code Generation
|
||||
- [ ] 🤖 Fully AI-generated (100% AI, 0% Human)
|
||||
- [ ] 🛠️ Mostly AI-generated (AI draft, Human verified/modified)
|
||||
- [ ] 👨💻 Mostly Human-written (Human lead, AI assisted or none)
|
||||
|
||||
|
||||
## 🔗 Related Issue
|
||||
|
||||
<!-- Please link the related issue(s) (e.g., Fixes #123, Closes #456) -->
|
||||
|
||||
## 📚 Technical Context (Skip for Docs)
|
||||
- **Reference URL:**
|
||||
- **Reasoning:**
|
||||
|
||||
## 🧪 Test Environment
|
||||
- **Hardware:** <!-- e.g. Raspberry Pi 5, Orange Pi, PC-->
|
||||
- **OS:** <!-- e.g. Debian 12, Ubuntu 22.04 -->
|
||||
- **Model/Provider:** <!-- e.g. OpenAI GPT-4o, Kimi k2, DeepSeek-V3 -->
|
||||
- **Channels:** <!-- e.g. Discord, Telegram, Feishu, ... -->
|
||||
|
||||
|
||||
## 📸 Evidence (Optional)
|
||||
<details>
|
||||
<summary>Click to view Logs/Screenshots</summary>
|
||||
|
||||
<!-- Please paste relevant screenshots or logs here -->
|
||||
|
||||
</details>
|
||||
|
||||
## ☑️ Checklist
|
||||
- [ ] My code/docs follow the style of this project.
|
||||
- [ ] I have performed a self-review of my own changes.
|
||||
- [ ] I have updated the documentation accordingly.
|
||||
@@ -9,10 +9,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
steps:
|
||||
# ── Checkout ──────────────────────────────
|
||||
- name: 📥 Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ inputs.tag }}
|
||||
|
||||
|
||||
+34
-10
@@ -1,17 +1,39 @@
|
||||
name: pr-check
|
||||
name: PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
pull_request: { }
|
||||
|
||||
jobs:
|
||||
fmt-check:
|
||||
lint:
|
||||
name: Linter
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run go generate
|
||||
run: go generate ./...
|
||||
|
||||
- name: Golangci Lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.10.1
|
||||
|
||||
# TODO: Remove once linter is properly configured
|
||||
fmt-check:
|
||||
name: Formatting
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
@@ -20,15 +42,17 @@ jobs:
|
||||
make fmt
|
||||
git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
|
||||
|
||||
# TODO: Remove once linter is properly configured
|
||||
vet:
|
||||
name: Vet
|
||||
runs-on: ubuntu-latest
|
||||
needs: fmt-check
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
@@ -39,14 +63,15 @@ jobs:
|
||||
run: go vet ./...
|
||||
|
||||
test:
|
||||
name: Tests
|
||||
runs-on: ubuntu-latest
|
||||
needs: fmt-check
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
@@ -55,4 +80,3 @@ jobs:
|
||||
|
||||
- name: Run go test
|
||||
run: go test ./...
|
||||
|
||||
|
||||
@@ -26,17 +26,19 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Create and push tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ inputs.tag }}
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}"
|
||||
git push origin "${{ inputs.tag }}"
|
||||
git tag -a "$RELEASE_TAG" -m "Release $RELEASE_TAG"
|
||||
git push origin "$RELEASE_TAG"
|
||||
|
||||
release:
|
||||
name: GoReleaser Release
|
||||
@@ -47,13 +49,14 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout tag
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ inputs.tag }}
|
||||
|
||||
- name: Setup Go from go.mod
|
||||
uses: actions/setup-go@v5
|
||||
id: setup-go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
@@ -87,6 +90,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
GOVERSION: ${{ steps.setup-go.outputs.go-version }}
|
||||
|
||||
- name: Apply release flags
|
||||
shell: bash
|
||||
|
||||
+184
@@ -0,0 +1,184 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
default: all
|
||||
disable:
|
||||
# TODO: Tweak for current project needs
|
||||
- containedctx
|
||||
- cyclop
|
||||
- depguard
|
||||
- dupl
|
||||
- dupword
|
||||
- err113
|
||||
- exhaustruct
|
||||
- funcorder
|
||||
- gochecknoglobals
|
||||
- godot
|
||||
- intrange
|
||||
- ireturn
|
||||
- nlreturn
|
||||
- noctx
|
||||
- noinlineerr
|
||||
- nonamedreturns
|
||||
- tagliatelle
|
||||
- testpackage
|
||||
- varnamelen
|
||||
- wrapcheck
|
||||
- wsl
|
||||
- wsl_v5
|
||||
|
||||
# TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step)
|
||||
- bodyclose
|
||||
- contextcheck
|
||||
- dogsled
|
||||
- embeddedstructfieldcheck
|
||||
- errcheck
|
||||
- errchkjson
|
||||
- errorlint
|
||||
- exhaustive
|
||||
- forbidigo
|
||||
- forcetypeassert
|
||||
- funlen
|
||||
- gochecknoinits
|
||||
- gocognit
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- godox
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- govet
|
||||
- ineffassign
|
||||
- lll
|
||||
- maintidx
|
||||
- misspell
|
||||
- mnd
|
||||
- modernize
|
||||
- nakedret
|
||||
- nestif
|
||||
- nilnil
|
||||
- paralleltest
|
||||
- perfsprint
|
||||
- prealloc
|
||||
- predeclared
|
||||
- revive
|
||||
- staticcheck
|
||||
- tagalign
|
||||
- testifylint
|
||||
- thelper
|
||||
- unparam
|
||||
- unused
|
||||
- usestdlibvars
|
||||
- usetesting
|
||||
- wastedassign
|
||||
- whitespace
|
||||
settings:
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
check-blank: true
|
||||
exhaustive:
|
||||
default-signifies-exhaustive: true
|
||||
funlen:
|
||||
lines: 120
|
||||
statements: 40
|
||||
gocognit:
|
||||
min-complexity: 25
|
||||
gocyclo:
|
||||
min-complexity: 20
|
||||
govet:
|
||||
enable-all: true
|
||||
disable:
|
||||
- fieldalignment
|
||||
lll:
|
||||
line-length: 120
|
||||
tab-width: 4
|
||||
misspell:
|
||||
locale: US
|
||||
mnd:
|
||||
checks:
|
||||
- argument
|
||||
- assign
|
||||
- case
|
||||
- condition
|
||||
- operation
|
||||
- return
|
||||
nakedret:
|
||||
max-func-lines: 3
|
||||
revive:
|
||||
enable-all-rules: true
|
||||
rules:
|
||||
- name: add-constant
|
||||
disabled: true
|
||||
- name: argument-limit
|
||||
arguments:
|
||||
- 7
|
||||
severity: warning
|
||||
- name: banned-characters
|
||||
disabled: true
|
||||
- name: cognitive-complexity
|
||||
disabled: true
|
||||
- name: comment-spacings
|
||||
arguments:
|
||||
- nolint
|
||||
severity: warning
|
||||
- name: cyclomatic
|
||||
disabled: true
|
||||
- name: file-header
|
||||
disabled: true
|
||||
- name: function-result-limit
|
||||
arguments:
|
||||
- 3
|
||||
severity: warning
|
||||
- name: function-length
|
||||
disabled: true
|
||||
- name: line-length-limit
|
||||
disabled: true
|
||||
- name: max-public-structs
|
||||
disabled: true
|
||||
- name: modifies-value-receiver
|
||||
disabled: true
|
||||
- name: package-comments
|
||||
disabled: true
|
||||
- name: unused-receiver
|
||||
disabled: true
|
||||
exclusions:
|
||||
generated: lax
|
||||
rules:
|
||||
- linters:
|
||||
- lll
|
||||
source: '^//go:generate '
|
||||
- linters:
|
||||
- funlen
|
||||
- maintidx
|
||||
- gocognit
|
||||
- gocyclo
|
||||
path: _test\.go$
|
||||
|
||||
issues:
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- goimports
|
||||
# TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step)
|
||||
# - gci
|
||||
# - gofmt
|
||||
# - gofumpt
|
||||
# - golines
|
||||
settings:
|
||||
gci:
|
||||
sections:
|
||||
- standard
|
||||
- default
|
||||
- localmodule
|
||||
custom-order: true
|
||||
gofmt:
|
||||
simplify: true
|
||||
rewrite-rules:
|
||||
- pattern: "interface{}"
|
||||
replacement: "any"
|
||||
- pattern: "a[b:len(a)]"
|
||||
replacement: "a[b:]"
|
||||
golines:
|
||||
max-len: 120
|
||||
@@ -11,6 +11,14 @@ builds:
|
||||
- id: picoclaw
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
tags:
|
||||
- stdjson
|
||||
ldflags:
|
||||
- -s -w
|
||||
- -X main.version={{ .Version }}
|
||||
- -X main.gitCommit={{ .ShortCommit }}
|
||||
- -X main.buildTime={{ .Date }}
|
||||
- -X main.goVersion={{ .Env.GOVERSION }}
|
||||
goos:
|
||||
- linux
|
||||
- windows
|
||||
|
||||
+8
-1
@@ -29,7 +29,14 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
# Copy binary
|
||||
COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
|
||||
|
||||
# Create picoclaw home directory
|
||||
# Create non-root user and group
|
||||
RUN addgroup -g 1000 picoclaw && \
|
||||
adduser -D -u 1000 -G picoclaw picoclaw
|
||||
|
||||
# Switch to non-root user
|
||||
USER picoclaw
|
||||
|
||||
# Run onboard to create initial directories and config
|
||||
RUN /usr/local/bin/picoclaw onboard
|
||||
|
||||
ENTRYPOINT ["picoclaw"]
|
||||
|
||||
@@ -11,11 +11,11 @@ VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
|
||||
BUILD_TIME=$(shell date +%FT%T%z)
|
||||
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
|
||||
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
|
||||
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION) -s -w"
|
||||
|
||||
# Go variables
|
||||
GO?=go
|
||||
GOFLAGS?=-v
|
||||
GOFLAGS?=-v -tags stdjson
|
||||
|
||||
# Installation
|
||||
INSTALL_PREFIX?=$(HOME)/.local
|
||||
@@ -39,6 +39,8 @@ ifeq ($(UNAME_S),Linux)
|
||||
ARCH=amd64
|
||||
else ifeq ($(UNAME_M),aarch64)
|
||||
ARCH=arm64
|
||||
else ifeq ($(UNAME_M),loongarch64)
|
||||
ARCH=loong64
|
||||
else ifeq ($(UNAME_M),riscv64)
|
||||
ARCH=riscv64
|
||||
else
|
||||
@@ -84,6 +86,7 @@ build-all: generate
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
|
||||
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
|
||||
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
|
||||
|
||||
+14
-8
@@ -3,7 +3,7 @@
|
||||
|
||||
<h1>PicoClaw: Go で書かれた超効率 AI アシスタント</h1>
|
||||
|
||||
<h3>$10 ハードウェア · 10MB RAM · 1秒起動 · 皮皮虾,我们走!</h3>
|
||||
<h3>$10 ハードウェア · 10MB RAM · 1秒起動 · 行くぜ、シャコ!</h3>
|
||||
<h3></h3>
|
||||
|
||||
<p>
|
||||
@@ -12,7 +12,7 @@
|
||||
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
|
||||
</p>
|
||||
|
||||
**日本語** | [English](README.md)
|
||||
[中文](README.zh.md) | **日本語** | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [English](README.md)
|
||||
|
||||
</div>
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
</table>
|
||||
|
||||
## 📢 ニュース
|
||||
2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 皮皮虾,我们走!
|
||||
2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 行くぜ、シャコ!
|
||||
|
||||
## ✨ 特徴
|
||||
|
||||
@@ -195,6 +195,9 @@ picoclaw onboard
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
@@ -250,7 +253,7 @@ Telegram、Discord、QQ、DingTalk、LINE で PicoClaw と会話できます
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -290,7 +293,7 @@ picoclaw gateway
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -673,7 +676,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "123456:ABC...",
|
||||
"allowFrom": ["123456789"]
|
||||
"allow_from": ["123456789"]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
@@ -689,7 +692,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
"appSecret": "xxx",
|
||||
"encryptKey": "",
|
||||
"verificationToken": "",
|
||||
"allowFrom": []
|
||||
"allow_from": []
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
@@ -697,6 +700,9 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
"search": {
|
||||
"apiKey": "BSA..."
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
@@ -729,7 +735,7 @@ Discord: https://discord.gg/V4sAZ9XWpN
|
||||
|
||||
## 🐛 トラブルシューティング
|
||||
|
||||
### Web 検索で「API 配置问题」と表示される
|
||||
### Web 検索で「API 設定の問題」と表示される
|
||||
|
||||
検索 API キーをまだ設定していない場合、これは正常です。PicoClaw は手動検索用の便利なリンクを提供します。
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
|
||||
</p>
|
||||
|
||||
[中文](README.zh.md) | [日本語](README.ja.md) | **English**
|
||||
[中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | **English**
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -45,8 +45,11 @@
|
||||
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
|
||||
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
|
||||
> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release.
|
||||
> * **Note:** picoclaw has recently merged a lot of PRs, which may result in a larger memory footprint (10–20MB) in the latest versions. We plan to prioritize resource optimization as soon as the current feature set reaches a stable state.
|
||||
|
||||
|
||||
## 📢 News
|
||||
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/picoclaw_community_roadmap_260216.md) —we can’t wait to have you on board!
|
||||
|
||||
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
|
||||
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
|
||||
@@ -96,6 +99,20 @@
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 📱 Run on old Android Phones
|
||||
Give your decade-old phone a second life! Turn it into a smart AI Assistant with PicoClaw. Quick Start:
|
||||
1. **Install Termux** (Available on F-Droid or Google Play).
|
||||
2. **Execute cmds**
|
||||
```bash
|
||||
# Note: Replace v0.1.1 with the latest version from the Releases page
|
||||
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
|
||||
chmod +x picoclaw-linux-arm64
|
||||
pkg install proot
|
||||
termux-chroot ./picoclaw-linux-arm64 onboard
|
||||
```
|
||||
And then follow the instructions in the "Quick Start" section to complete the configuration!
|
||||
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
### 🐜 Innovative Low-Footprint Deploy
|
||||
|
||||
PicoClaw can be deployed on almost any Linux device!
|
||||
@@ -266,7 +283,7 @@ Talk to your picoclaw through Telegram, Discord, DingTalk, or LINE
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -309,7 +326,7 @@ picoclaw gateway
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -662,6 +679,16 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
|
||||
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
|
||||
### Provider Architecture
|
||||
|
||||
PicoClaw routes providers by protocol family:
|
||||
|
||||
- OpenAI-compatible protocol: OpenRouter, OpenAI-compatible gateways, Groq, Zhipu, and vLLM-style endpoints.
|
||||
- Anthropic protocol: Claude-native API behavior.
|
||||
- Codex/OAuth path: OpenAI OAuth/token authentication route.
|
||||
|
||||
This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`).
|
||||
|
||||
<details>
|
||||
<summary><b>Zhipu</b></summary>
|
||||
|
||||
@@ -757,6 +784,9 @@ picoclaw agent -m "Hello"
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
+882
@@ -0,0 +1,882 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
<h1>PicoClaw: Assistente de IA Ultra-Eficiente em Go</h1>
|
||||
|
||||
<h3>Hardware de $10 · 10MB de RAM · Boot em 1s · 皮皮虾,我们走!</h3>
|
||||
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
|
||||
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
|
||||
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
|
||||
<br>
|
||||
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
|
||||
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
|
||||
</p>
|
||||
|
||||
[中文](README.zh.md) | [日本語](README.ja.md) | [English](README.md) | **Português**
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [nanobot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código.
|
||||
|
||||
⚡️ **Extremamente leve:** Roda em hardware de apenas **$10** com **<10MB** de RAM. Isso é 99% menos memória que o OpenClaw e 98% mais barato que um Mac mini!
|
||||
|
||||
<table align="center">
|
||||
<tr align="center">
|
||||
<td align="center" valign="top">
|
||||
<p align="center">
|
||||
<img src="assets/picoclaw_mem.gif" width="360" height="240">
|
||||
</p>
|
||||
</td>
|
||||
<td align="center" valign="top">
|
||||
<p align="center">
|
||||
<img src="assets/licheervnano.png" width="400" height="240">
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
> [!CAUTION]
|
||||
> **🚨 DECLARAÇÃO DE SEGURANÇA & CANAIS OFICIAIS**
|
||||
>
|
||||
> * **SEM CRIPTOMOEDAS:** O PicoClaw **NÃO** possui nenhum token/moeda oficial. Todas as alegações no `pump.fun` ou outras plataformas de negociação são **GOLPES**.
|
||||
> * **DOMÍNIO OFICIAL:** O **ÚNICO** site oficial é o **[picoclaw.io](https://picoclaw.io)**, e o site da empresa é o **[sipeed.com](https://sipeed.com)**.
|
||||
> * **Aviso:** Muitos domínios `.ai/.org/.com/.net/...` foram registrados por terceiros, não são nossos.
|
||||
> * **Aviso:** O PicoClaw está em fase inicial de desenvolvimento e pode ter problemas de segurança de rede não resolvidos. Não implante em ambientes de produção antes da versão v1.0.
|
||||
> * **Nota:** O PicoClaw recentemente fez merge de muitos PRs, o que pode resultar em maior consumo de memória (10-20MB) nas versões mais recentes. Planejamos priorizar a otimização de recursos assim que o conjunto de funcionalidades estiver estável.
|
||||
|
||||
|
||||
## 📢 Novidades
|
||||
|
||||
2026-02-16 🎉 PicoClaw atingiu 12K stars em uma semana! Obrigado a todos pelo apoio! O PicoClaw está crescendo mais rápido do que jamais imaginamos. Dado o alto volume de PRs, precisamos urgentemente de maintainers da comunidade. Nossos papéis de voluntários e roadmap foram publicados oficialmente [aqui](docs/picoclaw_community_roadmap_260216.md) — estamos ansiosos para ter você a bordo!
|
||||
|
||||
2026-02-13 🎉 PicoClaw atingiu 5000 stars em 4 dias! Obrigado à comunidade! Estamos finalizando o **Roadmap do Projeto** e configurando o **Grupo de Desenvolvedores** para acelerar o desenvolvimento do PicoClaw.
|
||||
|
||||
🚀 **Chamada para Ação:** Envie suas solicitações de funcionalidades nas GitHub Discussions. Revisaremos e priorizaremos na próxima reunião semanal.
|
||||
|
||||
2026-02-09 🎉 PicoClaw lançado oficialmente! Construído em 1 dia para trazer Agentes de IA para hardware de $10 com <10MB de RAM. 🦐 PicoClaw, Partiu!
|
||||
|
||||
## ✨ Funcionalidades
|
||||
|
||||
🪶 **Ultra-Leve**: Consumo de memória <10MB — 99% menor que o Clawdbot para funcionalidades essenciais.
|
||||
|
||||
💰 **Custo Mínimo**: Eficiente o suficiente para rodar em hardware de $10 — 98% mais barato que um Mac mini.
|
||||
|
||||
⚡️ **Inicialização Relámpago**: Tempo de inicialização 400X mais rápido, boot em 1 segundo mesmo em CPU single-core de 0.6GHz.
|
||||
|
||||
🌍 **Portabilidade Real**: Um único binário auto-contido para RISC-V, ARM e x86. Um clique e já era!
|
||||
|
||||
🤖 **Auto-Construído por IA**: Implementação nativa em Go de forma autônoma — 95% do núcleo gerado pelo Agente com refinamento humano no loop.
|
||||
|
||||
| | OpenClaw | NanoBot | **PicoClaw** |
|
||||
| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- |
|
||||
| **Linguagem** | TypeScript | Python | **Go** |
|
||||
| **RAM** | >1GB | >100MB | **< 10MB** |
|
||||
| **Inicialização**</br>(CPU 0.8GHz) | >500s | >30s | **<1s** |
|
||||
| **Custo** | Mac Mini $599 | Maioria dos SBC Linux </br>~$50 | **Qualquer placa Linux**</br>**A partir de $10** |
|
||||
|
||||
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
## 🦾 Demonstração
|
||||
|
||||
### 🛠️ Fluxos de Trabalho Padrão do Assistente
|
||||
|
||||
<table align="center">
|
||||
<tr align="center">
|
||||
<th><p align="center">🧩 Engenharia Full-Stack</p></th>
|
||||
<th><p align="center">🗂️ Gerenciamento de Logs & Planejamento</p></th>
|
||||
<th><p align="center">🔎 Busca Web & Aprendizado</p></th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Desenvolver • Implantar • Escalar</td>
|
||||
<td align="center">Agendar • Automatizar • Memorizar</td>
|
||||
<td align="center">Descobrir • Analisar • Tendências</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 📱 Rode em celulares Android antigos
|
||||
|
||||
Dê uma segunda vida ao seu celular de dez anos atrás! Transforme-o em um assistente de IA inteligente com o PicoClaw. Início rápido:
|
||||
|
||||
1. **Instale o Termux** (Disponível no F-Droid ou Google Play).
|
||||
2. **Execute os comandos**
|
||||
|
||||
```bash
|
||||
# Nota: Substitua v0.1.1 pela versao mais recente da pagina de Releases
|
||||
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
|
||||
chmod +x picoclaw-linux-arm64
|
||||
pkg install proot
|
||||
termux-chroot ./picoclaw-linux-arm64 onboard
|
||||
```
|
||||
|
||||
Depois siga as instruções na seção "Início Rápido" para completar a configuração!
|
||||
|
||||
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
### 🐜 Implantação Inovadora com Baixo Consumo
|
||||
|
||||
O PicoClaw pode ser implantado em praticamente qualquer dispositivo Linux!
|
||||
|
||||
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) versão E (Ethernet) ou W (WiFi6), para Assistente Doméstico Minimalista
|
||||
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), ou $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) para Manutenção Automatizada de Servidores
|
||||
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) ou $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) para Monitoramento Inteligente
|
||||
|
||||
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
|
||||
|
||||
🌟 Mais cenários de implantação aguardam você!
|
||||
|
||||
## 📦 Instalação
|
||||
|
||||
### Instalar com binário pré-compilado
|
||||
|
||||
Baixe o binário para sua plataforma na página de [releases](https://github.com/sipeed/picoclaw/releases).
|
||||
|
||||
### Instalar a partir do código-fonte (funcionalidades mais recentes, recomendado para desenvolvimento)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/sipeed/picoclaw.git
|
||||
|
||||
cd picoclaw
|
||||
make deps
|
||||
|
||||
# Build, sem necessidade de instalar
|
||||
make build
|
||||
|
||||
# Build para multiplas plataformas
|
||||
make build-all
|
||||
|
||||
# Build e Instalar
|
||||
make install
|
||||
```
|
||||
|
||||
## 🐳 Docker Compose
|
||||
|
||||
Você tambêm pode rodar o PicoClaw usando Docker Compose sem instalar nada localmente.
|
||||
|
||||
```bash
|
||||
# 1. Clone este repositorio
|
||||
git clone https://github.com/sipeed/picoclaw.git
|
||||
cd picoclaw
|
||||
|
||||
# 2. Configure suas API keys
|
||||
cp config/config.example.json config/config.json
|
||||
vim config/config.json # Configure DISCORD_BOT_TOKEN, API keys, etc.
|
||||
|
||||
# 3. Build & Iniciar
|
||||
docker compose --profile gateway up -d
|
||||
|
||||
# 4. Ver logs
|
||||
docker compose logs -f picoclaw-gateway
|
||||
|
||||
# 5. Parar
|
||||
docker compose --profile gateway down
|
||||
```
|
||||
|
||||
### Modo Agente (Execução única)
|
||||
|
||||
```bash
|
||||
# Fazer uma pergunta
|
||||
docker compose run --rm picoclaw-agent -m "Quanto e 2+2?"
|
||||
|
||||
# Modo interativo
|
||||
docker compose run --rm picoclaw-agent
|
||||
```
|
||||
|
||||
### Rebuild
|
||||
|
||||
```bash
|
||||
docker compose --profile gateway build --no-cache
|
||||
docker compose --profile gateway up -d
|
||||
```
|
||||
|
||||
### 🚀 Início Rápido
|
||||
|
||||
> [!TIP]
|
||||
> Configure sua API key em `~/.picoclaw/config.json`.
|
||||
> Obtenha API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
|
||||
> Busca web e **opcional** — obtenha a [Brave Search API](https://brave.com/search/api) gratuita (2000 consultas grátis/mês) ou use o fallback automático integrado.
|
||||
|
||||
**1. Inicializar**
|
||||
|
||||
```bash
|
||||
picoclaw onboard
|
||||
```
|
||||
|
||||
**2. Configurar** (`~/.picoclaw/config.json`)
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"api_key": "xxx",
|
||||
"api_base": "https://openrouter.ai/api/v1"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Obter API Keys**
|
||||
|
||||
* **Provedor de LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
|
||||
* **Busca Web** (opcional): [Brave Search](https://brave.com/search/api) - Plano gratuito disponível (2000 consultas/mês)
|
||||
|
||||
> **Nota**: Veja `config.example.json` para um modelo de configuração completo.
|
||||
|
||||
**4. Conversar**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "Quanto e 2+2?"
|
||||
```
|
||||
|
||||
Pronto! Você tem um assistente de IA funcionando em 2 minutos.
|
||||
|
||||
---
|
||||
|
||||
## 💬 Integração com Apps de Chat
|
||||
|
||||
Converse com seu PicoClaw via Telegram, Discord, DingTalk ou LINE.
|
||||
|
||||
| Canal | Nível de Configuração |
|
||||
| --- | --- |
|
||||
| **Telegram** | Fácil (apenas um token) |
|
||||
| **Discord** | Fácil (bot token + intents) |
|
||||
| **QQ** | Fácil (AppID + AppSecret) |
|
||||
| **DingTalk** | Médio (credenciais do app) |
|
||||
| **LINE** | Médio (credenciais + webhook URL) |
|
||||
|
||||
<details>
|
||||
<summary><b>Telegram</b> (Recomendado)</summary>
|
||||
|
||||
**1. Criar o bot**
|
||||
|
||||
* Abra o Telegram, busque `@BotFather`
|
||||
* Envie `/newbot`, siga as instruções
|
||||
* Copie o token
|
||||
|
||||
**2. Configurar**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> Obtenha seu User ID pelo `@userinfobot` no Telegram.
|
||||
|
||||
**3. Executar**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Discord</b></summary>
|
||||
|
||||
**1. Criar o bot**
|
||||
|
||||
* Acesse <https://discord.com/developers/applications>
|
||||
* Crie um aplicativo → Bot → Add Bot
|
||||
* Copie o token do bot
|
||||
|
||||
**2. Habilitar Intents**
|
||||
|
||||
* Nas configurações do Bot, habilite **MESSAGE CONTENT INTENT**
|
||||
* (Opcional) Habilite **SERVER MEMBERS INTENT** se quiser usar lista de permissões baseada em dados dos membros
|
||||
|
||||
**3. Obter seu User ID**
|
||||
|
||||
* Configurações do Discord → Avançado → habilite **Modo Desenvolvedor**
|
||||
* Clique com botão direito no seu avatar → **Copiar ID do Usuário**
|
||||
|
||||
**4. Configurar**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**5. Convidar o bot**
|
||||
|
||||
* OAuth2 → URL Generator
|
||||
* Scopes: `bot`
|
||||
* Bot Permissions: `Send Messages`, `Read Message History`
|
||||
* Abra a URL de convite gerada e adicione o bot ao seu servidor
|
||||
|
||||
**6. Executar**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>QQ</b></summary>
|
||||
|
||||
**1. Criar o bot**
|
||||
|
||||
- Acesse a [QQ Open Platform](https://q.qq.com/#)
|
||||
- Crie um aplicativo → Obtenha **AppID** e **AppSecret**
|
||||
|
||||
**2. Configurar**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"qq": {
|
||||
"enabled": true,
|
||||
"app_id": "YOUR_APP_ID",
|
||||
"app_secret": "YOUR_APP_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique números QQ para restringir o acesso.
|
||||
|
||||
**3. Executar**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>DingTalk</b></summary>
|
||||
|
||||
**1. Criar o bot**
|
||||
|
||||
* Acesse a [Open Platform](https://open.dingtalk.com/)
|
||||
* Crie um app interno
|
||||
* Copie o Client ID e Client Secret
|
||||
|
||||
**2. Configurar**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"dingtalk": {
|
||||
"enabled": true,
|
||||
"client_id": "YOUR_CLIENT_ID",
|
||||
"client_secret": "YOUR_CLIENT_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> Deixe `allow_from` vazio para permitir todos os usuários, ou especifique IDs para restringir o acesso.
|
||||
|
||||
**3. Executar**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>LINE</b></summary>
|
||||
|
||||
**1. Criar uma Conta Oficial LINE**
|
||||
|
||||
- Acesse o [LINE Developers Console](https://developers.line.biz/)
|
||||
- Crie um provider → Crie um canal Messaging API
|
||||
- Copie o **Channel Secret** e o **Channel Access Token**
|
||||
|
||||
**2. Configurar**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"line": {
|
||||
"enabled": true,
|
||||
"channel_secret": "YOUR_CHANNEL_SECRET",
|
||||
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
|
||||
"webhook_host": "0.0.0.0",
|
||||
"webhook_port": 18791,
|
||||
"webhook_path": "/webhook/line",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Configurar URL do Webhook**
|
||||
|
||||
O LINE requer HTTPS para webhooks. Use um reverse proxy ou tunnel:
|
||||
|
||||
```bash
|
||||
# Exemplo com ngrok
|
||||
ngrok http 18791
|
||||
```
|
||||
|
||||
Em seguida, configure a Webhook URL no LINE Developers Console para `https://seu-dominio/webhook/line` e habilite **Use webhook**.
|
||||
|
||||
**4. Executar**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
> Em chats de grupo, o bot responde apenas quando mencionado com @. As respostas citam a mensagem original.
|
||||
|
||||
> **Docker Compose**: Adicione `ports: ["18791:18791"]` ao serviço `picoclaw-gateway` para expor a porta do webhook.
|
||||
|
||||
</details>
|
||||
|
||||
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> Junte-se a Rede Social de Agentes
|
||||
|
||||
Conecte o PicoClaw a Rede Social de Agentes simplesmente enviando uma única mensagem via CLI ou qualquer App de Chat integrado.
|
||||
|
||||
**Leia `https://clawdchat.ai/skill.md` e siga as instruções para entrar no [ClawdChat.ai](https://clawdchat.ai)**
|
||||
|
||||
## ⚙️ Configuração Detalhada
|
||||
|
||||
Arquivo de configuração: `~/.picoclaw/config.json`
|
||||
|
||||
### Estrutura do Workspace
|
||||
|
||||
O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/workspace`):
|
||||
|
||||
```
|
||||
~/.picoclaw/workspace/
|
||||
├── sessions/ # Sessoes de conversa e historico
|
||||
├── memory/ # Memoria de longo prazo (MEMORY.md)
|
||||
├── state/ # Estado persistente (ultimo canal, etc.)
|
||||
├── cron/ # Banco de dados de tarefas agendadas
|
||||
├── skills/ # Skills personalizadas
|
||||
├── AGENTS.md # Guia de comportamento do Agente
|
||||
├── HEARTBEAT.md # Prompts de tarefas periodicas (verificado a cada 30 min)
|
||||
├── IDENTITY.md # Identidade do Agente
|
||||
├── SOUL.md # Alma do Agente
|
||||
├── TOOLS.md # Descrição das ferramentas
|
||||
└── USER.md # Preferencias do usuario
|
||||
```
|
||||
|
||||
### 🔒 Sandbox de Segurança
|
||||
|
||||
O PicoClaw roda em um ambiente sandbox por padrão. O agente so pode acessar arquivos e executar comandos dentro do workspace configurado.
|
||||
|
||||
#### Configuração Padrão
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"restrict_to_workspace": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Opção | Padrão | Descrição |
|
||||
|-------|--------|-----------|
|
||||
| `workspace` | `~/.picoclaw/workspace` | Diretório de trabalho do agente |
|
||||
| `restrict_to_workspace` | `true` | Restringir acesso de arquivos/comandos ao workspace |
|
||||
|
||||
#### Ferramentas Protegidas
|
||||
|
||||
Quando `restrict_to_workspace: true`, as seguintes ferramentas são restritas ao sandbox:
|
||||
|
||||
| Ferramenta | Função | Restrição |
|
||||
|------------|--------|-----------|
|
||||
| `read_file` | Ler arquivos | Apenas arquivos dentro do workspace |
|
||||
| `write_file` | Escrever arquivos | Apenas arquivos dentro do workspace |
|
||||
| `list_dir` | Listar diretorios | Apenas diretorios dentro do workspace |
|
||||
| `edit_file` | Editar arquivos | Apenas arquivos dentro do workspace |
|
||||
| `append_file` | Adicionar a arquivos | Apenas arquivos dentro do workspace |
|
||||
| `exec` | Executar comandos | Caminhos dos comandos devem estar dentro do workspace |
|
||||
|
||||
#### Proteção Adicional do Exec
|
||||
|
||||
Mesmo com `restrict_to_workspace: false`, a ferramenta `exec` bloqueia estes comandos perigosos:
|
||||
|
||||
* `rm -rf`, `del /f`, `rmdir /s` — Exclusão em massa
|
||||
* `format`, `mkfs`, `diskpart` — Formatação de disco
|
||||
* `dd if=` — Criação de imagem de disco
|
||||
* Escrita em `/dev/sd[a-z]` — Escrita direta no disco
|
||||
* `shutdown`, `reboot`, `poweroff` — Desligamento do sistema
|
||||
* Fork bomb `:(){ :|:& };:`
|
||||
|
||||
#### Exemplos de Erro
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
|
||||
```
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
|
||||
```
|
||||
|
||||
#### Desabilitar Restrições (Risco de Segurança)
|
||||
|
||||
Se você precisa que o agente acesse caminhos fora do workspace:
|
||||
|
||||
**Método 1: Arquivo de configuração**
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"restrict_to_workspace": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Método 2: Variável de ambiente**
|
||||
|
||||
```bash
|
||||
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
|
||||
```
|
||||
|
||||
> ⚠️ **Aviso**: Desabilitar esta restrição permite que o agente acesse qualquer caminho no seu sistema. Use com cuidado apenas em ambientes controlados.
|
||||
|
||||
#### Consistência do Limite de Segurança
|
||||
|
||||
A configuração `restrict_to_workspace` se aplica consistentemente em todos os caminhos de execução:
|
||||
|
||||
| Caminho de Execução | Limite de Segurança |
|
||||
|----------------------|---------------------|
|
||||
| Agente Principal | `restrict_to_workspace` ✅ |
|
||||
| Subagente / Spawn | Herda a mesma restrição ✅ |
|
||||
| Tarefas Heartbeat | Herda a mesma restrição ✅ |
|
||||
|
||||
Todos os caminhos compartilham a mesma restrição de workspace — nao há como contornar o limite de segurança por meio de subagentes ou tarefas agendadas.
|
||||
|
||||
### Heartbeat (Tarefas Periódicas)
|
||||
|
||||
O PicoClaw pode executar tarefas periódicas automaticamente. Crie um arquivo `HEARTBEAT.md` no seu workspace:
|
||||
|
||||
```markdown
|
||||
# Tarefas Periodicas
|
||||
|
||||
- Verificar meu email para mensagens importantes
|
||||
- Revisar minha agenda para proximos eventos
|
||||
- Verificar a previsao do tempo
|
||||
```
|
||||
|
||||
O agente lerá este arquivo a cada 30 minutos (configurável) e executará as tarefas usando as ferramentas disponíveis.
|
||||
|
||||
#### Tarefas Assincronas com Spawn
|
||||
|
||||
Para tarefas de longa duração (busca web, chamadas de API), use a ferramenta `spawn` para criar um **subagente**:
|
||||
|
||||
```markdown
|
||||
# Tarefas Periódicas
|
||||
|
||||
## Tarefas Rápidas (resposta direta)
|
||||
- Informar hora atual
|
||||
|
||||
## Tarefas Longas (usar spawn para async)
|
||||
- Buscar notícias de IA na web e resumir
|
||||
- Verificar email e reportar mensagens importantes
|
||||
```
|
||||
|
||||
**Comportamentos principais:**
|
||||
|
||||
| Funcionalidade | Descrição |
|
||||
|----------------|-----------|
|
||||
| **spawn** | Cria subagente assíncrono, não bloqueia o heartbeat |
|
||||
| **Contexto independente** | Subagente tem seu próprio contexto, sem histórico de sessão |
|
||||
| **Ferramenta message** | Subagente se comunica diretamente com o usuário via ferramenta message |
|
||||
| **Não-bloqueante** | Após o spawn, o heartbeat continua para a próxima tarefa |
|
||||
|
||||
#### Como Funciona a Comunicação do Subagente
|
||||
|
||||
```
|
||||
Heartbeat dispara
|
||||
↓
|
||||
Agente lê HEARTBEAT.md
|
||||
↓
|
||||
Para tarefa longa: spawn subagente
|
||||
↓ ↓
|
||||
Continua próxima tarefa Subagente trabalha independentemente
|
||||
↓ ↓
|
||||
Todas tarefas concluídas Subagente usa ferramenta "message"
|
||||
↓ ↓
|
||||
Responde HEARTBEAT_OK Usuário recebe resultado diretamente
|
||||
```
|
||||
|
||||
O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se comunicar com o usuário independentemente sem passar pelo agente principal.
|
||||
|
||||
**Configuração:**
|
||||
|
||||
```json
|
||||
{
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Opção | Padrão | Descrição |
|
||||
|-------|--------|-----------|
|
||||
| `enabled` | `true` | Habilitar/desabilitar heartbeat |
|
||||
| `interval` | `30` | Intervalo de verificação em minutos (min: 5) |
|
||||
|
||||
**Variáveis de ambiente:**
|
||||
|
||||
* `PICOCLAW_HEARTBEAT_ENABLED=false` para desabilitar
|
||||
* `PICOCLAW_HEARTBEAT_INTERVAL=60` para alterar o intervalo
|
||||
|
||||
### Provedores
|
||||
|
||||
> [!NOTE]
|
||||
> O Groq fornece transcrição de voz gratuita via Whisper. Se configurado, mensagens de voz do Telegram serão automaticamente transcritas.
|
||||
|
||||
| Provedor | Finalidade | Obter API Key |
|
||||
| --- | --- | --- |
|
||||
| `gemini` | LLM (Gemini direto) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `zhipu` | LLM (Zhipu direto) | [bigmodel.cn](bigmodel.cn) |
|
||||
| `openrouter` (Em teste) | LLM (recomendado, acesso a todos os modelos) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `anthropic` (Em teste) | LLM (Claude direto) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `openai` (Em teste) | LLM (GPT direto) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek` (Em teste) | LLM (DeepSeek direto) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Transcrição de voz** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
|
||||
<details>
|
||||
<summary><b>Configuração Zhipu</b></summary>
|
||||
|
||||
**1. Obter API key**
|
||||
|
||||
* Obtenha a [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
||||
|
||||
**2. Configurar**
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"zhipu": {
|
||||
"api_key": "Sua API Key",
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Executar**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "Ola, como vai?"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Exemplo de configuraçao completa</b></summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "anthropic/claude-opus-4-5"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"api_key": "sk-or-v1-xxx"
|
||||
},
|
||||
"groq": {
|
||||
"api_key": "gsk_xxx"
|
||||
}
|
||||
},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "123456:ABC...",
|
||||
"allow_from": ["123456789"]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "",
|
||||
"allow_from": [""]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": false
|
||||
},
|
||||
"feishu": {
|
||||
"enabled": false,
|
||||
"app_id": "cli_xxx",
|
||||
"app_secret": "xxx",
|
||||
"encrypt_key": "",
|
||||
"verification_token": "",
|
||||
"allow_from": []
|
||||
},
|
||||
"qq": {
|
||||
"enabled": false,
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"allow_from": []
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "BSA...",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Referência CLI
|
||||
|
||||
| Comando | Descrição |
|
||||
| --- | --- |
|
||||
| `picoclaw onboard` | Inicializar configuração & workspace |
|
||||
| `picoclaw agent -m "..."` | Conversar com o agente |
|
||||
| `picoclaw agent` | Modo de chat interativo |
|
||||
| `picoclaw gateway` | Iniciar o gateway (para bots de chat) |
|
||||
| `picoclaw status` | Mostrar status |
|
||||
| `picoclaw cron list` | Listar todas as tarefas agendadas |
|
||||
| `picoclaw cron add ...` | Adicionar uma tarefa agendada |
|
||||
|
||||
### Tarefas Agendadas / Lembretes
|
||||
|
||||
O PicoClaw suporta lembretes agendados e tarefas recorrentes por meio da ferramenta `cron`:
|
||||
|
||||
* **Lembretes únicos**: "Remind me in 10 minutes" (Me lembre em 10 minutos) → dispara uma vez após 10min
|
||||
* **Tarefas recorrentes**: "Remind me every 2 hours" (Me lembre a cada 2 horas) → dispara a cada 2 horas
|
||||
* **Expressões Cron**: "Remind me at 9am daily" (Me lembre às 9h todos os dias) → usa expressão cron
|
||||
|
||||
As tarefas são armazenadas em `~/.picoclaw/workspace/cron/` e processadas automaticamente.
|
||||
|
||||
## 🤝 Contribuir & Roadmap
|
||||
|
||||
PRs são bem-vindos! O código-fonte é intencionalmente pequeno e legível. 🤗
|
||||
|
||||
Roadmap em breve...
|
||||
|
||||
Grupo de desenvolvedores em formação. Requisito de entrada: Pelo menos 1 PR com merge.
|
||||
|
||||
Grupos de usuários:
|
||||
|
||||
Discord: <https://discord.gg/V4sAZ9XWpN>
|
||||
|
||||
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
||||
|
||||
## 🐛 Solução de Problemas
|
||||
|
||||
### Busca web mostra "API 配置问题"
|
||||
|
||||
Isso é normal se você ainda não configurou uma API key de busca. O PicoClaw fornecerá links úteis para busca manual.
|
||||
|
||||
Para habilitar a busca web:
|
||||
|
||||
1. **Opção 1 (Recomendado)**: Obtenha uma API key gratuita em [https://brave.com/search/api](https://brave.com/search/api) (2000 consultas grátis/mês) para os melhores resultados.
|
||||
2. **Opção 2 (Sem Cartão de Crédito)**: Se você não tem uma key, o sistema automaticamente usa o **DuckDuckGo** como fallback (sem necessidade de key).
|
||||
|
||||
Adicione a key em `~/.picoclaw/config.json` se usar o Brave:
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": true,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Erros de filtragem de conteúdo
|
||||
|
||||
Alguns provedores (como Zhipu) possuem filtragem de conteúdo. Tente reformular sua pergunta ou use um modelo diferente.
|
||||
|
||||
### Bot do Telegram diz "Conflict: terminated by other getUpdates"
|
||||
|
||||
Isso acontece quando outra instância do bot está em execução. Certifique-se de que apenas um `picoclaw gateway` esteja rodando por vez.
|
||||
|
||||
---
|
||||
|
||||
## 📝 Comparação de API Keys
|
||||
|
||||
| Serviço | Plano Gratuito | Caso de Uso |
|
||||
| --- | --- | --- |
|
||||
| **OpenRouter** | 200K tokens/mês | Múltiplos modelos (Claude, GPT-4, etc.) |
|
||||
| **Zhipu** | 200K tokens/mês | Melhor para usuários chineses |
|
||||
| **Brave Search** | 2000 consultas/mês | Funcionalidade de busca web |
|
||||
| **Groq** | Plano gratuito disponível | Inferência ultra-rápida (Llama, Mixtral) |
|
||||
+859
@@ -0,0 +1,859 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
<h1>PicoClaw: Trợ lý AI Siêu Nhẹ viết bằng Go</h1>
|
||||
|
||||
<h3>Phần cứng $10 · RAM 10MB · Khởi động 1 giây · 皮皮虾,我们走!</h3>
|
||||
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
|
||||
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
|
||||
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
|
||||
<br>
|
||||
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
|
||||
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
|
||||
</p>
|
||||
|
||||
**Tiếng Việt** | [中文](README.zh.md) | [日本語](README.ja.md) | [English](README.md)
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
🦐 **PicoClaw** là trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [nanobot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn.
|
||||
|
||||
⚡️ **Cực kỳ nhẹ:** Chạy trên phần cứng chỉ **$10** với RAM **<10MB**. Tiết kiệm 99% bộ nhớ so với OpenClaw và rẻ hơn 98% so với Mac mini!
|
||||
|
||||
<table align="center">
|
||||
<tr align="center">
|
||||
<td align="center" valign="top">
|
||||
<p align="center">
|
||||
<img src="assets/picoclaw_mem.gif" width="360" height="240">
|
||||
</p>
|
||||
</td>
|
||||
<td align="center" valign="top">
|
||||
<p align="center">
|
||||
<img src="assets/licheervnano.png" width="400" height="240">
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
> [!CAUTION]
|
||||
> **🚨 TUYÊN BỐ BẢO MẬT & KÊNH CHÍNH THỨC**
|
||||
>
|
||||
> * **KHÔNG CÓ CRYPTO:** PicoClaw **KHÔNG** có bất kỳ token/coin chính thức nào. Mọi thông tin trên `pump.fun` hoặc các sàn giao dịch khác đều là **LỪA ĐẢO**.
|
||||
> * **DOMAIN CHÍNH THỨC:** Website chính thức **DUY NHẤT** là **[picoclaw.io](https://picoclaw.io)**, website công ty là **[sipeed.com](https://sipeed.com)**.
|
||||
> * **Cảnh báo:** Nhiều tên miền `.ai/.org/.com/.net/...` đã bị bên thứ ba đăng ký, không phải của chúng tôi.
|
||||
> * **Cảnh báo:** PicoClaw đang trong giai đoạn phát triển sớm và có thể còn các vấn đề bảo mật mạng chưa được giải quyết. Không nên triển khai lên môi trường production trước phiên bản v1.0.
|
||||
> * **Lưu ý:** PicoClaw gần đây đã merge nhiều PR, dẫn đến bộ nhớ sử dụng có thể lớn hơn (10–20MB) ở các phiên bản mới nhất. Chúng tôi sẽ ưu tiên tối ưu tài nguyên khi bộ tính năng đã ổn định.
|
||||
|
||||
|
||||
## 📢 Tin tức
|
||||
|
||||
2026-02-16 🎉 PicoClaw đạt 12K stars chỉ trong một tuần! Cảm ơn tất cả mọi người! PicoClaw đang phát triển nhanh hơn chúng tôi tưởng tượng. Do số lượng PR tăng cao, chúng tôi cấp thiết cần maintainer từ cộng đồng. Các vai trò tình nguyện viên và roadmap đã được công bố [tại đây](docs/picoclaw_community_roadmap_260216.md) — rất mong đón nhận sự tham gia của bạn!
|
||||
|
||||
2026-02-13 🎉 PicoClaw đạt 5000 stars trong 4 ngày! Cảm ơn cộng đồng! Chúng tôi đang hoàn thiện **Lộ trình dự án (Roadmap)** và thiết lập **Nhóm phát triển** để đẩy nhanh tốc độ phát triển PicoClaw.
|
||||
🚀 **Kêu gọi hành động:** Vui lòng gửi yêu cầu tính năng tại GitHub Discussions. Chúng tôi sẽ xem xét và ưu tiên trong cuộc họp hàng tuần.
|
||||
|
||||
2026-02-09 🎉 PicoClaw chính thức ra mắt! Được xây dựng trong 1 ngày để mang AI Agent đến phần cứng $10 với RAM <10MB. 🦐 PicoClaw, Lên Đường!
|
||||
|
||||
## ✨ Tính năng nổi bật
|
||||
|
||||
🪶 **Siêu nhẹ**: Bộ nhớ sử dụng <10MB — nhỏ hơn 99% so với Clawdbot (chức năng cốt lõi).
|
||||
|
||||
💰 **Chi phí tối thiểu**: Đủ hiệu quả để chạy trên phần cứng $10 — rẻ hơn 98% so với Mac mini.
|
||||
|
||||
⚡️ **Khởi động siêu nhanh**: Nhanh gấp 400 lần, khởi động trong 1 giây ngay cả trên CPU đơn nhân 0.6GHz.
|
||||
|
||||
🌍 **Di động thực sự**: Một file binary duy nhất chạy trên RISC-V, ARM và x86. Một click là chạy!
|
||||
|
||||
🤖 **AI tự xây dựng**: Triển khai Go-native tự động — 95% mã nguồn cốt lõi được Agent tạo ra, với sự tinh chỉnh của con người.
|
||||
|
||||
| | OpenClaw | NanoBot | **PicoClaw** |
|
||||
| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- |
|
||||
| **Ngôn ngữ** | TypeScript | Python | **Go** |
|
||||
| **RAM** | >1GB | >100MB | **< 10MB** |
|
||||
| **Thời gian khởi động**</br>(CPU 0.8GHz) | >500s | >30s | **<1s** |
|
||||
| **Chi phí** | Mac Mini $599 | Hầu hết SBC Linux ~$50 | **Mọi bo mạch Linux**</br>**Chỉ từ $10** |
|
||||
|
||||
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
## 🦾 Demo
|
||||
|
||||
### 🛠️ Quy trình trợ lý tiêu chuẩn
|
||||
|
||||
<table align="center">
|
||||
<tr align="center">
|
||||
<th><p align="center">🧩 Lập trình Full-Stack</p></th>
|
||||
<th><p align="center">🗂️ Quản lý Nhật ký & Kế hoạch</p></th>
|
||||
<th><p align="center">🔎 Tìm kiếm Web & Học hỏi</p></th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Phát triển • Triển khai • Mở rộng</td>
|
||||
<td align="center">Lên lịch • Tự động hóa • Ghi nhớ</td>
|
||||
<td align="center">Khám phá • Phân tích • Xu hướng</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 🐜 Triển khai sáng tạo trên phần cứng tối thiểu
|
||||
|
||||
PicoClaw có thể triển khai trên hầu hết mọi thiết bị Linux!
|
||||
|
||||
* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) phiên bản E (Ethernet) hoặc W (WiFi6), dùng làm Trợ lý Gia đình tối giản.
|
||||
* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), hoặc $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html), dùng cho quản trị Server tự động.
|
||||
* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) hoặc $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera), dùng cho Giám sát thông minh.
|
||||
|
||||
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
|
||||
|
||||
🌟 Nhiều hình thức triển khai hơn đang chờ bạn khám phá!
|
||||
|
||||
## 📦 Cài đặt
|
||||
|
||||
### Cài đặt bằng binary biên dịch sẵn
|
||||
|
||||
Tải file binary cho nền tảng của bạn từ [trang Release](https://github.com/sipeed/picoclaw/releases).
|
||||
|
||||
### Cài đặt từ mã nguồn (có tính năng mới nhất, khuyên dùng cho phát triển)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/sipeed/picoclaw.git
|
||||
|
||||
cd picoclaw
|
||||
make deps
|
||||
|
||||
# Build (không cần cài đặt)
|
||||
make build
|
||||
|
||||
# Build cho nhiều nền tảng
|
||||
make build-all
|
||||
|
||||
# Build và cài đặt
|
||||
make install
|
||||
```
|
||||
|
||||
## 🐳 Docker Compose
|
||||
|
||||
Bạn cũng có thể chạy PicoClaw bằng Docker Compose mà không cần cài đặt gì trên máy.
|
||||
|
||||
```bash
|
||||
# 1. Clone repo
|
||||
git clone https://github.com/sipeed/picoclaw.git
|
||||
cd picoclaw
|
||||
|
||||
# 2. Thiết lập API Key
|
||||
cp config/config.example.json config/config.json
|
||||
vim config/config.json # Thiết lập DISCORD_BOT_TOKEN, API keys, v.v.
|
||||
|
||||
# 3. Build & Khởi động
|
||||
docker compose --profile gateway up -d
|
||||
|
||||
# 4. Xem logs
|
||||
docker compose logs -f picoclaw-gateway
|
||||
|
||||
# 5. Dừng
|
||||
docker compose --profile gateway down
|
||||
```
|
||||
|
||||
### Chế độ Agent (chạy một lần)
|
||||
|
||||
```bash
|
||||
# Đặt câu hỏi
|
||||
docker compose run --rm picoclaw-agent -m "2+2 bằng mấy?"
|
||||
|
||||
# Chế độ tương tác
|
||||
docker compose run --rm picoclaw-agent
|
||||
```
|
||||
|
||||
### Build lại
|
||||
|
||||
```bash
|
||||
docker compose --profile gateway build --no-cache
|
||||
docker compose --profile gateway up -d
|
||||
```
|
||||
|
||||
### 🚀 Bắt đầu nhanh
|
||||
|
||||
> [!TIP]
|
||||
> Thiết lập API key trong `~/.picoclaw/config.json`.
|
||||
> Lấy API key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
|
||||
> Tìm kiếm web là **tùy chọn** — lấy [Brave Search API](https://brave.com/search/api) miễn phí (2000 truy vấn/tháng) hoặc dùng tính năng auto fallback tích hợp sẵn.
|
||||
|
||||
**1. Khởi tạo**
|
||||
|
||||
```bash
|
||||
picoclaw onboard
|
||||
```
|
||||
|
||||
**2. Cấu hình** (`~/.picoclaw/config.json`)
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"api_key": "xxx",
|
||||
"api_base": "https://openrouter.ai/api/v1"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Lấy API Key**
|
||||
|
||||
* **Nhà cung cấp LLM**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
|
||||
* **Tìm kiếm Web** (tùy chọn): [Brave Search](https://brave.com/search/api) — Có gói miễn phí (2000 truy vấn/tháng)
|
||||
|
||||
> **Lưu ý**: Xem `config.example.json` để có mẫu cấu hình đầy đủ.
|
||||
|
||||
**4. Trò chuyện**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "Xin chào, bạn là ai?"
|
||||
```
|
||||
|
||||
Vậy là xong! Bạn đã có một trợ lý AI hoạt động chỉ trong 2 phút.
|
||||
|
||||
---
|
||||
|
||||
## 💬 Tích hợp ứng dụng Chat
|
||||
|
||||
Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk hoặc LINE.
|
||||
|
||||
| Kênh | Mức độ thiết lập |
|
||||
| --- | --- |
|
||||
| **Telegram** | Dễ (chỉ cần token) |
|
||||
| **Discord** | Dễ (bot token + intents) |
|
||||
| **QQ** | Dễ (AppID + AppSecret) |
|
||||
| **DingTalk** | Trung bình (app credentials) |
|
||||
| **LINE** | Trung bình (credentials + webhook URL) |
|
||||
|
||||
<details>
|
||||
<summary><b>Telegram</b> (Khuyên dùng)</summary>
|
||||
|
||||
**1. Tạo bot**
|
||||
|
||||
* Mở Telegram, tìm `@BotFather`
|
||||
* Gửi `/newbot`, làm theo hướng dẫn
|
||||
* Sao chép token
|
||||
|
||||
**2. Cấu hình**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> Lấy User ID từ `@userinfobot` trên Telegram.
|
||||
|
||||
**3. Chạy**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Discord</b></summary>
|
||||
|
||||
**1. Tạo bot**
|
||||
|
||||
* Truy cập <https://discord.com/developers/applications>
|
||||
* Create an application → Bot → Add Bot
|
||||
* Sao chép bot token
|
||||
|
||||
**2. Bật Intents**
|
||||
|
||||
* Trong phần Bot settings, bật **MESSAGE CONTENT INTENT**
|
||||
* (Tùy chọn) Bật **SERVER MEMBERS INTENT** nếu muốn dùng danh sách cho phép theo thông tin thành viên
|
||||
|
||||
**3. Lấy User ID**
|
||||
|
||||
* Discord Settings → Advanced → bật **Developer Mode**
|
||||
* Click chuột phải vào avatar → **Copy User ID**
|
||||
|
||||
**4. Cấu hình**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**5. Mời bot vào server**
|
||||
|
||||
* OAuth2 → URL Generator
|
||||
* Scopes: `bot`
|
||||
* Bot Permissions: `Send Messages`, `Read Message History`
|
||||
* Mở URL mời được tạo và thêm bot vào server của bạn
|
||||
|
||||
**6. Chạy**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>QQ</b></summary>
|
||||
|
||||
**1. Tạo bot**
|
||||
|
||||
* Truy cập [QQ Open Platform](https://q.qq.com/#)
|
||||
* Tạo ứng dụng → Lấy **AppID** và **AppSecret**
|
||||
|
||||
**2. Cấu hình**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"qq": {
|
||||
"enabled": true,
|
||||
"app_id": "YOUR_APP_ID",
|
||||
"app_secret": "YOUR_APP_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định số QQ để giới hạn quyền truy cập.
|
||||
|
||||
**3. Chạy**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>DingTalk</b></summary>
|
||||
|
||||
**1. Tạo bot**
|
||||
|
||||
* Truy cập [Open Platform](https://open.dingtalk.com/)
|
||||
* Tạo ứng dụng nội bộ
|
||||
* Sao chép Client ID và Client Secret
|
||||
|
||||
**2. Cấu hình**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"dingtalk": {
|
||||
"enabled": true,
|
||||
"client_id": "YOUR_CLIENT_ID",
|
||||
"client_secret": "YOUR_CLIENT_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> Để `allow_from` trống để cho phép tất cả người dùng, hoặc chỉ định ID để giới hạn quyền truy cập.
|
||||
|
||||
**3. Chạy**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>LINE</b></summary>
|
||||
|
||||
**1. Tạo tài khoản LINE Official**
|
||||
|
||||
- Truy cập [LINE Developers Console](https://developers.line.biz/)
|
||||
- Tạo provider → Tạo Messaging API channel
|
||||
- Sao chép **Channel Secret** và **Channel Access Token**
|
||||
|
||||
**2. Cấu hình**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"line": {
|
||||
"enabled": true,
|
||||
"channel_secret": "YOUR_CHANNEL_SECRET",
|
||||
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
|
||||
"webhook_host": "0.0.0.0",
|
||||
"webhook_port": 18791,
|
||||
"webhook_path": "/webhook/line",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Thiết lập Webhook URL**
|
||||
|
||||
LINE yêu cầu HTTPS cho webhook. Sử dụng reverse proxy hoặc tunnel:
|
||||
|
||||
```bash
|
||||
# Ví dụ với ngrok
|
||||
ngrok http 18791
|
||||
```
|
||||
|
||||
Sau đó cài đặt Webhook URL trong LINE Developers Console thành `https://your-domain/webhook/line` và bật **Use webhook**.
|
||||
|
||||
**4. Chạy**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
> Trong nhóm chat, bot chỉ phản hồi khi được @mention. Các câu trả lời sẽ trích dẫn tin nhắn gốc.
|
||||
|
||||
> **Docker Compose**: Thêm `ports: ["18791:18791"]` vào service `picoclaw-gateway` để mở port webhook.
|
||||
|
||||
</details>
|
||||
|
||||
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> Tham gia Mạng xã hội Agent
|
||||
|
||||
Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một tin nhắn qua CLI hoặc bất kỳ ứng dụng Chat nào đã tích hợp.
|
||||
|
||||
**Đọc `https://clawdchat.ai/skill.md` và làm theo hướng dẫn để tham gia [ClawdChat.ai](https://clawdchat.ai)**
|
||||
|
||||
## ⚙️ Cấu hình chi tiết
|
||||
|
||||
File cấu hình: `~/.picoclaw/config.json`
|
||||
|
||||
### Cấu trúc Workspace
|
||||
|
||||
PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định: `~/.picoclaw/workspace`):
|
||||
|
||||
```
|
||||
~/.picoclaw/workspace/
|
||||
├── sessions/ # Phiên hội thoại và lịch sử
|
||||
├── memory/ # Bộ nhớ dài hạn (MEMORY.md)
|
||||
├── state/ # Trạng thái lưu trữ (kênh cuối cùng, v.v.)
|
||||
├── cron/ # Cơ sở dữ liệu tác vụ định kỳ
|
||||
├── skills/ # Kỹ năng tùy chỉnh
|
||||
├── AGENTS.md # Hướng dẫn hành vi Agent
|
||||
├── HEARTBEAT.md # Prompt tác vụ định kỳ (kiểm tra mỗi 30 phút)
|
||||
├── IDENTITY.md # Danh tính Agent
|
||||
├── SOUL.md # Tâm hồn/Tính cách Agent
|
||||
├── TOOLS.md # Mô tả công cụ
|
||||
└── USER.md # Tùy chọn người dùng
|
||||
```
|
||||
|
||||
### 🔒 Hộp cát bảo mật (Security Sandbox)
|
||||
|
||||
PicoClaw chạy trong môi trường sandbox theo mặc định. Agent chỉ có thể truy cập file và thực thi lệnh trong phạm vi workspace.
|
||||
|
||||
#### Cấu hình mặc định
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"restrict_to_workspace": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Tùy chọn | Mặc định | Mô tả |
|
||||
|----------|---------|-------|
|
||||
| `workspace` | `~/.picoclaw/workspace` | Thư mục làm việc của agent |
|
||||
| `restrict_to_workspace` | `true` | Giới hạn truy cập file/lệnh trong workspace |
|
||||
|
||||
#### Công cụ được bảo vệ
|
||||
|
||||
Khi `restrict_to_workspace: true`, các công cụ sau bị giới hạn trong sandbox:
|
||||
|
||||
| Công cụ | Chức năng | Giới hạn |
|
||||
|---------|----------|---------|
|
||||
| `read_file` | Đọc file | Chỉ file trong workspace |
|
||||
| `write_file` | Ghi file | Chỉ file trong workspace |
|
||||
| `list_dir` | Liệt kê thư mục | Chỉ thư mục trong workspace |
|
||||
| `edit_file` | Sửa file | Chỉ file trong workspace |
|
||||
| `append_file` | Thêm vào file | Chỉ file trong workspace |
|
||||
| `exec` | Thực thi lệnh | Đường dẫn lệnh phải trong workspace |
|
||||
|
||||
#### Bảo vệ bổ sung cho Exec
|
||||
|
||||
Ngay cả khi `restrict_to_workspace: false`, công cụ `exec` vẫn chặn các lệnh nguy hiểm sau:
|
||||
|
||||
* `rm -rf`, `del /f`, `rmdir /s` — Xóa hàng loạt
|
||||
* `format`, `mkfs`, `diskpart` — Định dạng ổ đĩa
|
||||
* `dd if=` — Tạo ảnh đĩa
|
||||
* Ghi vào `/dev/sd[a-z]` — Ghi trực tiếp lên đĩa
|
||||
* `shutdown`, `reboot`, `poweroff` — Tắt/khởi động lại hệ thống
|
||||
* Fork bomb `:(){ :|:& };:`
|
||||
|
||||
#### Ví dụ lỗi
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
|
||||
```
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
|
||||
```
|
||||
|
||||
#### Tắt giới hạn (Rủi ro bảo mật)
|
||||
|
||||
Nếu bạn cần agent truy cập đường dẫn ngoài workspace:
|
||||
|
||||
**Cách 1: File cấu hình**
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"restrict_to_workspace": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Cách 2: Biến môi trường**
|
||||
|
||||
```bash
|
||||
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
|
||||
```
|
||||
|
||||
> ⚠️ **Cảnh báo**: Tắt giới hạn này cho phép agent truy cập mọi đường dẫn trên hệ thống. Chỉ sử dụng cẩn thận trong môi trường được kiểm soát.
|
||||
|
||||
#### Tính nhất quán của ranh giới bảo mật
|
||||
|
||||
Cài đặt `restrict_to_workspace` áp dụng nhất quán trên mọi đường thực thi:
|
||||
|
||||
| Đường thực thi | Ranh giới bảo mật |
|
||||
|----------------|-------------------|
|
||||
| Agent chính | `restrict_to_workspace` ✅ |
|
||||
| Subagent / Spawn | Kế thừa cùng giới hạn ✅ |
|
||||
| Tác vụ Heartbeat | Kế thừa cùng giới hạn ✅ |
|
||||
|
||||
Tất cả đường thực thi chia sẻ cùng giới hạn workspace — không có cách nào vượt qua ranh giới bảo mật thông qua subagent hoặc tác vụ định kỳ.
|
||||
|
||||
### Heartbeat (Tác vụ định kỳ)
|
||||
|
||||
PicoClaw có thể tự động thực hiện các tác vụ định kỳ. Tạo file `HEARTBEAT.md` trong workspace:
|
||||
|
||||
```markdown
|
||||
# Tác vụ định kỳ
|
||||
|
||||
- Kiểm tra email xem có tin nhắn quan trọng không
|
||||
- Xem lại lịch cho các sự kiện sắp tới
|
||||
- Kiểm tra dự báo thời tiết
|
||||
```
|
||||
|
||||
Agent sẽ đọc file này mỗi 30 phút (có thể cấu hình) và thực hiện các tác vụ bằng công cụ có sẵn.
|
||||
|
||||
#### Tác vụ bất đồng bộ với Spawn
|
||||
|
||||
Đối với các tác vụ chạy lâu (tìm kiếm web, gọi API), sử dụng công cụ `spawn` để tạo **subagent**:
|
||||
|
||||
```markdown
|
||||
# Tác vụ định kỳ
|
||||
|
||||
## Tác vụ nhanh (trả lời trực tiếp)
|
||||
- Báo cáo thời gian hiện tại
|
||||
|
||||
## Tác vụ lâu (dùng spawn cho async)
|
||||
- Tìm kiếm tin tức AI trên web và tóm tắt
|
||||
- Kiểm tra email và báo cáo tin nhắn quan trọng
|
||||
```
|
||||
|
||||
**Hành vi chính:**
|
||||
|
||||
| Tính năng | Mô tả |
|
||||
|-----------|-------|
|
||||
| **spawn** | Tạo subagent bất đồng bộ, không chặn heartbeat |
|
||||
| **Context độc lập** | Subagent có context riêng, không có lịch sử phiên |
|
||||
| **message tool** | Subagent giao tiếp trực tiếp với người dùng qua công cụ message |
|
||||
| **Không chặn** | Sau khi spawn, heartbeat tiếp tục tác vụ tiếp theo |
|
||||
|
||||
#### Cách Subagent giao tiếp
|
||||
|
||||
```
|
||||
Heartbeat kích hoạt
|
||||
↓
|
||||
Agent đọc HEARTBEAT.md
|
||||
↓
|
||||
Tác vụ lâu: spawn subagent
|
||||
↓ ↓
|
||||
Tiếp tục tác vụ tiếp theo Subagent làm việc độc lập
|
||||
↓ ↓
|
||||
Tất cả tác vụ hoàn thành Subagent dùng công cụ "message"
|
||||
↓ ↓
|
||||
Phản hồi HEARTBEAT_OK Người dùng nhận kết quả trực tiếp
|
||||
```
|
||||
|
||||
Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và có thể giao tiếp với người dùng một cách độc lập mà không cần thông qua agent chính.
|
||||
|
||||
**Cấu hình:**
|
||||
|
||||
```json
|
||||
{
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Tùy chọn | Mặc định | Mô tả |
|
||||
|----------|---------|-------|
|
||||
| `enabled` | `true` | Bật/tắt heartbeat |
|
||||
| `interval` | `30` | Khoảng thời gian kiểm tra (phút, tối thiểu: 5) |
|
||||
|
||||
**Biến môi trường:**
|
||||
|
||||
* `PICOCLAW_HEARTBEAT_ENABLED=false` để tắt
|
||||
* `PICOCLAW_HEARTBEAT_INTERVAL=60` để thay đổi khoảng thời gian
|
||||
|
||||
### Nhà cung cấp (Providers)
|
||||
|
||||
> [!NOTE]
|
||||
> Groq cung cấp dịch vụ chuyển giọng nói thành văn bản miễn phí qua Whisper. Nếu đã cấu hình Groq, tin nhắn thoại trên Telegram sẽ được tự động chuyển thành văn bản.
|
||||
|
||||
| Nhà cung cấp | Mục đích | Lấy API Key |
|
||||
| --- | --- | --- |
|
||||
| `gemini` | LLM (Gemini trực tiếp) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `zhipu` | LLM (Zhipu trực tiếp) | [bigmodel.cn](bigmodel.cn) |
|
||||
| `openrouter` (Đang thử nghiệm) | LLM (khuyên dùng, truy cập mọi model) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `anthropic` (Đang thử nghiệm) | LLM (Claude trực tiếp) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `openai` (Đang thử nghiệm) | LLM (GPT trực tiếp) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek` (Đang thử nghiệm) | LLM (DeepSeek trực tiếp) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Chuyển giọng nói** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
|
||||
<details>
|
||||
<summary><b>Cấu hình Zhipu</b></summary>
|
||||
|
||||
**1. Lấy API key**
|
||||
|
||||
* Lấy [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
||||
|
||||
**2. Cấu hình**
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"zhipu": {
|
||||
"api_key": "Your API Key",
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Chạy**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "Xin chào"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Ví dụ cấu hình đầy đủ</b></summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "anthropic/claude-opus-4-5"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"api_key": "sk-or-v1-xxx"
|
||||
},
|
||||
"groq": {
|
||||
"api_key": "gsk_xxx"
|
||||
}
|
||||
},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "123456:ABC...",
|
||||
"allow_from": ["123456789"]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "",
|
||||
"allow_from": [""]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": false
|
||||
},
|
||||
"feishu": {
|
||||
"enabled": false,
|
||||
"app_id": "cli_xxx",
|
||||
"app_secret": "xxx",
|
||||
"encrypt_key": "",
|
||||
"verification_token": "",
|
||||
"allow_from": []
|
||||
},
|
||||
"qq": {
|
||||
"enabled": false,
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"allow_from": []
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "BSA...",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Tham chiếu CLI
|
||||
|
||||
| Lệnh | Mô tả |
|
||||
| --- | --- |
|
||||
| `picoclaw onboard` | Khởi tạo cấu hình & workspace |
|
||||
| `picoclaw agent -m "..."` | Trò chuyện với agent |
|
||||
| `picoclaw agent` | Chế độ chat tương tác |
|
||||
| `picoclaw gateway` | Khởi động gateway (cho bot chat) |
|
||||
| `picoclaw status` | Hiển thị trạng thái |
|
||||
| `picoclaw cron list` | Liệt kê tất cả tác vụ định kỳ |
|
||||
| `picoclaw cron add ...` | Thêm tác vụ định kỳ |
|
||||
|
||||
### Tác vụ định kỳ / Nhắc nhở
|
||||
|
||||
PicoClaw hỗ trợ nhắc nhở theo lịch và tác vụ lặp lại thông qua công cụ `cron`:
|
||||
|
||||
* **Nhắc nhở một lần**: "Remind me in 10 minutes" (Nhắc tôi sau 10 phút) → kích hoạt một lần sau 10 phút
|
||||
* **Tác vụ lặp lại**: "Remind me every 2 hours" (Nhắc tôi mỗi 2 giờ) → kích hoạt mỗi 2 giờ
|
||||
* **Biểu thức Cron**: "Remind me at 9am daily" (Nhắc tôi lúc 9 giờ sáng mỗi ngày) → sử dụng biểu thức cron
|
||||
|
||||
Các tác vụ được lưu trong `~/.picoclaw/workspace/cron/` và được xử lý tự động.
|
||||
|
||||
## 🤝 Đóng góp & Lộ trình
|
||||
|
||||
Chào đón mọi PR! Mã nguồn được thiết kế nhỏ gọn và dễ đọc. 🤗
|
||||
|
||||
Lộ trình sắp được công bố...
|
||||
|
||||
Nhóm phát triển đang được xây dựng. Điều kiện tham gia: Ít nhất 1 PR đã được merge.
|
||||
|
||||
Nhóm người dùng:
|
||||
|
||||
Discord: <https://discord.gg/V4sAZ9XWpN>
|
||||
|
||||
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
||||
|
||||
## 🐛 Xử lý sự cố
|
||||
|
||||
### Tìm kiếm web hiện "API 配置问题"
|
||||
|
||||
Điều này là bình thường nếu bạn chưa cấu hình API key cho tìm kiếm. PicoClaw sẽ cung cấp các liên kết hữu ích để tìm kiếm thủ công.
|
||||
|
||||
Để bật tìm kiếm web:
|
||||
|
||||
1. **Tùy chọn 1 (Khuyên dùng)**: Lấy API key miễn phí tại [https://brave.com/search/api](https://brave.com/search/api) (2000 truy vấn miễn phí/tháng) để có kết quả tốt nhất.
|
||||
2. **Tùy chọn 2 (Không cần thẻ tín dụng)**: Nếu không có key, hệ thống tự động chuyển sang dùng **DuckDuckGo** (không cần key).
|
||||
|
||||
Thêm key vào `~/.picoclaw/config.json` nếu dùng Brave:
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": true,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Gặp lỗi lọc nội dung (Content Filtering)
|
||||
|
||||
Một số nhà cung cấp (như Zhipu) có bộ lọc nội dung nghiêm ngặt. Thử diễn đạt lại câu hỏi hoặc sử dụng model khác.
|
||||
|
||||
### Telegram bot báo "Conflict: terminated by other getUpdates"
|
||||
|
||||
Điều này xảy ra khi có một instance bot khác đang chạy. Đảm bảo chỉ có một tiến trình `picoclaw gateway` chạy tại một thời điểm.
|
||||
|
||||
---
|
||||
|
||||
## 📝 So sánh API Key
|
||||
|
||||
| Dịch vụ | Gói miễn phí | Trường hợp sử dụng |
|
||||
| --- | --- | --- |
|
||||
| **OpenRouter** | 200K tokens/tháng | Đa model (Claude, GPT-4, v.v.) |
|
||||
| **Zhipu** | 200K tokens/tháng | Tốt nhất cho người dùng Trung Quốc |
|
||||
| **Brave Search** | 2000 truy vấn/tháng | Chức năng tìm kiếm web |
|
||||
| **Groq** | Có gói miễn phí | Suy luận siêu nhanh (Llama, Mixtral) |
|
||||
+28
-3
@@ -14,7 +14,7 @@
|
||||
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
|
||||
</p>
|
||||
|
||||
**中文** | [日本語](README.ja.md) | [English](README.md)
|
||||
**中文** | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [English](README.md)
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -46,9 +46,11 @@
|
||||
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
|
||||
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
|
||||
> * **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
|
||||
> * **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化.
|
||||
|
||||
|
||||
## 📢 新闻 (News)
|
||||
2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](docs/picoclaw_community_roadmap_260216.md), 期待你的参与!
|
||||
|
||||
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars!** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
|
||||
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
|
||||
@@ -98,6 +100,23 @@
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 📱 在手机上轻松运行
|
||||
picoclaw 可以将你10年前的老旧手机废物利用,变身成为你的AI助理!快速指南:
|
||||
1. 先去应用商店下载安装Termux
|
||||
2. 打开后执行指令
|
||||
```bash
|
||||
# 注意: 下面的v0.1.1 可以换为你实际看到的最新版本
|
||||
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
|
||||
chmod +x picoclaw-linux-arm64
|
||||
pkg install proot
|
||||
termux-chroot ./picoclaw-linux-arm64 onboard
|
||||
```
|
||||
然后跟随下面的“快速开始”章节继续配置picoclaw即可使用!
|
||||
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
|
||||
|
||||
|
||||
### 🐜 创新的低占用部署
|
||||
|
||||
PicoClaw 几乎可以部署在任何 Linux 设备上!
|
||||
@@ -217,6 +236,9 @@ picoclaw onboard
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -269,7 +291,7 @@ picoclaw agent -m "2+2 等于几?"
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -314,7 +336,7 @@ picoclaw gateway
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -625,6 +647,9 @@ picoclaw agent -m "你好"
|
||||
"search": {
|
||||
"api_key": "BSA..."
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
+116
@@ -0,0 +1,116 @@
|
||||
|
||||
# 🦐 PicoClaw Roadmap
|
||||
|
||||
> **Vision**: To build the ultimate lightweight, secure, and fully autonomous AI Agent infrastructure.automate the mundane, unleash your creativity
|
||||
|
||||
---
|
||||
|
||||
## 🚀 1. Core Optimization: Extreme Lightweight
|
||||
|
||||
*Our defining characteristic. We fight software bloat to ensure PicoClaw runs smoothly on the smallest embedded devices.*
|
||||
|
||||
* [**Memory Footprint Reduction**](https://github.com/sipeed/picoclaw/issues/346)
|
||||
* **Goal**: Run smoothly on 64MB RAM embedded boards (e.g., low-end RISC-V SBCs) with the core process consuming < 20MB.
|
||||
* **Context**: RAM is expensive and scarce on edge devices. Memory optimization takes precedence over storage size.
|
||||
* **Action**: Analyze memory growth between releases, remove redundant dependencies, and optimize data structures.
|
||||
|
||||
|
||||
## 🛡️ 2. Security Hardening: Defense in Depth
|
||||
|
||||
*Paying off early technical debt. We invite security experts to help build a "Secure-by-Default" agent.*
|
||||
|
||||
* **Input Defense & Permission Control**
|
||||
* **Prompt Injection Defense**: Harden JSON extraction logic to prevent LLM manipulation.
|
||||
* **Tool Abuse Prevention**: Strict parameter validation to ensure generated commands stay within safe boundaries.
|
||||
* **SSRF Protection**: Built-in blocklists for network tools to prevent accessing internal IPs (LAN/Metadata services).
|
||||
|
||||
|
||||
* **Sandboxing & Isolation**
|
||||
* **Filesystem Sandbox**: Restrict file R/W operations to specific directories only.
|
||||
* **Context Isolation**: Prevent data leakage between different user sessions or channels.
|
||||
* **Privacy Redaction**: Auto-redact sensitive info (API Keys, PII) from logs and standard outputs.
|
||||
|
||||
|
||||
* **Authentication & Secrets**
|
||||
* **Crypto Upgrade**: Adopt modern algorithms like `ChaCha20-Poly1305` for secret storage.
|
||||
* **OAuth 2.0 Flow**: Deprecate hardcoded API keys in the CLI; move to secure OAuth flows.
|
||||
|
||||
|
||||
|
||||
## 🔌 3. Connectivity: Protocol-First Architecture
|
||||
|
||||
*Connect every model, reach every platform.*
|
||||
|
||||
* **Provider**
|
||||
* [**Architecture Upgrade**](https://github.com/sipeed/picoclaw/issues/283): Refactor from "Vendor-based" to "Protocol-based" classification (e.g., OpenAI-compatible, Ollama-compatible). *(Status: In progress by @Daming, ETA 5 days)*
|
||||
* **Local Models**: Deep integration with **Ollama**, **vLLM**, **LM Studio**, and **Mistral** (local inference).
|
||||
* **Online Models**: Continued support for frontier closed-source models.
|
||||
|
||||
|
||||
* **Channel**
|
||||
* **IM Matrix**: QQ, WeChat (Work), DingTalk, Feishu (Lark), Telegram, Discord, WhatsApp, LINE, Slack, Email, KOOK, Signal, ...
|
||||
* **Standards**: Support for the **OneBot** protocol.
|
||||
* [**attachment**](https://github.com/sipeed/picoclaw/issues/348): Native handling of images, audio, and video attachments.
|
||||
|
||||
|
||||
* **Skill Marketplace**
|
||||
* [**Discovery skills**](https://github.com/sipeed/picoclaw/issues/287): Implement `find_skill` to automatically discover and install skills from the [GitHub Skills Repo] or other registries.
|
||||
|
||||
|
||||
|
||||
## 🧠 4. Advanced Capabilities: From Chatbot to Agentic AI
|
||||
|
||||
*Beyond conversation—focusing on action and collaboration.*
|
||||
|
||||
* **Operations**
|
||||
* [**MCP Support**](https://github.com/sipeed/picoclaw/issues/290): Native support for the **Model Context Protocol (MCP)**.
|
||||
* [**Browser Automation**](https://github.com/sipeed/picoclaw/issues/293): Headless browser control via CDP (Chrome DevTools Protocol) or ActionBook.
|
||||
* [**Mobile Operation**](https://github.com/sipeed/picoclaw/issues/292): Android device control (similar to BotDrop).
|
||||
|
||||
|
||||
* **Multi-Agent Collaboration**
|
||||
* [**Basic Multi-Agent**](https://github.com/sipeed/picoclaw/issues/294) implement
|
||||
* [**Model Routing**](https://github.com/sipeed/picoclaw/issues/295): "Smart Routing" — dispatch simple tasks to small/local models (fast/cheap) and complex tasks to SOTA models (smart).
|
||||
* [**Swarm Mode**](https://github.com/sipeed/picoclaw/issues/284): Collaboration between multiple PicoClaw instances on the same network.
|
||||
* [**AIEOS**](https://github.com/sipeed/picoclaw/issues/296): Exploring AI-Native Operating System interaction paradigms.
|
||||
|
||||
|
||||
|
||||
## 📚 5. Developer Experience (DevEx) & Documentation
|
||||
|
||||
*Lowering the barrier to entry so anyone can deploy in minutes.*
|
||||
|
||||
* [**QuickGuide (Zero-Config Start)**](https://github.com/sipeed/picoclaw/issues/350)
|
||||
* Interactive CLI Wizard: If launched without config, automatically detect the environment and guide the user through Token/Network setup step-by-step.
|
||||
|
||||
|
||||
* **Comprehensive Documentation**
|
||||
* **Platform Guides**: Dedicated guides for Windows, macOS, Linux, and Android.
|
||||
* **Step-by-Step Tutorials**: "Babysitter-level" guides for configuring Providers and Channels.
|
||||
* **AI-Assisted Docs**: Using AI to auto-generate API references and code comments (with human verification to prevent hallucinations).
|
||||
|
||||
|
||||
|
||||
## 🤖 6. Engineering: AI-Powered Open Source
|
||||
|
||||
*Born from Vibe Coding, we continue to use AI to accelerate development.*
|
||||
|
||||
* **AI-Enhanced CI/CD**
|
||||
* Integrate AI for automated Code Review, Linting, and PR Labeling.
|
||||
* **Bot Noise Reduction**: Optimize bot interactions to keep PR timelines clean.
|
||||
* **Issue Triage**: AI agents to analyze incoming issues and suggest preliminary fixes.
|
||||
|
||||
|
||||
|
||||
## 🎨 7. Brand & Community
|
||||
|
||||
* [**Logo Design**](https://github.com/sipeed/picoclaw/issues/297): We are looking for a **Mantis Shrimp (Stomatopoda)** logo design!
|
||||
* *Concept*: Needs to reflect "Small but Mighty" and "Lightning Fast Strikes."
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 🤝 Call for Contributions
|
||||
|
||||
We welcome community contributions to any item on this roadmap! Please comment on the relevant Issue or submit a PR. Let's build the best Edge AI Agent together!
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 97 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 140 KiB After Width: | Height: | Size: 141 KiB |
+10
-3
@@ -562,7 +562,8 @@ func gatewayCmd() {
|
||||
})
|
||||
|
||||
// Setup cron tool and service
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace)
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout, cfg)
|
||||
|
||||
heartbeatService := heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
@@ -622,6 +623,12 @@ func gatewayCmd() {
|
||||
logger.InfoC("voice", "Groq transcription attached to Slack channel")
|
||||
}
|
||||
}
|
||||
if onebotChannel, ok := channelManager.GetChannel("onebot"); ok {
|
||||
if oc, ok := onebotChannel.(*channels.OneBotChannel); ok {
|
||||
oc.SetTranscriber(transcriber)
|
||||
logger.InfoC("voice", "Groq transcription attached to OneBot channel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enabledChannels := channelManager.GetEnabledChannels()
|
||||
@@ -987,14 +994,14 @@ func getConfigPath() string {
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
}
|
||||
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool) *cron.CronService {
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *cron.CronService {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict)
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, config)
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
|
||||
// Set the onJob handler
|
||||
|
||||
+13
-55
@@ -79,7 +79,8 @@
|
||||
},
|
||||
"openai": {
|
||||
"api_key": "",
|
||||
"api_base": ""
|
||||
"api_base": "",
|
||||
"web_search": true
|
||||
},
|
||||
"openrouter": {
|
||||
"api_key": "sk-or-v1-xxx",
|
||||
@@ -117,66 +118,23 @@
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "pplx-xxx",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
},
|
||||
"mcp": {
|
||||
"enabled": false,
|
||||
"servers": {
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/tmp"
|
||||
],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-github"
|
||||
],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_TOKEN"
|
||||
},
|
||||
"envFile": ".env"
|
||||
},
|
||||
"brave-search": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-brave-search"
|
||||
],
|
||||
"env": {
|
||||
"BRAVE_API_KEY": "YOUR_BRAVE_API_KEY"
|
||||
}
|
||||
},
|
||||
"postgres": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-postgres",
|
||||
"postgresql://user:password@localhost/dbname"
|
||||
]
|
||||
},
|
||||
"remote-http-example": {
|
||||
"enabled": false,
|
||||
"url": "https://mcp-server.example.com/stream",
|
||||
"type": "sse",
|
||||
"headers": {
|
||||
"Authorization": "Bearer YOUR_TOKEN",
|
||||
"X-Custom-Header": "custom-value"
|
||||
}
|
||||
}
|
||||
}
|
||||
"servers": {}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
|
||||
+4
-4
@@ -11,8 +11,8 @@ services:
|
||||
profiles:
|
||||
- agent
|
||||
volumes:
|
||||
- ./config/config.json:/root/.picoclaw/config.json:ro
|
||||
- picoclaw-workspace:/root/.picoclaw/workspace
|
||||
- ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
|
||||
- picoclaw-workspace:/home/picoclaw/.picoclaw/workspace
|
||||
entrypoint: ["picoclaw", "agent"]
|
||||
stdin_open: true
|
||||
tty: true
|
||||
@@ -31,9 +31,9 @@ services:
|
||||
- gateway
|
||||
volumes:
|
||||
# Configuration file
|
||||
- ./config/config.json:/root/.picoclaw/config.json:ro
|
||||
- ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
|
||||
# Persistent workspace (sessions, memory, logs)
|
||||
- picoclaw-workspace:/root/.picoclaw/workspace
|
||||
- picoclaw-workspace:/home/picoclaw/.picoclaw/workspace
|
||||
command: ["gateway"]
|
||||
|
||||
volumes:
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
## 🚀 Join the PicoClaw Journey: Call for Community Volunteers & Roadmap Reveal
|
||||
|
||||
**Hello, PicoClaw Community!**
|
||||
|
||||
First, a massive thank you to everyone for your enthusiasm and PR contributions. It is because of you that PicoClaw continues to iterate and evolve so rapidly. Thanks to the simplicity and accessibility of the **Go language**, we’ve seen a non-stop stream of high-quality PRs!
|
||||
|
||||
PicoClaw is growing much faster than we anticipated. As we are currently in the midst of the **Chinese New Year holiday**, we are looking to recruit community volunteers to help us maintain this incredible momentum.
|
||||
|
||||
This document outlines the specific volunteer roles we need right now and provides a look at our upcoming **Roadmap**.
|
||||
|
||||
### 🎁 Community Perks
|
||||
|
||||
To show our appreciation, developers who officially join our community operations will receive:
|
||||
|
||||
* **Exclusive AI Hardware:** Our upcoming, unreleased AI device.
|
||||
* **Token Discounts:** Potential discounts on LLM tokens (currently in negotiations with major providers).
|
||||
|
||||
### 🎥 Calling All Content Creators!
|
||||
|
||||
Not a developer? You can still help! We welcome users to post **PicoClaw reviews or tutorials**.
|
||||
|
||||
* **Twitter:** Use the tag **#picoclaw** and mention **@SipeedIO**.
|
||||
* **Bilibili:** Mention **@Sipeed矽速科技** or send us a DM.
|
||||
We will be rewarding high-quality content creators with the same perks as our community developers!
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Urgent Volunteer Roles
|
||||
|
||||
We are looking for experts in the following areas:
|
||||
|
||||
1. **Issue/PR Reviewers**
|
||||
* **The Mission:** With PRs and Issues exploding in volume, we need help with initial triage, evaluation, and merging.
|
||||
* **Focus:** Preliminary merging and community health. Efficiency optimization and security audits will be handled by specialized roles.
|
||||
|
||||
|
||||
2. **Resource Optimization Experts**
|
||||
* **The Mission:** Rapid growth has introduced dependencies that are making PicoClaw a bit "heavy." We want to keep it lean.
|
||||
* **Focus:** Analyzing resource growth between releases and trimming redundancy.
|
||||
* **Priority:** **RAM usage optimization** > Binary size reduction.
|
||||
|
||||
|
||||
3. **Security Audit & Bug Fixes**
|
||||
* **The Mission:** Due to the "vibe coding" nature of our early stages, we need a thorough review of network security and AI permission management.
|
||||
* **Focus:** Auditing the codebase for vulnerabilities and implementing robust fixes.
|
||||
|
||||
|
||||
4. **Documentation & DX (Developer Experience)**
|
||||
* **The Mission:** Our current README is a bit outdated. We need "step-by-step" guides that even beginners can follow.
|
||||
* **Focus:** Creating clear, user-friendly documentation for both setup and development.
|
||||
|
||||
|
||||
5. **AI-Powered CI/CD Optimization**
|
||||
* **The Mission:** PicoClaw started as a "vibe coding" experiment; now we want to use AI to manage it.
|
||||
* **Focus:** Automating builds with AI and exploring AI-driven issue resolution.
|
||||
|
||||
**How to Apply:** > If you are interested in any of the roles above, please send an email to support@sipeed.com with the subject line: [Apply: PicoClaw Expert Volunteer] + Your Desired Role.
|
||||
Please include a brief introduction and any relevant experience or portfolio links. We will review all applications and grant project permissions to selected contributors!
|
||||
|
||||
---
|
||||
|
||||
## 📍 The Roadmap
|
||||
|
||||
Interested in a specific feature? You can "claim" these tasks and start building:
|
||||
|
||||
###
|
||||
* **Provider:**
|
||||
* **Provider Refactor:** Currently being handled by **@Daming** (ETA: 5 days)
|
||||
* You can still submit code; Daming will merge it into the new implementation.
|
||||
* **Channels:**
|
||||
* Support for OneBot, additional platforms
|
||||
* attachments (images, audio, video, files).
|
||||
* **Skills:**
|
||||
* Implementing `find_skill` to discover tools via [openclaw/skills](https://github.com/openclaw/skills) and other platforms.
|
||||
* **Operations:** * MCP Support.
|
||||
* Android operations (e.g., botdrop).
|
||||
* Browser automation via CDP or ActionBook.
|
||||
|
||||
|
||||
* **Multi-Agent Ecosystem:**
|
||||
* **Basic Model-Agnet** S
|
||||
* **Model Routing:** Small models for easy tasks, large models for hard ones (to save tokens).
|
||||
* **Swarm Mode.**
|
||||
* **AIEOS Integration.**
|
||||
|
||||
|
||||
* **Branding:**
|
||||
* **Logo**: We need a cute logo! We’re leaning toward a **Mantis Shrimp**—small, but packs a legendary punch!
|
||||
|
||||
|
||||
We have officially created these tasks as GitHub Issues, all marked with the roadmap tag.
|
||||
This list will be updated continuously as we progress.
|
||||
If you would like to claim a task, please feel free to start a conversation by commenting directly on the corresponding issue!
|
||||
|
||||
---
|
||||
|
||||
## 🤝 How to Join
|
||||
|
||||
**Everything is open to your creativity!** If you have a wild idea, just PR it.
|
||||
|
||||
1. **The Fast Track:** Once you have at least **one merged PR**, you are eligible to join our **Developer Discord** to help plan the future of PicoClaw.
|
||||
2. **The Application Track:** If you haven’t submitted a PR yet but want to dive in, email **support@sipeed.com** with the subject:
|
||||
> `[Apply Join PicoClaw Dev Group] + Your GitHub Account`
|
||||
> Include the role you're interested in and any evidence of your development experience.
|
||||
|
||||
|
||||
|
||||
### Looking Ahead
|
||||
|
||||
Powered by PicoClaw, we are crafting a Swarm AI Assistant to transform your environment into a seamless network of personal stewards. By automating the friction of daily life, we empower you to transcend the ordinary and freely explore your creative potential.
|
||||
|
||||
**Finally, Happy Chinese New Year to everyone!** May PicoClaw gallop forward in this **Year of the Horse!** 🐎
|
||||
@@ -0,0 +1,122 @@
|
||||
# Tools Configuration
|
||||
|
||||
PicoClaw's tools configuration is located in the `tools` field of `config.json`.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": { ... },
|
||||
"exec": { ... },
|
||||
"approval": { ... },
|
||||
"cron": { ... }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Web Tools
|
||||
|
||||
Web tools are used for web search and fetching.
|
||||
|
||||
### Brave
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enabled` | bool | false | Enable Brave search |
|
||||
| `api_key` | string | - | Brave Search API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
|
||||
### DuckDuckGo
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enabled` | bool | true | Enable DuckDuckGo search |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
|
||||
### Perplexity
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enabled` | bool | false | Enable Perplexity search |
|
||||
| `api_key` | string | - | Perplexity API key |
|
||||
| `max_results` | int | 5 | Maximum number of results |
|
||||
|
||||
## Exec Tool
|
||||
|
||||
The exec tool is used to execute shell commands.
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enable_deny_patterns` | bool | true | Enable default dangerous command blocking |
|
||||
| `custom_deny_patterns` | array | [] | Custom deny patterns (regular expressions) |
|
||||
|
||||
### Functionality
|
||||
|
||||
- **`enable_deny_patterns`**: Set to `false` to completely disable the default dangerous command blocking patterns
|
||||
- **`custom_deny_patterns`**: Add custom deny regex patterns; commands matching these will be blocked
|
||||
|
||||
### Default Blocked Command Patterns
|
||||
|
||||
By default, PicoClaw blocks the following dangerous commands:
|
||||
|
||||
- Delete commands: `rm -rf`, `del /f/q`, `rmdir /s`
|
||||
- Disk operations: `format`, `mkfs`, `diskpart`, `dd if=`, writing to `/dev/sd*`
|
||||
- System operations: `shutdown`, `reboot`, `poweroff`
|
||||
- Command substitution: `$()`, `${}`, backticks
|
||||
- Pipe to shell: `| sh`, `| bash`
|
||||
- Privilege escalation: `sudo`, `chmod`, `chown`
|
||||
- Process control: `pkill`, `killall`, `kill -9`
|
||||
- Remote operations: `curl | sh`, `wget | sh`, `ssh`
|
||||
- Package management: `apt`, `yum`, `dnf`, `npm install -g`, `pip install --user`
|
||||
- Containers: `docker run`, `docker exec`
|
||||
- Git: `git push`, `git force`
|
||||
- Other: `eval`, `source *.sh`
|
||||
|
||||
### Configuration Example
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"exec": {
|
||||
"enable_deny_patterns": true,
|
||||
"custom_deny_patterns": [
|
||||
"\\brm\\s+-r\\b",
|
||||
"\\bkillall\\s+python"
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Approval Tool
|
||||
|
||||
The approval tool controls permissions for dangerous operations.
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `enabled` | bool | true | Enable approval functionality |
|
||||
| `write_file` | bool | true | Require approval for file writes |
|
||||
| `edit_file` | bool | true | Require approval for file edits |
|
||||
| `append_file` | bool | true | Require approval for file appends |
|
||||
| `exec` | bool | true | Require approval for command execution |
|
||||
| `timeout_minutes` | int | 5 | Approval timeout in minutes |
|
||||
|
||||
## Cron Tool
|
||||
|
||||
The cron tool is used for scheduling periodic tasks.
|
||||
|
||||
| Config | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS_<SECTION>_<KEY>`:
|
||||
|
||||
For example:
|
||||
- `PICOCLAW_TOOLS_WEB_BRAVE_ENABLED=true`
|
||||
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
|
||||
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
|
||||
|
||||
Note: Array-type environment variables are not currently supported and must be set via the config file.
|
||||
@@ -0,0 +1,145 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// AgentInstance represents a fully configured agent with its own workspace,
|
||||
// session manager, context builder, and tool registry.
|
||||
type AgentInstance struct {
|
||||
ID string
|
||||
Name string
|
||||
Model string
|
||||
Fallbacks []string
|
||||
Workspace string
|
||||
MaxIterations int
|
||||
ContextWindow int
|
||||
Provider providers.LLMProvider
|
||||
Sessions *session.SessionManager
|
||||
ContextBuilder *ContextBuilder
|
||||
Tools *tools.ToolRegistry
|
||||
Subagents *config.SubagentsConfig
|
||||
SkillsFilter []string
|
||||
Candidates []providers.FallbackCandidate
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
func NewAgentInstance(
|
||||
agentCfg *config.AgentConfig,
|
||||
defaults *config.AgentDefaults,
|
||||
cfg *config.Config,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentInstance {
|
||||
workspace := resolveAgentWorkspace(agentCfg, defaults)
|
||||
os.MkdirAll(workspace, 0755)
|
||||
|
||||
model := resolveAgentModel(agentCfg, defaults)
|
||||
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
|
||||
|
||||
restrict := defaults.RestrictToWorkspace
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewExecToolWithConfig(workspace, restrict, cfg))
|
||||
toolsRegistry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
|
||||
sessionsDir := filepath.Join(workspace, "sessions")
|
||||
sessionsManager := session.NewSessionManager(sessionsDir)
|
||||
|
||||
contextBuilder := NewContextBuilder(workspace)
|
||||
contextBuilder.SetToolsRegistry(toolsRegistry)
|
||||
|
||||
agentID := routing.DefaultAgentID
|
||||
agentName := ""
|
||||
var subagents *config.SubagentsConfig
|
||||
var skillsFilter []string
|
||||
|
||||
if agentCfg != nil {
|
||||
agentID = routing.NormalizeAgentID(agentCfg.ID)
|
||||
agentName = agentCfg.Name
|
||||
subagents = agentCfg.Subagents
|
||||
skillsFilter = agentCfg.Skills
|
||||
}
|
||||
|
||||
maxIter := defaults.MaxToolIterations
|
||||
if maxIter == 0 {
|
||||
maxIter = 20
|
||||
}
|
||||
|
||||
// Resolve fallback candidates
|
||||
modelCfg := providers.ModelConfig{
|
||||
Primary: model,
|
||||
Fallbacks: fallbacks,
|
||||
}
|
||||
candidates := providers.ResolveCandidates(modelCfg, defaults.Provider)
|
||||
|
||||
return &AgentInstance{
|
||||
ID: agentID,
|
||||
Name: agentName,
|
||||
Model: model,
|
||||
Fallbacks: fallbacks,
|
||||
Workspace: workspace,
|
||||
MaxIterations: maxIter,
|
||||
ContextWindow: defaults.MaxTokens,
|
||||
Provider: provider,
|
||||
Sessions: sessionsManager,
|
||||
ContextBuilder: contextBuilder,
|
||||
Tools: toolsRegistry,
|
||||
Subagents: subagents,
|
||||
SkillsFilter: skillsFilter,
|
||||
Candidates: candidates,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAgentWorkspace determines the workspace directory for an agent.
|
||||
func resolveAgentWorkspace(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
|
||||
if agentCfg != nil && strings.TrimSpace(agentCfg.Workspace) != "" {
|
||||
return expandHome(strings.TrimSpace(agentCfg.Workspace))
|
||||
}
|
||||
if agentCfg == nil || agentCfg.Default || agentCfg.ID == "" || routing.NormalizeAgentID(agentCfg.ID) == "main" {
|
||||
return expandHome(defaults.Workspace)
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
id := routing.NormalizeAgentID(agentCfg.ID)
|
||||
return filepath.Join(home, ".picoclaw", "workspace-"+id)
|
||||
}
|
||||
|
||||
// resolveAgentModel resolves the primary model for an agent.
|
||||
func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) string {
|
||||
if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" {
|
||||
return strings.TrimSpace(agentCfg.Model.Primary)
|
||||
}
|
||||
return defaults.Model
|
||||
}
|
||||
|
||||
// resolveAgentFallbacks resolves the fallback models for an agent.
|
||||
func resolveAgentFallbacks(agentCfg *config.AgentConfig, defaults *config.AgentDefaults) []string {
|
||||
if agentCfg != nil && agentCfg.Model != nil && agentCfg.Model.Fallbacks != nil {
|
||||
return agentCfg.Model.Fallbacks
|
||||
}
|
||||
return defaults.ModelFallbacks
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
}
|
||||
if path[0] == '~' {
|
||||
home, _ := os.UserHomeDir()
|
||||
if len(path) > 1 && path[1] == '/' {
|
||||
return home + path[1:]
|
||||
}
|
||||
return home
|
||||
}
|
||||
return path
|
||||
}
|
||||
+335
-341
File diff suppressed because it is too large
Load Diff
@@ -594,7 +594,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
{Role: "assistant", Content: "Old response 2"},
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
}
|
||||
al.sessions.SetHistory(sessionKey, history)
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, history)
|
||||
|
||||
// Call ProcessDirectWithChannel
|
||||
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
|
||||
@@ -614,7 +618,7 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check final history length
|
||||
finalHistory := al.sessions.GetHistory(sessionKey)
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
// We verify that the history has been modified (compressed)
|
||||
// Original length: 6
|
||||
// Expected behavior: compression drops ~50% of history (mid slice)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
)
|
||||
|
||||
// AgentRegistry manages multiple agent instances and routes messages to them.
|
||||
type AgentRegistry struct {
|
||||
agents map[string]*AgentInstance
|
||||
resolver *routing.RouteResolver
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewAgentRegistry creates a registry from config, instantiating all agents.
|
||||
func NewAgentRegistry(
|
||||
cfg *config.Config,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentRegistry {
|
||||
registry := &AgentRegistry{
|
||||
agents: make(map[string]*AgentInstance),
|
||||
resolver: routing.NewRouteResolver(cfg),
|
||||
}
|
||||
|
||||
agentConfigs := cfg.Agents.List
|
||||
if len(agentConfigs) == 0 {
|
||||
implicitAgent := &config.AgentConfig{
|
||||
ID: "main",
|
||||
Default: true,
|
||||
}
|
||||
instance := NewAgentInstance(implicitAgent, &cfg.Agents.Defaults, cfg, provider)
|
||||
registry.agents["main"] = instance
|
||||
logger.InfoCF("agent", "Created implicit main agent (no agents.list configured)", nil)
|
||||
} else {
|
||||
for i := range agentConfigs {
|
||||
ac := &agentConfigs[i]
|
||||
id := routing.NormalizeAgentID(ac.ID)
|
||||
instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider)
|
||||
registry.agents[id] = instance
|
||||
logger.InfoCF("agent", "Registered agent",
|
||||
map[string]interface{}{
|
||||
"agent_id": id,
|
||||
"name": ac.Name,
|
||||
"workspace": instance.Workspace,
|
||||
"model": instance.Model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
// GetAgent returns the agent instance for a given ID.
|
||||
func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
id := routing.NormalizeAgentID(agentID)
|
||||
agent, ok := r.agents[id]
|
||||
return agent, ok
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message.
|
||||
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(input)
|
||||
}
|
||||
|
||||
// ListAgentIDs returns all registered agent IDs.
|
||||
func (r *AgentRegistry) ListAgentIDs() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
ids := make([]string, 0, len(r.agents))
|
||||
for id := range r.agents {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// CanSpawnSubagent checks if parentAgentID is allowed to spawn targetAgentID.
|
||||
func (r *AgentRegistry) CanSpawnSubagent(parentAgentID, targetAgentID string) bool {
|
||||
parent, ok := r.GetAgent(parentAgentID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if parent.Subagents == nil || parent.Subagents.AllowAgents == nil {
|
||||
return false
|
||||
}
|
||||
targetNorm := routing.NormalizeAgentID(targetAgentID)
|
||||
for _, allowed := range parent.Subagents.AllowAgents {
|
||||
if allowed == "*" {
|
||||
return true
|
||||
}
|
||||
if routing.NormalizeAgentID(allowed) == targetNorm {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetDefaultAgent returns the default agent instance.
|
||||
func (r *AgentRegistry) GetDefaultAgent() *AgentInstance {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
if agent, ok := r.agents["main"]; ok {
|
||||
return agent
|
||||
}
|
||||
for _, agent := range r.agents {
|
||||
return agent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type mockRegistryProvider struct{}
|
||||
|
||||
func (m *mockRegistryProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
func (m *mockRegistryProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func testCfg(agents []config.AgentConfig) *config.Config {
|
||||
return &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: "/tmp/picoclaw-test-registry",
|
||||
Model: "gpt-4",
|
||||
MaxTokens: 8192,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
List: agents,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentRegistry_ImplicitMain(t *testing.T) {
|
||||
cfg := testCfg(nil)
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
ids := registry.ListAgentIDs()
|
||||
if len(ids) != 1 || ids[0] != "main" {
|
||||
t.Errorf("expected implicit main agent, got %v", ids)
|
||||
}
|
||||
|
||||
agent, ok := registry.GetAgent("main")
|
||||
if !ok || agent == nil {
|
||||
t.Fatal("expected to find 'main' agent")
|
||||
}
|
||||
if agent.ID != "main" {
|
||||
t.Errorf("agent.ID = %q, want 'main'", agent.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentRegistry_ExplicitAgents(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "sales", Default: true, Name: "Sales Bot"},
|
||||
{ID: "support", Name: "Support Bot"},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
ids := registry.ListAgentIDs()
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected 2 agents, got %d: %v", len(ids), ids)
|
||||
}
|
||||
|
||||
sales, ok := registry.GetAgent("sales")
|
||||
if !ok || sales == nil {
|
||||
t.Fatal("expected to find 'sales' agent")
|
||||
}
|
||||
if sales.Name != "Sales Bot" {
|
||||
t.Errorf("sales.Name = %q, want 'Sales Bot'", sales.Name)
|
||||
}
|
||||
|
||||
support, ok := registry.GetAgent("support")
|
||||
if !ok || support == nil {
|
||||
t.Fatal("expected to find 'support' agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_GetAgent_Normalize(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "my-agent", Default: true},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, ok := registry.GetAgent("My-Agent")
|
||||
if !ok || agent == nil {
|
||||
t.Fatal("expected to find agent with normalized ID")
|
||||
}
|
||||
if agent.ID != "my-agent" {
|
||||
t.Errorf("agent.ID = %q, want 'my-agent'", agent.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_GetDefaultAgent(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "alpha"},
|
||||
{ID: "beta", Default: true},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
// GetDefaultAgent first checks for "main", then returns any
|
||||
agent := registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected a default agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_CanSpawnSubagent(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{
|
||||
ID: "parent",
|
||||
Default: true,
|
||||
Subagents: &config.SubagentsConfig{
|
||||
AllowAgents: []string{"child1", "child2"},
|
||||
},
|
||||
},
|
||||
{ID: "child1"},
|
||||
{ID: "child2"},
|
||||
{ID: "restricted"},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
if !registry.CanSpawnSubagent("parent", "child1") {
|
||||
t.Error("expected parent to be allowed to spawn child1")
|
||||
}
|
||||
if !registry.CanSpawnSubagent("parent", "child2") {
|
||||
t.Error("expected parent to be allowed to spawn child2")
|
||||
}
|
||||
if registry.CanSpawnSubagent("parent", "restricted") {
|
||||
t.Error("expected parent to NOT be allowed to spawn restricted")
|
||||
}
|
||||
if registry.CanSpawnSubagent("child1", "child2") {
|
||||
t.Error("expected child1 to NOT be allowed to spawn (no subagents config)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentRegistry_CanSpawnSubagent_Wildcard(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{
|
||||
ID: "admin",
|
||||
Default: true,
|
||||
Subagents: &config.SubagentsConfig{
|
||||
AllowAgents: []string{"*"},
|
||||
},
|
||||
},
|
||||
{ID: "any-agent"},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
if !registry.CanSpawnSubagent("admin", "any-agent") {
|
||||
t.Error("expected wildcard to allow spawning any agent")
|
||||
}
|
||||
if !registry.CanSpawnSubagent("admin", "nonexistent") {
|
||||
t.Error("expected wildcard to allow spawning even nonexistent agents")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentInstance_Model(t *testing.T) {
|
||||
model := &config.AgentModelConfig{Primary: "claude-opus"}
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "custom", Default: true, Model: model},
|
||||
})
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, _ := registry.GetAgent("custom")
|
||||
if agent.Model != "claude-opus" {
|
||||
t.Errorf("agent.Model = %q, want 'claude-opus'", agent.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentInstance_FallbackInheritance(t *testing.T) {
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "inherit", Default: true},
|
||||
})
|
||||
cfg.Agents.Defaults.ModelFallbacks = []string{"openai/gpt-4o-mini", "anthropic/haiku"}
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, _ := registry.GetAgent("inherit")
|
||||
if len(agent.Fallbacks) != 2 {
|
||||
t.Errorf("expected 2 fallbacks inherited from defaults, got %d", len(agent.Fallbacks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentInstance_FallbackExplicitEmpty(t *testing.T) {
|
||||
model := &config.AgentModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: []string{}, // explicitly empty = disable
|
||||
}
|
||||
cfg := testCfg([]config.AgentConfig{
|
||||
{ID: "no-fallback", Default: true, Model: model},
|
||||
})
|
||||
cfg.Agents.Defaults.ModelFallbacks = []string{"should-not-inherit"}
|
||||
registry := NewAgentRegistry(cfg, &mockRegistryProvider{})
|
||||
|
||||
agent, _ := registry.GetAgent("no-fallback")
|
||||
if len(agent.Fallbacks) != 0 {
|
||||
t.Errorf("expected 0 fallbacks (explicit empty), got %d: %v", len(agent.Fallbacks), agent.Fallbacks)
|
||||
}
|
||||
}
|
||||
+6
-11
@@ -2,7 +2,6 @@ package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -87,17 +86,13 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
|
||||
return
|
||||
}
|
||||
|
||||
// Build session key: channel:chatID
|
||||
sessionKey := fmt.Sprintf("%s:%s", c.name, chatID)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: c.name,
|
||||
SenderID: senderID,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Media: media,
|
||||
SessionKey: sessionKey,
|
||||
Metadata: metadata,
|
||||
Channel: c.name,
|
||||
SenderID: senderID,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
Media: media,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
c.bus.PublishInbound(msg)
|
||||
|
||||
+10
-128
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
@@ -106,7 +105,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
|
||||
chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
|
||||
@@ -117,132 +116,6 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitMessage splits long messages into chunks, preserving code block integrity
|
||||
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
|
||||
func splitMessage(content string, limit int) []string {
|
||||
var messages []string
|
||||
|
||||
for len(content) > 0 {
|
||||
if len(content) <= limit {
|
||||
messages = append(messages, content)
|
||||
break
|
||||
}
|
||||
|
||||
msgEnd := limit
|
||||
|
||||
// Find natural split point within the limit
|
||||
msgEnd = findLastNewline(content[:limit], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:limit], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
// Check if this would end with an incomplete code block
|
||||
candidate := content[:msgEnd]
|
||||
unclosedIdx := findLastUnclosedCodeBlock(candidate)
|
||||
|
||||
if unclosedIdx >= 0 {
|
||||
// Message would end with incomplete code block
|
||||
// Try to extend to include the closing ``` (with some buffer)
|
||||
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
|
||||
if len(content) > extendedLimit {
|
||||
closingIdx := findNextClosingCodeBlock(content, msgEnd)
|
||||
if closingIdx > 0 && closingIdx <= extendedLimit {
|
||||
// Extend to include the closing ```
|
||||
msgEnd = closingIdx
|
||||
} else {
|
||||
// Can't find closing, split before the code block
|
||||
msgEnd = findLastNewline(content[:unclosedIdx], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:unclosedIdx], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = unclosedIdx
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Remaining content fits within extended limit
|
||||
msgEnd = len(content)
|
||||
}
|
||||
}
|
||||
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
messages = append(messages, content[:msgEnd])
|
||||
content = strings.TrimSpace(content[msgEnd:])
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
|
||||
// Returns the position of the opening ``` or -1 if all code blocks are complete
|
||||
func findLastUnclosedCodeBlock(text string) int {
|
||||
count := 0
|
||||
lastOpenIdx := -1
|
||||
|
||||
for i := 0; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
if count == 0 {
|
||||
lastOpenIdx = i
|
||||
}
|
||||
count++
|
||||
i += 2
|
||||
}
|
||||
}
|
||||
|
||||
// If odd number of ``` markers, last one is unclosed
|
||||
if count%2 == 1 {
|
||||
return lastOpenIdx
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findNextClosingCodeBlock finds the next closing ``` starting from a position
|
||||
// Returns the position after the closing ``` or -1 if not found
|
||||
func findNextClosingCodeBlock(text string, startIdx int) int {
|
||||
for i := startIdx; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
return i + 3
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastNewline finds the last newline character within the last N characters
|
||||
// Returns the position of the newline or -1 if not found
|
||||
func findLastNewline(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == '\n' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastSpace finds the last space character within the last N characters
|
||||
// Returns the position of the space or -1 if not found
|
||||
func findLastSpace(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == ' ' || s[i] == '\t' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
|
||||
// 使用传入的 ctx 进行超时控制
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
@@ -376,6 +249,13 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := m.ChannelID
|
||||
if m.GuildID == "" {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": m.ID,
|
||||
"user_id": senderID,
|
||||
@@ -384,6 +264,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
"guild_id": m.GuildID,
|
||||
"channel_id": m.ChannelID,
|
||||
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
|
||||
|
||||
@@ -18,7 +18,6 @@ type MaixCamChannel struct {
|
||||
listener net.Listener
|
||||
clients map[net.Conn]bool
|
||||
clientsMux sync.RWMutex
|
||||
running bool
|
||||
}
|
||||
|
||||
type MaixCamMessage struct {
|
||||
@@ -35,7 +34,6 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
clients: make(map[net.Conn]bool),
|
||||
running: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+498
-209
@@ -4,9 +4,11 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -14,20 +16,28 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type OneBotChannel struct {
|
||||
*BaseChannel
|
||||
config config.OneBotConfig
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
dedup map[string]struct{}
|
||||
dedupRing []string
|
||||
dedupIdx int
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
echoCounter int64
|
||||
config config.OneBotConfig
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
dedup map[string]struct{}
|
||||
dedupRing []string
|
||||
dedupIdx int
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
echoCounter int64
|
||||
selfID int64
|
||||
pending map[string]chan json.RawMessage
|
||||
pendingMu sync.Mutex
|
||||
transcriber *voice.GroqTranscriber
|
||||
lastMessageID sync.Map
|
||||
pendingEmojiMsg sync.Map
|
||||
}
|
||||
|
||||
type oneBotRawEvent struct {
|
||||
@@ -43,9 +53,11 @@ type oneBotRawEvent struct {
|
||||
SelfID json.RawMessage `json:"self_id"`
|
||||
Time json.RawMessage `json:"time"`
|
||||
MetaEventType string `json:"meta_event_type"`
|
||||
NoticeType string `json:"notice_type"`
|
||||
Echo string `json:"echo"`
|
||||
RetCode json.RawMessage `json:"retcode"`
|
||||
Status BotStatus `json:"status"`
|
||||
Status json.RawMessage `json:"status"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type BotStatus struct {
|
||||
@@ -53,42 +65,36 @@ type BotStatus struct {
|
||||
Good bool `json:"good"`
|
||||
}
|
||||
|
||||
func isAPIResponse(raw json.RawMessage) bool {
|
||||
if len(raw) == 0 {
|
||||
return false
|
||||
}
|
||||
var s string
|
||||
if json.Unmarshal(raw, &s) == nil {
|
||||
return s == "ok" || s == "failed"
|
||||
}
|
||||
var bs BotStatus
|
||||
if json.Unmarshal(raw, &bs) == nil {
|
||||
return bs.Online || bs.Good
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type oneBotSender struct {
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
Card string `json:"card"`
|
||||
}
|
||||
|
||||
type oneBotEvent struct {
|
||||
PostType string
|
||||
MessageType string
|
||||
SubType string
|
||||
MessageID string
|
||||
UserID int64
|
||||
GroupID int64
|
||||
Content string
|
||||
RawContent string
|
||||
IsBotMentioned bool
|
||||
Sender oneBotSender
|
||||
SelfID int64
|
||||
Time int64
|
||||
MetaEventType string
|
||||
}
|
||||
|
||||
type oneBotAPIRequest struct {
|
||||
Action string `json:"action"`
|
||||
Params interface{} `json:"params"`
|
||||
Echo string `json:"echo,omitempty"`
|
||||
}
|
||||
|
||||
type oneBotSendPrivateMsgParams struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type oneBotSendGroupMsgParams struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
Message string `json:"message"`
|
||||
type oneBotMessageSegment struct {
|
||||
Type string `json:"type"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
|
||||
@@ -101,9 +107,30 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One
|
||||
dedup: make(map[string]struct{}, dedupSize),
|
||||
dedupRing: make([]string, dedupSize),
|
||||
dedupIdx: 0,
|
||||
pending: make(map[string]chan json.RawMessage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) {
|
||||
go func() {
|
||||
_, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{
|
||||
"message_id": messageID,
|
||||
"emoji_id": emojiID,
|
||||
"set": set,
|
||||
}, 5*time.Second)
|
||||
if err != nil {
|
||||
logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{
|
||||
"message_id": messageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Start(ctx context.Context) error {
|
||||
if c.config.WSUrl == "" {
|
||||
return fmt.Errorf("OneBot ws_url not configured")
|
||||
@@ -121,12 +148,12 @@ func (c *OneBotChannel) Start(ctx context.Context) error {
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
c.fetchSelfID()
|
||||
}
|
||||
|
||||
if c.config.ReconnectInterval > 0 {
|
||||
go c.reconnectLoop()
|
||||
} else {
|
||||
// If reconnect is disabled but initial connection failed, we cannot recover
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
|
||||
}
|
||||
@@ -152,14 +179,141 @@ func (c *OneBotChannel) connect() error {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.SetPongHandler(func(appData string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.pinger(conn)
|
||||
|
||||
logger.InfoC("onebot", "WebSocket connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) pinger(conn *websocket.Conn) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.writeMu.Lock()
|
||||
err := conn.WriteMessage(websocket.PingMessage, nil)
|
||||
c.writeMu.Unlock()
|
||||
if err != nil {
|
||||
logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) fetchSelfID() {
|
||||
resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second)
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type loginInfo struct {
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
}
|
||||
for _, extract := range []func() (*loginInfo, error){
|
||||
func() (*loginInfo, error) {
|
||||
var w struct {
|
||||
Data loginInfo `json:"data"`
|
||||
}
|
||||
err := json.Unmarshal(resp, &w)
|
||||
return &w.Data, err
|
||||
},
|
||||
func() (*loginInfo, error) {
|
||||
var f loginInfo
|
||||
err := json.Unmarshal(resp, &f)
|
||||
return &f, err
|
||||
},
|
||||
} {
|
||||
info, err := extract()
|
||||
if err != nil || len(info.UserID) == 0 {
|
||||
continue
|
||||
}
|
||||
if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 {
|
||||
atomic.StoreInt64(&c.selfID, uid)
|
||||
logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{
|
||||
"self_id": uid,
|
||||
"nickname": info.Nickname,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{
|
||||
"response": string(resp),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) {
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("WebSocket not connected")
|
||||
}
|
||||
|
||||
echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1))
|
||||
|
||||
ch := make(chan json.RawMessage, 1)
|
||||
c.pendingMu.Lock()
|
||||
c.pending[echo] = ch
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.pendingMu.Lock()
|
||||
delete(c.pending, echo)
|
||||
c.pendingMu.Unlock()
|
||||
}()
|
||||
|
||||
req := oneBotAPIRequest{
|
||||
Action: action,
|
||||
Params: params,
|
||||
Echo: echo,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal API request: %w", err)
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write API request: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case resp := <-ch:
|
||||
return resp, nil
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("API request %s timed out after %v", action, timeout)
|
||||
case <-c.ctx.Done():
|
||||
return nil, fmt.Errorf("context cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) reconnectLoop() {
|
||||
interval := time.Duration(c.config.ReconnectInterval) * time.Second
|
||||
if interval < 5*time.Second {
|
||||
@@ -183,6 +337,7 @@ func (c *OneBotChannel) reconnectLoop() {
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
c.fetchSelfID()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,6 +352,13 @@ func (c *OneBotChannel) Stop(ctx context.Context) error {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.pendingMu.Lock()
|
||||
for echo, ch := range c.pending {
|
||||
close(ch)
|
||||
delete(c.pending, echo)
|
||||
}
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
@@ -225,10 +387,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
|
||||
return err
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
c.echoCounter++
|
||||
echo := fmt.Sprintf("send_%d", c.echoCounter)
|
||||
c.writeMu.Unlock()
|
||||
echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1))
|
||||
|
||||
req := oneBotAPIRequest{
|
||||
Action: action,
|
||||
@@ -252,67 +411,78 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
|
||||
return err
|
||||
}
|
||||
|
||||
if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok {
|
||||
if mid, ok := msgID.(string); ok && mid != "" {
|
||||
c.setMsgEmojiLike(mid, 289, false)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment {
|
||||
var segments []oneBotMessageSegment
|
||||
|
||||
if lastMsgID, ok := c.lastMessageID.Load(chatID); ok {
|
||||
if msgID, ok := lastMsgID.(string); ok && msgID != "" {
|
||||
segments = append(segments, oneBotMessageSegment{
|
||||
Type: "reply",
|
||||
Data: map[string]interface{}{"id": msgID},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
segments = append(segments, oneBotMessageSegment{
|
||||
Type: "text",
|
||||
Data: map[string]interface{}{"text": content},
|
||||
})
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
|
||||
chatID := msg.ChatID
|
||||
segments := c.buildMessageSegments(chatID, msg.Content)
|
||||
|
||||
if len(chatID) > 6 && chatID[:6] == "group:" {
|
||||
groupID, err := strconv.ParseInt(chatID[6:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_group_msg", oneBotSendGroupMsgParams{
|
||||
GroupID: groupID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
var action, idKey string
|
||||
var rawID string
|
||||
if rest, ok := strings.CutPrefix(chatID, "group:"); ok {
|
||||
action, idKey, rawID = "send_group_msg", "group_id", rest
|
||||
} else if rest, ok := strings.CutPrefix(chatID, "private:"); ok {
|
||||
action, idKey, rawID = "send_private_msg", "user_id", rest
|
||||
} else {
|
||||
action, idKey, rawID = "send_private_msg", "user_id", chatID
|
||||
}
|
||||
|
||||
if len(chatID) > 8 && chatID[:8] == "private:" {
|
||||
userID, err := strconv.ParseInt(chatID[8:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(chatID, 10, 64)
|
||||
id, err := strconv.ParseInt(rawID, 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID)
|
||||
return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID)
|
||||
}
|
||||
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
return action, map[string]interface{}{idKey: id, "message": segments}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) listen() {
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
|
||||
return
|
||||
}
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
if c.conn == conn {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
@@ -320,10 +490,7 @@ func (c *OneBotChannel) listen() {
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{
|
||||
"length": len(message),
|
||||
"payload": string(message),
|
||||
})
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
var raw oneBotRawEvent
|
||||
if err := json.Unmarshal(message, &raw); err != nil {
|
||||
@@ -334,20 +501,37 @@ func (c *OneBotChannel) listen() {
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Echo != "" || raw.Status.Online || raw.Status.Good {
|
||||
logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{
|
||||
"length": len(message),
|
||||
"post_type": raw.PostType,
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
|
||||
if raw.Echo != "" {
|
||||
c.pendingMu.Lock()
|
||||
ch, ok := c.pending[raw.Echo]
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
if ok {
|
||||
select {
|
||||
case ch <- message:
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": string(raw.Status),
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
"message_type": raw.MessageType,
|
||||
"sub_type": raw.SubType,
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
if isAPIResponse(raw.Status) {
|
||||
logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{
|
||||
"status": string(raw.Status),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
c.handleRawEvent(&raw)
|
||||
}
|
||||
@@ -386,9 +570,12 @@ func parseJSONString(raw json.RawMessage) string {
|
||||
type parseMessageResult struct {
|
||||
Text string
|
||||
IsBotMentioned bool
|
||||
Media []string
|
||||
LocalFiles []string
|
||||
ReplyTo string
|
||||
}
|
||||
|
||||
func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult {
|
||||
func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult {
|
||||
if len(raw) == 0 {
|
||||
return parseMessageResult{}
|
||||
}
|
||||
@@ -408,60 +595,155 @@ func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult
|
||||
}
|
||||
|
||||
var segments []map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &segments); err == nil {
|
||||
var text string
|
||||
mentioned := false
|
||||
selfIDStr := strconv.FormatInt(selfID, 10)
|
||||
for _, seg := range segments {
|
||||
segType, _ := seg["type"].(string)
|
||||
data, _ := seg["data"].(map[string]interface{})
|
||||
switch segType {
|
||||
case "text":
|
||||
if data != nil {
|
||||
if t, ok := data["text"].(string); ok {
|
||||
text += t
|
||||
}
|
||||
if err := json.Unmarshal(raw, &segments); err != nil {
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
var textParts []string
|
||||
mentioned := false
|
||||
selfIDStr := strconv.FormatInt(selfID, 10)
|
||||
var media []string
|
||||
var localFiles []string
|
||||
var replyTo string
|
||||
|
||||
for _, seg := range segments {
|
||||
segType, _ := seg["type"].(string)
|
||||
data, _ := seg["data"].(map[string]interface{})
|
||||
|
||||
switch segType {
|
||||
case "text":
|
||||
if data != nil {
|
||||
if t, ok := data["text"].(string); ok {
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
case "at":
|
||||
if data != nil && selfID > 0 {
|
||||
qqVal := fmt.Sprintf("%v", data["qq"])
|
||||
if qqVal == selfIDStr || qqVal == "all" {
|
||||
mentioned = true
|
||||
}
|
||||
|
||||
case "at":
|
||||
if data != nil && selfID > 0 {
|
||||
qqVal := fmt.Sprintf("%v", data["qq"])
|
||||
if qqVal == selfIDStr || qqVal == "all" {
|
||||
mentioned = true
|
||||
}
|
||||
}
|
||||
|
||||
case "image", "video", "file":
|
||||
if data != nil {
|
||||
url, _ := data["url"].(string)
|
||||
if url != "" {
|
||||
defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"}
|
||||
filename := defaults[segType]
|
||||
if f, ok := data["file"].(string); ok && f != "" {
|
||||
filename = f
|
||||
} else if n, ok := data["name"].(string); ok && n != "" {
|
||||
filename = n
|
||||
}
|
||||
localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "onebot",
|
||||
})
|
||||
if localPath != "" {
|
||||
media = append(media, localPath)
|
||||
localFiles = append(localFiles, localPath)
|
||||
textParts = append(textParts, fmt.Sprintf("[%s]", segType))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "record":
|
||||
if data != nil {
|
||||
url, _ := data["url"].(string)
|
||||
if url != "" {
|
||||
localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{
|
||||
LoggerPrefix: "onebot",
|
||||
})
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||
result, err := c.transcriber.Transcribe(tctx, localPath)
|
||||
tcancel()
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
textParts = append(textParts, "[voice (transcription failed)]")
|
||||
media = append(media, localPath)
|
||||
} else {
|
||||
textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text))
|
||||
}
|
||||
} else {
|
||||
textParts = append(textParts, "[voice]")
|
||||
media = append(media, localPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "reply":
|
||||
if data != nil {
|
||||
if id, ok := data["id"]; ok {
|
||||
replyTo = fmt.Sprintf("%v", id)
|
||||
}
|
||||
}
|
||||
|
||||
case "face":
|
||||
if data != nil {
|
||||
faceID, _ := data["id"]
|
||||
textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID))
|
||||
}
|
||||
|
||||
case "forward":
|
||||
textParts = append(textParts, "[forward message]")
|
||||
|
||||
default:
|
||||
|
||||
}
|
||||
return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned}
|
||||
}
|
||||
return parseMessageResult{}
|
||||
|
||||
return parseMessageResult{
|
||||
Text: strings.TrimSpace(strings.Join(textParts, "")),
|
||||
IsBotMentioned: mentioned,
|
||||
Media: media,
|
||||
LocalFiles: localFiles,
|
||||
ReplyTo: replyTo,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
|
||||
switch raw.PostType {
|
||||
case "message":
|
||||
evt, err := c.normalizeMessageEvent(raw)
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 {
|
||||
if !c.IsAllowed(strconv.FormatInt(userID, 10)) {
|
||||
logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.handleMessage(evt)
|
||||
c.handleMessage(raw)
|
||||
|
||||
case "message_sent":
|
||||
logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{
|
||||
"message_type": raw.MessageType,
|
||||
"message_id": parseJSONString(raw.MessageID),
|
||||
})
|
||||
|
||||
case "meta_event":
|
||||
c.handleMetaEvent(raw)
|
||||
|
||||
case "notice":
|
||||
logger.DebugCF("onebot", "Notice event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
c.handleNoticeEvent(raw)
|
||||
|
||||
case "request":
|
||||
logger.DebugCF("onebot", "Request event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
|
||||
case "":
|
||||
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
@@ -469,18 +751,51 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) {
|
||||
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
|
||||
if raw.MetaEventType == "lifecycle" {
|
||||
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType})
|
||||
} else if raw.MetaEventType != "heartbeat" {
|
||||
logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) {
|
||||
fields := map[string]interface{}{
|
||||
"notice_type": raw.NoticeType,
|
||||
"sub_type": raw.SubType,
|
||||
"group_id": parseJSONString(raw.GroupID),
|
||||
"user_id": parseJSONString(raw.UserID),
|
||||
"message_id": parseJSONString(raw.MessageID),
|
||||
}
|
||||
switch raw.NoticeType {
|
||||
case "group_recall", "group_increase", "group_decrease",
|
||||
"friend_add", "group_admin", "group_ban":
|
||||
logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields)
|
||||
default:
|
||||
logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
// Parse fields from raw event
|
||||
userID, err := parseJSONInt64(raw.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID))
|
||||
logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"raw": string(raw.UserID),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
groupID, _ := parseJSONInt64(raw.GroupID)
|
||||
selfID, _ := parseJSONInt64(raw.SelfID)
|
||||
ts, _ := parseJSONInt64(raw.Time)
|
||||
messageID := parseJSONString(raw.MessageID)
|
||||
|
||||
parsed := parseMessageContentEx(raw.Message, selfID)
|
||||
if selfID == 0 {
|
||||
selfID = atomic.LoadInt64(&c.selfID)
|
||||
}
|
||||
|
||||
parsed := c.parseMessageSegments(raw.Message, selfID)
|
||||
isBotMentioned := parsed.IsBotMentioned
|
||||
|
||||
content := raw.RawMessage
|
||||
@@ -495,6 +810,10 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") {
|
||||
content = parsed.Text
|
||||
}
|
||||
|
||||
var sender oneBotSender
|
||||
if len(raw.Sender) > 0 {
|
||||
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
|
||||
@@ -505,137 +824,107 @@ func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent
|
||||
}
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{
|
||||
"message_type": raw.MessageType,
|
||||
"user_id": userID,
|
||||
"group_id": groupID,
|
||||
"message_id": messageID,
|
||||
"content_len": len(content),
|
||||
"nickname": sender.Nickname,
|
||||
})
|
||||
|
||||
return &oneBotEvent{
|
||||
PostType: raw.PostType,
|
||||
MessageType: raw.MessageType,
|
||||
SubType: raw.SubType,
|
||||
MessageID: messageID,
|
||||
UserID: userID,
|
||||
GroupID: groupID,
|
||||
Content: content,
|
||||
RawContent: raw.RawMessage,
|
||||
IsBotMentioned: isBotMentioned,
|
||||
Sender: sender,
|
||||
SelfID: selfID,
|
||||
Time: ts,
|
||||
MetaEventType: raw.MetaEventType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
|
||||
switch raw.MetaEventType {
|
||||
case "lifecycle":
|
||||
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "heartbeat":
|
||||
logger.DebugC("onebot", "Heartbeat received")
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
// Clean up temp files when done
|
||||
if len(parsed.LocalFiles) > 0 {
|
||||
defer func() {
|
||||
for _, f := range parsed.LocalFiles {
|
||||
if err := os.Remove(f); err != nil {
|
||||
logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{
|
||||
"path": f,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMessage(evt *oneBotEvent) {
|
||||
if c.isDuplicate(evt.MessageID) {
|
||||
if c.isDuplicate(messageID) {
|
||||
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
"message_id": messageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
content := evt.Content
|
||||
if content == "" {
|
||||
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
"message_id": messageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := strconv.FormatInt(evt.UserID, 10)
|
||||
senderID := strconv.FormatInt(userID, 10)
|
||||
var chatID string
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": evt.MessageID,
|
||||
"message_id": messageID,
|
||||
}
|
||||
|
||||
switch evt.MessageType {
|
||||
if parsed.ReplyTo != "" {
|
||||
metadata["reply_to_message_id"] = parsed.ReplyTo
|
||||
}
|
||||
|
||||
switch raw.MessageType {
|
||||
case "private":
|
||||
chatID = "private:" + senderID
|
||||
logger.InfoCF("onebot", "Received private message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"message_id": evt.MessageID,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(evt.GroupID, 10)
|
||||
groupIDStr := strconv.FormatInt(groupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(evt.Sender.UserID)
|
||||
senderUserID, _ := parseJSONInt64(sender.UserID)
|
||||
if senderUserID > 0 {
|
||||
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
|
||||
}
|
||||
|
||||
if evt.Sender.Card != "" {
|
||||
metadata["sender_name"] = evt.Sender.Card
|
||||
} else if evt.Sender.Nickname != "" {
|
||||
metadata["sender_name"] = evt.Sender.Nickname
|
||||
if sender.Card != "" {
|
||||
metadata["sender_name"] = sender.Card
|
||||
} else if sender.Nickname != "" {
|
||||
metadata["sender_name"] = sender.Nickname
|
||||
}
|
||||
|
||||
triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned)
|
||||
triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned)
|
||||
if !triggered {
|
||||
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"is_mentioned": isBotMentioned,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
return
|
||||
}
|
||||
content = strippedContent
|
||||
|
||||
logger.InfoCF("onebot", "Received group message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"message_id": evt.MessageID,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
default:
|
||||
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
|
||||
"type": evt.MessageType,
|
||||
"message_id": evt.MessageID,
|
||||
"user_id": evt.UserID,
|
||||
"type": raw.MessageType,
|
||||
"message_id": messageID,
|
||||
"user_id": userID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if evt.Sender.Nickname != "" {
|
||||
metadata["nickname"] = evt.Sender.Nickname
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"content": truncate(content, 100),
|
||||
logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
"media_count": len(parsed.Media),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
|
||||
if sender.Nickname != "" {
|
||||
metadata["nickname"] = sender.Nickname
|
||||
}
|
||||
|
||||
c.lastMessageID.Store(chatID, messageID)
|
||||
|
||||
if raw.MessageType == "group" && messageID != "" && messageID != "0" {
|
||||
c.setMsgEmojiLike(messageID, 289, true)
|
||||
c.pendingEmojiMsg.Store(chatID, messageID)
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, parsed.Media, metadata)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
|
||||
@@ -25,6 +25,7 @@ type SlackChannel struct {
|
||||
api *slack.Client
|
||||
socketClient *socketmode.Client
|
||||
botUserID string
|
||||
teamID string
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -72,6 +73,7 @@ func (c *SlackChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("slack auth test failed: %w", err)
|
||||
}
|
||||
c.botUserID = authResp.UserID
|
||||
c.teamID = authResp.TeamID
|
||||
|
||||
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
|
||||
"bot_user_id": c.botUserID,
|
||||
@@ -274,11 +276,21 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Received message", map[string]interface{}{
|
||||
@@ -331,12 +343,22 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
mentionPeerKind := "channel"
|
||||
mentionPeerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
mentionPeerKind = "direct"
|
||||
mentionPeerID = senderID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
"is_mention": "true",
|
||||
"peer_kind": mentionPeerKind,
|
||||
"peer_id": mentionPeerID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
@@ -373,6 +395,9 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
"platform": "slack",
|
||||
"is_command": "true",
|
||||
"trigger_id": cmd.TriggerID,
|
||||
"peer_kind": "channel",
|
||||
"peer_id": channelID,
|
||||
"team_id": c.teamID,
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
|
||||
|
||||
@@ -347,12 +347,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := fmt.Sprintf("%d", user.ID)
|
||||
if message.Chat.Type != "private" {
|
||||
peerKind = "group"
|
||||
peerID = fmt.Sprintf("%d", chatID)
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": fmt.Sprintf("%d", message.MessageID),
|
||||
"user_id": fmt.Sprintf("%d", user.ID),
|
||||
"username": user.Username,
|
||||
"first_name": user.FirstName,
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
"peer_kind": peerKind,
|
||||
"peer_id": peerID,
|
||||
}
|
||||
|
||||
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
|
||||
+166
-23
@@ -45,6 +45,8 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
|
||||
|
||||
type Config struct {
|
||||
Agents AgentsConfig `json:"agents"`
|
||||
Bindings []AgentBinding `json:"bindings,omitempty"`
|
||||
Session SessionConfig `json:"session,omitempty"`
|
||||
Channels ChannelsConfig `json:"channels"`
|
||||
Providers ProvidersConfig `json:"providers"`
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
@@ -56,16 +58,97 @@ type Config struct {
|
||||
|
||||
type AgentsConfig struct {
|
||||
Defaults AgentDefaults `json:"defaults"`
|
||||
List []AgentConfig `json:"list,omitempty"`
|
||||
}
|
||||
|
||||
// AgentModelConfig supports both string and structured model config.
|
||||
// String format: "gpt-4" (just primary, no fallbacks)
|
||||
// Object format: {"primary": "gpt-4", "fallbacks": ["claude-haiku"]}
|
||||
type AgentModelConfig struct {
|
||||
Primary string `json:"primary,omitempty"`
|
||||
Fallbacks []string `json:"fallbacks,omitempty"`
|
||||
}
|
||||
|
||||
func (m *AgentModelConfig) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
m.Primary = s
|
||||
m.Fallbacks = nil
|
||||
return nil
|
||||
}
|
||||
type raw struct {
|
||||
Primary string `json:"primary"`
|
||||
Fallbacks []string `json:"fallbacks"`
|
||||
}
|
||||
var r raw
|
||||
if err := json.Unmarshal(data, &r); err != nil {
|
||||
return err
|
||||
}
|
||||
m.Primary = r.Primary
|
||||
m.Fallbacks = r.Fallbacks
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m AgentModelConfig) MarshalJSON() ([]byte, error) {
|
||||
if len(m.Fallbacks) == 0 && m.Primary != "" {
|
||||
return json.Marshal(m.Primary)
|
||||
}
|
||||
type raw struct {
|
||||
Primary string `json:"primary,omitempty"`
|
||||
Fallbacks []string `json:"fallbacks,omitempty"`
|
||||
}
|
||||
return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks})
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
ID string `json:"id"`
|
||||
Default bool `json:"default,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Workspace string `json:"workspace,omitempty"`
|
||||
Model *AgentModelConfig `json:"model,omitempty"`
|
||||
Skills []string `json:"skills,omitempty"`
|
||||
Subagents *SubagentsConfig `json:"subagents,omitempty"`
|
||||
}
|
||||
|
||||
type SubagentsConfig struct {
|
||||
AllowAgents []string `json:"allow_agents,omitempty"`
|
||||
Model *AgentModelConfig `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type PeerMatch struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
type BindingMatch struct {
|
||||
Channel string `json:"channel"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
Peer *PeerMatch `json:"peer,omitempty"`
|
||||
GuildID string `json:"guild_id,omitempty"`
|
||||
TeamID string `json:"team_id,omitempty"`
|
||||
}
|
||||
|
||||
type AgentBinding struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Match BindingMatch `json:"match"`
|
||||
}
|
||||
|
||||
type SessionConfig struct {
|
||||
DMScope string `json:"dm_scope,omitempty"`
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
|
||||
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
|
||||
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
}
|
||||
|
||||
type ChannelsConfig struct {
|
||||
@@ -167,19 +250,19 @@ type DevicesConfig struct {
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Ollama ProviderConfig `json:"ollama"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI OpenAIProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Ollama ProviderConfig `json:"ollama"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
@@ -190,6 +273,11 @@ type ProviderConfig struct {
|
||||
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
|
||||
}
|
||||
|
||||
type OpenAIProviderConfig struct {
|
||||
ProviderConfig
|
||||
WebSearch bool `json:"web_search" env:"PICOCLAW_PROVIDERS_OPENAI_WEB_SEARCH"`
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
@@ -206,14 +294,32 @@ type DuckDuckGoConfig struct {
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type PerplexityConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
Brave BraveConfig `json:"brave"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
||||
}
|
||||
|
||||
type ExecConfig struct {
|
||||
EnableDenyPatterns bool `json:"enable_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS"`
|
||||
CustomDenyPatterns []string `json:"custom_deny_patterns" env:"PICOCLAW_TOOLS_EXEC_CUSTOM_DENY_PATTERNS"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
Web WebToolsConfig `json:"web"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
Exec ExecConfig `json:"exec"`
|
||||
MCP MCPConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
// MCPServerConfig defines configuration for a single MCP server
|
||||
@@ -325,7 +431,7 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenAI: OpenAIProviderConfig{WebSearch: true},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
@@ -350,6 +456,17 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
Exec: ExecConfig{
|
||||
EnableDenyPatterns: true,
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: false,
|
||||
@@ -460,6 +577,32 @@ func (c *Config) GetAPIBase() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ModelConfig holds primary model and fallback list.
|
||||
type ModelConfig struct {
|
||||
Primary string
|
||||
Fallbacks []string
|
||||
}
|
||||
|
||||
// GetModelConfig returns the text model configuration with fallbacks.
|
||||
func (c *Config) GetModelConfig() ModelConfig {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return ModelConfig{
|
||||
Primary: c.Agents.Defaults.Model,
|
||||
Fallbacks: c.Agents.Defaults.ModelFallbacks,
|
||||
}
|
||||
}
|
||||
|
||||
// GetImageModelConfig returns the image model configuration with fallbacks.
|
||||
func (c *Config) GetImageModelConfig() ModelConfig {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return ModelConfig{
|
||||
Primary: c.Agents.Defaults.ImageModel,
|
||||
Fallbacks: c.Agents.Defaults.ImageModelFallbacks,
|
||||
}
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if path == "" {
|
||||
return path
|
||||
|
||||
+220
-32
@@ -1,12 +1,193 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
|
||||
t.Fatalf("unmarshal string: %v", err)
|
||||
}
|
||||
if m.Primary != "gpt-4" {
|
||||
t.Errorf("Primary = %q, want 'gpt-4'", m.Primary)
|
||||
}
|
||||
if m.Fallbacks != nil {
|
||||
t.Errorf("Fallbacks = %v, want nil", m.Fallbacks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_UnmarshalObject(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
data := `{"primary": "claude-opus", "fallbacks": ["gpt-4o-mini", "haiku"]}`
|
||||
if err := json.Unmarshal([]byte(data), &m); err != nil {
|
||||
t.Fatalf("unmarshal object: %v", err)
|
||||
}
|
||||
if m.Primary != "claude-opus" {
|
||||
t.Errorf("Primary = %q, want 'claude-opus'", m.Primary)
|
||||
}
|
||||
if len(m.Fallbacks) != 2 {
|
||||
t.Fatalf("Fallbacks len = %d, want 2", len(m.Fallbacks))
|
||||
}
|
||||
if m.Fallbacks[0] != "gpt-4o-mini" || m.Fallbacks[1] != "haiku" {
|
||||
t.Errorf("Fallbacks = %v", m.Fallbacks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_MarshalString(t *testing.T) {
|
||||
m := AgentModelConfig{Primary: "gpt-4"}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
if string(data) != `"gpt-4"` {
|
||||
t.Errorf("marshal = %s, want '\"gpt-4\"'", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_MarshalObject(t *testing.T) {
|
||||
m := AgentModelConfig{Primary: "claude-opus", Fallbacks: []string{"haiku"}}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal(data, &result)
|
||||
if result["primary"] != "claude-opus" {
|
||||
t.Errorf("primary = %v", result["primary"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentConfig_FullParse(t *testing.T) {
|
||||
jsonData := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"max_tool_iterations": 20
|
||||
},
|
||||
"list": [
|
||||
{
|
||||
"id": "sales",
|
||||
"default": true,
|
||||
"name": "Sales Bot",
|
||||
"model": "gpt-4"
|
||||
},
|
||||
{
|
||||
"id": "support",
|
||||
"name": "Support Bot",
|
||||
"model": {
|
||||
"primary": "claude-opus",
|
||||
"fallbacks": ["haiku"]
|
||||
},
|
||||
"subagents": {
|
||||
"allow_agents": ["sales"]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"bindings": [
|
||||
{
|
||||
"agent_id": "support",
|
||||
"match": {
|
||||
"channel": "telegram",
|
||||
"account_id": "*",
|
||||
"peer": {"kind": "direct", "id": "user123"}
|
||||
}
|
||||
}
|
||||
],
|
||||
"session": {
|
||||
"dm_scope": "per-peer",
|
||||
"identity_links": {
|
||||
"john": ["telegram:123", "discord:john#1234"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Agents.List) != 2 {
|
||||
t.Fatalf("agents.list len = %d, want 2", len(cfg.Agents.List))
|
||||
}
|
||||
|
||||
sales := cfg.Agents.List[0]
|
||||
if sales.ID != "sales" || !sales.Default || sales.Name != "Sales Bot" {
|
||||
t.Errorf("sales = %+v", sales)
|
||||
}
|
||||
if sales.Model == nil || sales.Model.Primary != "gpt-4" {
|
||||
t.Errorf("sales.Model = %+v", sales.Model)
|
||||
}
|
||||
|
||||
support := cfg.Agents.List[1]
|
||||
if support.ID != "support" || support.Name != "Support Bot" {
|
||||
t.Errorf("support = %+v", support)
|
||||
}
|
||||
if support.Model == nil || support.Model.Primary != "claude-opus" {
|
||||
t.Errorf("support.Model = %+v", support.Model)
|
||||
}
|
||||
if len(support.Model.Fallbacks) != 1 || support.Model.Fallbacks[0] != "haiku" {
|
||||
t.Errorf("support.Model.Fallbacks = %v", support.Model.Fallbacks)
|
||||
}
|
||||
if support.Subagents == nil || len(support.Subagents.AllowAgents) != 1 {
|
||||
t.Errorf("support.Subagents = %+v", support.Subagents)
|
||||
}
|
||||
|
||||
if len(cfg.Bindings) != 1 {
|
||||
t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings))
|
||||
}
|
||||
binding := cfg.Bindings[0]
|
||||
if binding.AgentID != "support" || binding.Match.Channel != "telegram" {
|
||||
t.Errorf("binding = %+v", binding)
|
||||
}
|
||||
if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" {
|
||||
t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer)
|
||||
}
|
||||
|
||||
if cfg.Session.DMScope != "per-peer" {
|
||||
t.Errorf("Session.DMScope = %q", cfg.Session.DMScope)
|
||||
}
|
||||
if len(cfg.Session.IdentityLinks) != 1 {
|
||||
t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks)
|
||||
}
|
||||
links := cfg.Session.IdentityLinks["john"]
|
||||
if len(links) != 2 {
|
||||
t.Errorf("john links = %v", links)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
|
||||
jsonData := `{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := json.Unmarshal([]byte(jsonData), cfg); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Agents.List) != 0 {
|
||||
t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List))
|
||||
}
|
||||
if len(cfg.Bindings) != 0 {
|
||||
t.Errorf("bindings should be empty, got %d", len(cfg.Bindings))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
|
||||
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
@@ -20,8 +201,6 @@ func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
|
||||
func TestDefaultConfig_WorkspacePath(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Just verify the workspace is set, don't compare exact paths
|
||||
// since expandHome behavior may differ based on environment
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
@@ -79,7 +258,6 @@ func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
func TestDefaultConfig_Providers(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all providers are empty by default
|
||||
if cfg.Providers.Anthropic.APIKey != "" {
|
||||
t.Error("Anthropic API key should be empty by default")
|
||||
}
|
||||
@@ -89,46 +267,18 @@ func TestDefaultConfig_Providers(t *testing.T) {
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
t.Error("OpenRouter API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
t.Error("Groq API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
t.Error("Zhipu API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.VLLM.APIKey != "" {
|
||||
t.Error("VLLM API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
t.Error("Gemini API key should be empty by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Channels verifies channels are disabled by default
|
||||
func TestDefaultConfig_Channels(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all channels are disabled by default
|
||||
if cfg.Channels.WhatsApp.Enabled {
|
||||
t.Error("WhatsApp should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Telegram.Enabled {
|
||||
t.Error("Telegram should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Feishu.Enabled {
|
||||
t.Error("Feishu should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Discord.Enabled {
|
||||
t.Error("Discord should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.MaixCam.Enabled {
|
||||
t.Error("MaixCam should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.QQ.Enabled {
|
||||
t.Error("QQ should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.DingTalk.Enabled {
|
||||
t.Error("DingTalk should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Slack.Enabled {
|
||||
t.Error("Slack should be disabled by default")
|
||||
}
|
||||
@@ -178,7 +328,6 @@ func TestSaveConfig_FilePermissions(t *testing.T) {
|
||||
func TestConfig_Complete(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify complete config structure
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
@@ -204,3 +353,42 @@ func TestConfig_Complete(t *testing.T) {
|
||||
t.Error("Heartbeat should be enabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_OpenAIWebSearchEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Providers.OpenAI.WebSearch {
|
||||
t.Fatal("DefaultConfig().Providers.OpenAI.WebSearch should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"api_base":""}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if !cfg.Providers.OpenAI.WebSearch {
|
||||
t.Fatal("OpenAI codex web search should remain true when unset in config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"providers":{"openai":{"web_search":false}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if cfg.Providers.OpenAI.WebSearch {
|
||||
t.Fatal("OpenAI codex web search should be false when disabled in config file")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
// Package constants provides shared constants across the codebase.
|
||||
package constants
|
||||
|
||||
// InternalChannels defines channels that are used for internal communication
|
||||
// internalChannels defines channels that are used for internal communication
|
||||
// and should not be exposed to external users or recorded as last active channel.
|
||||
var InternalChannels = map[string]bool{
|
||||
"cli": true,
|
||||
"system": true,
|
||||
"subagent": true,
|
||||
var internalChannels = map[string]struct{}{
|
||||
"cli": {},
|
||||
"system": {},
|
||||
"subagent": {},
|
||||
}
|
||||
|
||||
// IsInternalChannel returns true if the channel is an internal channel.
|
||||
func IsInternalChannel(channel string) bool {
|
||||
return InternalChannels[channel]
|
||||
_, found := internalChannels[channel]
|
||||
return found
|
||||
}
|
||||
|
||||
+11
-1
@@ -108,7 +108,10 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
|
||||
case "anthropic":
|
||||
cfg.Providers.Anthropic = pc
|
||||
case "openai":
|
||||
cfg.Providers.OpenAI = pc
|
||||
cfg.Providers.OpenAI = config.OpenAIProviderConfig{
|
||||
ProviderConfig: pc,
|
||||
WebSearch: getBoolOrDefault(pMap, "web_search", true),
|
||||
}
|
||||
case "openrouter":
|
||||
cfg.Providers.OpenRouter = pc
|
||||
case "groq":
|
||||
@@ -363,6 +366,13 @@ func getBool(data map[string]interface{}, key string) (bool, bool) {
|
||||
return b, ok
|
||||
}
|
||||
|
||||
func getBoolOrDefault(data map[string]interface{}, key string, defaultVal bool) bool {
|
||||
if v, ok := getBool(data, key); ok {
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func getStringSlice(data map[string]interface{}, key string) []string {
|
||||
v, ok := data[key]
|
||||
if !ok {
|
||||
|
||||
@@ -299,6 +299,24 @@ func TestConvertConfig(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSupportedProvidersCompatibility(t *testing.T) {
|
||||
expected := []string{
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"groq",
|
||||
"zhipu",
|
||||
"vllm",
|
||||
"gemini",
|
||||
}
|
||||
|
||||
for _, provider := range expected {
|
||||
if !supportedProviders[provider] {
|
||||
t.Fatalf("supportedProviders missing expected key %q", provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConfig(t *testing.T) {
|
||||
t.Run("fills empty fields", func(t *testing.T) {
|
||||
existing := config.DefaultConfig()
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
package anthropicprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
const defaultBaseURL = "https://api.anthropic.com"
|
||||
|
||||
type Provider struct {
|
||||
client *anthropic.Client
|
||||
tokenSource func() (string, error)
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewProvider(token string) *Provider {
|
||||
return NewProviderWithBaseURL(token, "")
|
||||
}
|
||||
|
||||
func NewProviderWithBaseURL(token, apiBase string) *Provider {
|
||||
baseURL := normalizeBaseURL(apiBase)
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
return &Provider{
|
||||
client: &client,
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithClient(client *anthropic.Client) *Provider {
|
||||
return &Provider{
|
||||
client: client,
|
||||
baseURL: defaultBaseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider {
|
||||
return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "")
|
||||
}
|
||||
|
||||
func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider {
|
||||
p := NewProviderWithBaseURL(token, apiBase)
|
||||
p.tokenSource = tokenSource
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
if p.tokenSource != nil {
|
||||
tok, err := p.tokenSource()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
opts = append(opts, option.WithAuthToken(tok))
|
||||
}
|
||||
|
||||
params, err := buildParams(messages, tools, model, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
return parseResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
|
||||
func (p *Provider) BaseURL() string {
|
||||
return p.baseURL
|
||||
}
|
||||
|
||||
func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||
var system []anthropic.TextBlockParam
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
var blocks []anthropic.ContentBlockParamUnion
|
||||
if msg.Content != "" {
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "tool":
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
maxTokens := int64(4096)
|
||||
if mt, ok := options["max_tokens"].(int); ok {
|
||||
maxTokens = int64(mt)
|
||||
}
|
||||
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(model),
|
||||
Messages: anthropicMessages,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
params.System = system
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = anthropic.Float(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateTools(tools)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool := anthropic.ToolParam{
|
||||
Name: t.Function.Name,
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: t.Function.Parameters["properties"],
|
||||
},
|
||||
}
|
||||
if desc := t.Function.Description; desc != "" {
|
||||
tool.Description = anthropic.String(desc)
|
||||
}
|
||||
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||
required := make([]string, 0, len(req))
|
||||
for _, r := range req {
|
||||
if s, ok := r.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
tool.InputSchema.Required = required
|
||||
}
|
||||
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content string
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content += tb.Text
|
||||
case "tool_use":
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
|
||||
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tu.ID,
|
||||
Name: tu.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
switch resp.StopReason {
|
||||
case anthropic.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case anthropic.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case anthropic.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
PromptTokens: int(resp.Usage.InputTokens),
|
||||
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeBaseURL(apiBase string) string {
|
||||
base := strings.TrimSpace(apiBase)
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
base = strings.TrimRight(base, "/")
|
||||
if strings.HasSuffix(base, "/v1") {
|
||||
base = strings.TrimSuffix(base, "/v1")
|
||||
}
|
||||
if base == "" {
|
||||
return defaultBaseURL
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package anthropicprovider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||
)
|
||||
|
||||
func TestBuildParams_BasicMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||
"max_tokens": 1024,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
if params.MaxTokens != 1024 {
|
||||
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_SystemMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if len(params.System) != 1 {
|
||||
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||
}
|
||||
if params.System[0].Text != "You are helpful" {
|
||||
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_ToolCallMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]interface{}{"city": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if len(params.Messages) != 3 {
|
||||
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a city",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []interface{}{"city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildParams() error: %v", err)
|
||||
}
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_TextOnly(t *testing.T) {
|
||||
resp := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
},
|
||||
}
|
||||
result := parseResponse(resp)
|
||||
if result.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||
}
|
||||
if result.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
stopReason anthropic.StopReason
|
||||
want string
|
||||
}{
|
||||
{anthropic.StopReasonEndTurn, "stop"},
|
||||
{anthropic.StopReasonMaxTokens, "length"},
|
||||
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
resp := &anthropic.Message{
|
||||
StopReason: tt.stopReason,
|
||||
}
|
||||
result := parseResponse(resp)
|
||||
if result.FinishReason != tt.want {
|
||||
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ChatRoundTrip(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": reqBody["model"],
|
||||
"stop_reason": "end_turn",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "Hello! How can I help you?"},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 15,
|
||||
"output_tokens": 8,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello! How can I help you?" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 15 {
|
||||
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewProvider("test-token")
|
||||
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) {
|
||||
p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/")
|
||||
if got := p.BaseURL(); got != "https://api.anthropic.com" {
|
||||
t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ChatUsesTokenSource(t *testing.T) {
|
||||
var requests int32
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
atomic.AddInt32(&requests, 1)
|
||||
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": reqBody["model"],
|
||||
"stop_reason": "end_turn",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "ok"},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 1,
|
||||
"output_tokens": 1,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) {
|
||||
return "refreshed-token", nil
|
||||
}, server.URL)
|
||||
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&requests); got != 1 {
|
||||
t.Fatalf("requests = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||
c := anthropic.NewClient(
|
||||
anthropicoption.WithAuthToken(token),
|
||||
anthropicoption.WithBaseURL(baseURL),
|
||||
)
|
||||
return &c
|
||||
}
|
||||
@@ -2,200 +2,58 @@ package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
|
||||
)
|
||||
|
||||
type ClaudeProvider struct {
|
||||
client *anthropic.Client
|
||||
tokenSource func() (string, error)
|
||||
delegate *anthropicprovider.Provider
|
||||
}
|
||||
|
||||
func NewClaudeProvider(token string) *ClaudeProvider {
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL("https://api.anthropic.com"),
|
||||
)
|
||||
return &ClaudeProvider{client: &client}
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProvider(token),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
|
||||
p := NewClaudeProvider(token)
|
||||
p.tokenSource = tokenSource
|
||||
return p
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
|
||||
return &ClaudeProvider{
|
||||
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
|
||||
}
|
||||
}
|
||||
|
||||
func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider {
|
||||
return &ClaudeProvider{delegate: delegate}
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
if p.tokenSource != nil {
|
||||
tok, err := p.tokenSource()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
opts = append(opts, option.WithAuthToken(tok))
|
||||
}
|
||||
|
||||
params, err := buildClaudeParams(messages, tools, model, options)
|
||||
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
return parseClaudeResponse(resp), nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) GetDefaultModel() string {
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
|
||||
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||
var system []anthropic.TextBlockParam
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
var blocks []anthropic.ContentBlockParamUnion
|
||||
if msg.Content != "" {
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "tool":
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
maxTokens := int64(4096)
|
||||
if mt, ok := options["max_tokens"].(int); ok {
|
||||
maxTokens = int64(mt)
|
||||
}
|
||||
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(model),
|
||||
Messages: anthropicMessages,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
params.System = system
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = anthropic.Float(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForClaude(tools)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool := anthropic.ToolParam{
|
||||
Name: t.Function.Name,
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: t.Function.Parameters["properties"],
|
||||
},
|
||||
}
|
||||
if desc := t.Function.Description; desc != "" {
|
||||
tool.Description = anthropic.String(desc)
|
||||
}
|
||||
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||
required := make([]string, 0, len(req))
|
||||
for _, r := range req {
|
||||
if s, ok := r.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
tool.InputSchema.Required = required
|
||||
}
|
||||
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content string
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content += tb.Text
|
||||
case "tool_use":
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tu.ID,
|
||||
Name: tu.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
switch resp.StopReason {
|
||||
case anthropic.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case anthropic.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case anthropic.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
PromptTokens: int(resp.Usage.InputTokens),
|
||||
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||
},
|
||||
}
|
||||
return p.delegate.GetDefaultModel()
|
||||
}
|
||||
|
||||
func createClaudeTokenSource() func() (string, error) {
|
||||
return func() (string, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
cred, err := getCredential("anthropic")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,140 +8,9 @@ import (
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
|
||||
)
|
||||
|
||||
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||
"max_tokens": 1024,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
if params.MaxTokens != 1024 {
|
||||
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.System) != 1 {
|
||||
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||
}
|
||||
if params.System[0].Text != "You are helpful" {
|
||||
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]interface{}{"city": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.Messages) != 3 {
|
||||
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a city",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []interface{}{"city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeResponse_TextOnly(t *testing.T) {
|
||||
resp := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
},
|
||||
}
|
||||
result := parseClaudeResponse(resp)
|
||||
if result.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||
}
|
||||
if result.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
stopReason anthropic.StopReason
|
||||
want string
|
||||
}{
|
||||
{anthropic.StopReasonEndTurn, "stop"},
|
||||
{anthropic.StopReasonMaxTokens, "length"},
|
||||
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
resp := &anthropic.Message{
|
||||
StopReason: tt.stopReason,
|
||||
}
|
||||
result := parseClaudeResponse(resp)
|
||||
if result.FinishReason != tt.want {
|
||||
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
@@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewClaudeProvider("test-token")
|
||||
provider.client = createAnthropicTestClient(server.URL, "test-token")
|
||||
delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
|
||||
provider := newClaudeProviderWithDelegate(delegate)
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
//go:build integration
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
exec "os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegration_RealCodexCLI tests the CodexCliProvider with a real codex CLI.
|
||||
// Run with: go test -tags=integration ./pkg/providers/...
|
||||
func TestIntegration_RealCodexCLI(t *testing.T) {
|
||||
path, err := exec.LookPath("codex")
|
||||
if err != nil {
|
||||
t.Skip("codex CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using codex CLI at: %s", path)
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewCodexCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealCodexCLI_ParsesRealJSONL(t *testing.T) {
|
||||
if _, err := exec.LookPath("codex"); err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run codex directly and verify our parser handles real output
|
||||
cmd := exec.Command("codex", "exec",
|
||||
"--json",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
"--skip-git-repo-check",
|
||||
"--color", "never",
|
||||
"-C", t.TempDir(),
|
||||
"-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// codex may write diagnostic noise to stderr but still produce valid output
|
||||
if len(output) == 0 {
|
||||
t.Fatalf("codex CLI failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output (first 500 chars): %s", string(output[:min(len(output), 500)]))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewCodexCliProvider("")
|
||||
resp, err := p.parseJSONLEvents(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
@@ -18,9 +18,10 @@ const codexDefaultModel = "gpt-5.2"
|
||||
const codexDefaultInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
type CodexProvider struct {
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
enableWebSearch bool
|
||||
}
|
||||
|
||||
const defaultCodexInstructions = "You are Codex, a coding assistant."
|
||||
@@ -37,8 +38,9 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
}
|
||||
client := openai.NewClient(opts...)
|
||||
return &CodexProvider{
|
||||
client: &client,
|
||||
accountID: accountID,
|
||||
client: &client,
|
||||
accountID: accountID,
|
||||
enableWebSearch: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +80,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
|
||||
})
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
|
||||
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
@@ -182,7 +184,7 @@ func resolveCodexModel(model string) (string, string) {
|
||||
return codexDefaultModel, "unsupported model family"
|
||||
}
|
||||
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams {
|
||||
var inputItems responses.ResponseInputParam
|
||||
var instructions string
|
||||
|
||||
@@ -217,12 +219,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
name, args, ok := resolveCodexToolCall(tc)
|
||||
if !ok {
|
||||
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{
|
||||
"call_id": tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Arguments: string(argsJSON),
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -260,20 +268,50 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
params.Instructions = openai.Opt(defaultCodexInstructions)
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForCodex(tools)
|
||||
if len(tools) > 0 || enableWebSearch {
|
||||
params.Tools = translateToolsForCodex(tools, enableWebSearch)
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
|
||||
name = tc.Name
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
if name == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if len(tc.Arguments) > 0 {
|
||||
argsJSON, err := json.Marshal(tc.Arguments)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return name, string(argsJSON), true
|
||||
}
|
||||
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
return name, tc.Function.Arguments, true
|
||||
}
|
||||
|
||||
return name, "{}", true
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition, enableWebSearch bool) []responses.ToolUnionParam {
|
||||
capHint := len(tools)
|
||||
if enableWebSearch {
|
||||
capHint++
|
||||
}
|
||||
result := make([]responses.ToolUnionParam, 0, capHint)
|
||||
for _, t := range tools {
|
||||
if t.Type != "function" {
|
||||
continue
|
||||
}
|
||||
if enableWebSearch && strings.EqualFold(t.Function.Name, "web_search") {
|
||||
continue
|
||||
}
|
||||
ft := responses.FunctionToolParam{
|
||||
Name: t.Function.Name,
|
||||
Parameters: t.Function.Parameters,
|
||||
@@ -284,6 +322,9 @@ func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
}
|
||||
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||
}
|
||||
if enableWebSearch {
|
||||
result = append(result, responses.ToolParamOfWebSearch(responses.WebSearchToolTypeWebSearch))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
}, true)
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
}
|
||||
@@ -29,6 +29,9 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
if params.Instructions.Or("") != defaultCodexInstructions {
|
||||
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
|
||||
}
|
||||
if params.MaxOutputTokens.Valid() {
|
||||
t.Fatalf("MaxOutputTokens should not be set for Codex backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
@@ -36,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true)
|
||||
if !params.Instructions.Valid() {
|
||||
t.Fatal("Instructions should be set")
|
||||
}
|
||||
@@ -56,7 +59,7 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
@@ -65,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Read a file"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
if len(params.Input.OfInputItemList) != 3 {
|
||||
t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||
}
|
||||
|
||||
fc := params.Input.OfInputItemList[1].OfFunctionCall
|
||||
if fc == nil {
|
||||
t.Fatal("assistant tool call should be converted to function_call input item")
|
||||
}
|
||||
if fc.Name != "read_file" {
|
||||
t.Errorf("Function call name = %q, want %q", fc.Name, "read_file")
|
||||
}
|
||||
if fc.Arguments != `{"path":"README.md"}` {
|
||||
t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
@@ -81,7 +123,7 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false)
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
@@ -94,12 +136,61 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false)
|
||||
if !params.Store.Valid() || params.Store.Or(true) != false {
|
||||
t.Error("Store should be explicitly set to false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) {
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true)
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
if params.Tools[0].OfWebSearch == nil {
|
||||
t.Fatal("Tool should include built-in web_search")
|
||||
}
|
||||
if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch {
|
||||
t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "web_search",
|
||||
Description: "local web search",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: "read file",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true)
|
||||
if len(params.Tools) != 2 {
|
||||
t.Fatalf("len(Tools) = %d, want 2", len(params.Tools))
|
||||
}
|
||||
if params.Tools[0].OfFunction == nil || params.Tools[0].OfFunction.Name != "read_file" {
|
||||
t.Fatalf("first tool should be function read_file, got %#v", params.Tools[0])
|
||||
}
|
||||
if params.Tools[1].OfWebSearch == nil {
|
||||
t.Fatalf("second tool should be built-in web_search, got %#v", params.Tools[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||
respJSON := `{
|
||||
"id": "resp_test",
|
||||
@@ -214,6 +305,20 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
toolsAny, ok := reqBody["tools"].([]interface{})
|
||||
if !ok || len(toolsAny) != 1 {
|
||||
http.Error(w, "missing default web search tool", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
toolObj, ok := toolsAny[0].(map[string]interface{})
|
||||
if !ok || toolObj["type"] != "web_search" {
|
||||
http.Error(w, "expected web_search tool", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
@@ -261,6 +366,64 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["tools"]; ok {
|
||||
http.Error(w, "tools should be absent when web search disabled", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 4,
|
||||
"output_tokens": 3,
|
||||
"total_tokens": 7,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("test-token", "acc-123")
|
||||
provider.enableWebSearch = false
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
@@ -293,6 +456,10 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
|
||||
http.Error(w, "temperature is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFailureWindow = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CooldownTracker manages per-provider cooldown state for the fallback chain.
|
||||
// Thread-safe via sync.RWMutex. In-memory only (resets on restart).
|
||||
type CooldownTracker struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]*cooldownEntry
|
||||
failureWindow time.Duration
|
||||
nowFunc func() time.Time // for testing
|
||||
}
|
||||
|
||||
type cooldownEntry struct {
|
||||
ErrorCount int
|
||||
FailureCounts map[FailoverReason]int
|
||||
CooldownEnd time.Time // standard cooldown expiry
|
||||
DisabledUntil time.Time // billing-specific disable expiry
|
||||
DisabledReason FailoverReason // reason for disable (billing)
|
||||
LastFailure time.Time
|
||||
}
|
||||
|
||||
// NewCooldownTracker creates a tracker with default 24h failure window.
|
||||
func NewCooldownTracker() *CooldownTracker {
|
||||
return &CooldownTracker{
|
||||
entries: make(map[string]*cooldownEntry),
|
||||
failureWindow: defaultFailureWindow,
|
||||
nowFunc: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkFailure records a failure for a provider and sets appropriate cooldown.
|
||||
// Resets error counts if last failure was more than failureWindow ago.
|
||||
func (ct *CooldownTracker) MarkFailure(provider string, reason FailoverReason) {
|
||||
ct.mu.Lock()
|
||||
defer ct.mu.Unlock()
|
||||
|
||||
now := ct.nowFunc()
|
||||
entry := ct.getOrCreate(provider)
|
||||
|
||||
// 24h failure window reset: if no failure in failureWindow, reset counters.
|
||||
if !entry.LastFailure.IsZero() && now.Sub(entry.LastFailure) > ct.failureWindow {
|
||||
entry.ErrorCount = 0
|
||||
entry.FailureCounts = make(map[FailoverReason]int)
|
||||
}
|
||||
|
||||
entry.ErrorCount++
|
||||
entry.FailureCounts[reason]++
|
||||
entry.LastFailure = now
|
||||
|
||||
if reason == FailoverBilling {
|
||||
billingCount := entry.FailureCounts[FailoverBilling]
|
||||
entry.DisabledUntil = now.Add(calculateBillingCooldown(billingCount))
|
||||
entry.DisabledReason = FailoverBilling
|
||||
} else {
|
||||
entry.CooldownEnd = now.Add(calculateStandardCooldown(entry.ErrorCount))
|
||||
}
|
||||
}
|
||||
|
||||
// MarkSuccess resets all counters and cooldowns for a provider.
|
||||
func (ct *CooldownTracker) MarkSuccess(provider string) {
|
||||
ct.mu.Lock()
|
||||
defer ct.mu.Unlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
entry.ErrorCount = 0
|
||||
entry.FailureCounts = make(map[FailoverReason]int)
|
||||
entry.CooldownEnd = time.Time{}
|
||||
entry.DisabledUntil = time.Time{}
|
||||
entry.DisabledReason = ""
|
||||
}
|
||||
|
||||
// IsAvailable returns true if the provider is not in cooldown or disabled.
|
||||
func (ct *CooldownTracker) IsAvailable(provider string) bool {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
now := ct.nowFunc()
|
||||
|
||||
// Billing disable takes precedence (longer cooldown).
|
||||
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Standard cooldown.
|
||||
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// CooldownRemaining returns how long until the provider becomes available.
|
||||
// Returns 0 if already available.
|
||||
func (ct *CooldownTracker) CooldownRemaining(provider string) time.Duration {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
now := ct.nowFunc()
|
||||
var remaining time.Duration
|
||||
|
||||
if !entry.DisabledUntil.IsZero() && now.Before(entry.DisabledUntil) {
|
||||
d := entry.DisabledUntil.Sub(now)
|
||||
if d > remaining {
|
||||
remaining = d
|
||||
}
|
||||
}
|
||||
|
||||
if !entry.CooldownEnd.IsZero() && now.Before(entry.CooldownEnd) {
|
||||
d := entry.CooldownEnd.Sub(now)
|
||||
if d > remaining {
|
||||
remaining = d
|
||||
}
|
||||
}
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
// ErrorCount returns the current error count for a provider.
|
||||
func (ct *CooldownTracker) ErrorCount(provider string) int {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
return entry.ErrorCount
|
||||
}
|
||||
|
||||
// FailureCount returns the failure count for a specific reason.
|
||||
func (ct *CooldownTracker) FailureCount(provider string, reason FailoverReason) int {
|
||||
ct.mu.RLock()
|
||||
defer ct.mu.RUnlock()
|
||||
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
return 0
|
||||
}
|
||||
return entry.FailureCounts[reason]
|
||||
}
|
||||
|
||||
func (ct *CooldownTracker) getOrCreate(provider string) *cooldownEntry {
|
||||
entry := ct.entries[provider]
|
||||
if entry == nil {
|
||||
entry = &cooldownEntry{
|
||||
FailureCounts: make(map[FailoverReason]int),
|
||||
}
|
||||
ct.entries[provider] = entry
|
||||
}
|
||||
return entry
|
||||
}
|
||||
|
||||
// calculateStandardCooldown computes standard exponential backoff.
|
||||
// Formula from OpenClaw: min(1h, 1min * 5^min(n-1, 3))
|
||||
//
|
||||
// 1 error → 1 min
|
||||
// 2 errors → 5 min
|
||||
// 3 errors → 25 min
|
||||
// 4+ errors → 1 hour (cap)
|
||||
func calculateStandardCooldown(errorCount int) time.Duration {
|
||||
n := max(1, errorCount)
|
||||
exp := min(n-1, 3)
|
||||
ms := 60_000 * int(math.Pow(5, float64(exp)))
|
||||
ms = min(3_600_000, ms) // cap at 1 hour
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
// calculateBillingCooldown computes billing-specific exponential backoff.
|
||||
// Formula from OpenClaw: min(24h, 5h * 2^min(n-1, 10))
|
||||
//
|
||||
// 1 error → 5 hours
|
||||
// 2 errors → 10 hours
|
||||
// 3 errors → 20 hours
|
||||
// 4+ errors → 24 hours (cap)
|
||||
func calculateBillingCooldown(billingErrorCount int) time.Duration {
|
||||
const baseMs = 5 * 60 * 60 * 1000 // 5 hours
|
||||
const maxMs = 24 * 60 * 60 * 1000 // 24 hours
|
||||
|
||||
n := max(1, billingErrorCount)
|
||||
exp := min(n-1, 10)
|
||||
raw := float64(baseMs) * math.Pow(2, float64(exp))
|
||||
ms := int(math.Min(float64(maxMs), raw))
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTestTracker(now time.Time) (*CooldownTracker, *time.Time) {
|
||||
current := now
|
||||
ct := NewCooldownTracker()
|
||||
ct.nowFunc = func() time.Time { return current }
|
||||
return ct, ¤t
|
||||
}
|
||||
|
||||
func TestCooldown_InitiallyAvailable(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("new provider should be available")
|
||||
}
|
||||
if ct.ErrorCount("openai") != 0 {
|
||||
t.Error("new provider should have 0 errors")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_StandardEscalation(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 1st error → 1 min cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be in cooldown after 1st error")
|
||||
}
|
||||
|
||||
// Advance 61 seconds → available
|
||||
*current = now.Add(61 * time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 1 min cooldown")
|
||||
}
|
||||
|
||||
// 2nd error → 5 min cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = now.Add(61*time.Second + 4*time.Minute)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be in cooldown (5 min) after 2nd error")
|
||||
}
|
||||
*current = now.Add(61*time.Second + 6*time.Minute)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 5 min cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_StandardCap(t *testing.T) {
|
||||
// Verify formula: 1m, 5m, 25m, 1h, 1h, 1h...
|
||||
expected := []time.Duration{
|
||||
1 * time.Minute,
|
||||
5 * time.Minute,
|
||||
25 * time.Minute,
|
||||
1 * time.Hour,
|
||||
1 * time.Hour,
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
got := calculateStandardCooldown(i + 1)
|
||||
if got != want {
|
||||
t.Errorf("calculateStandardCooldown(%d) = %v, want %v", i+1, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingEscalation(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 1st billing error → 5h cooldown
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should be disabled after billing error")
|
||||
}
|
||||
|
||||
// Advance 4h → still disabled
|
||||
*current = now.Add(4 * time.Hour)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("should still be disabled (5h cooldown)")
|
||||
}
|
||||
|
||||
// Advance 5h + 1s → available
|
||||
*current = now.Add(5*time.Hour + 1*time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after 5h billing cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingCap(t *testing.T) {
|
||||
expected := []time.Duration{
|
||||
5 * time.Hour,
|
||||
10 * time.Hour,
|
||||
20 * time.Hour,
|
||||
24 * time.Hour,
|
||||
24 * time.Hour,
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
got := calculateBillingCooldown(i + 1)
|
||||
if got != want {
|
||||
t.Errorf("calculateBillingCooldown(%d) = %v, want %v", i+1, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_SuccessReset(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
if ct.ErrorCount("openai") != 2 {
|
||||
t.Errorf("error count = %d, want 2", ct.ErrorCount("openai"))
|
||||
}
|
||||
|
||||
ct.MarkSuccess("openai")
|
||||
if ct.ErrorCount("openai") != 0 {
|
||||
t.Errorf("error count after success = %d, want 0", ct.ErrorCount("openai"))
|
||||
}
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after success")
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverRateLimit) != 0 {
|
||||
t.Error("failure counts should be reset after success")
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverBilling) != 0 {
|
||||
t.Error("billing failure count should be reset after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_FailureWindowReset(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// 4 errors → 1h cooldown
|
||||
for i := 0; i < 4; i++ {
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
*current = current.Add(2 * time.Second) // small advance between errors
|
||||
}
|
||||
if ct.ErrorCount("openai") != 4 {
|
||||
t.Errorf("error count = %d, want 4", ct.ErrorCount("openai"))
|
||||
}
|
||||
|
||||
// Advance 25 hours (past 24h failure window)
|
||||
*current = now.Add(25 * time.Hour)
|
||||
|
||||
// Next error should reset counters first, then increment to 1
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
if ct.ErrorCount("openai") != 1 {
|
||||
t.Errorf("error count after window reset = %d, want 1 (reset + 1)", ct.ErrorCount("openai"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_PerReasonTracking(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("openai", FailoverBilling)
|
||||
ct.MarkFailure("openai", FailoverAuth)
|
||||
|
||||
if ct.FailureCount("openai", FailoverRateLimit) != 2 {
|
||||
t.Errorf("rate_limit count = %d, want 2", ct.FailureCount("openai", FailoverRateLimit))
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverBilling) != 1 {
|
||||
t.Errorf("billing count = %d, want 1", ct.FailureCount("openai", FailoverBilling))
|
||||
}
|
||||
if ct.FailureCount("openai", FailoverAuth) != 1 {
|
||||
t.Errorf("auth count = %d, want 1", ct.FailureCount("openai", FailoverAuth))
|
||||
}
|
||||
if ct.ErrorCount("openai") != 4 {
|
||||
t.Errorf("total error count = %d, want 4", ct.ErrorCount("openai"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_BillingTakesPrecedence(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// Standard cooldown (1 min) + billing disable (5h)
|
||||
ct.MarkFailure("openai", FailoverRateLimit) // 1 min cooldown
|
||||
ct.MarkFailure("openai", FailoverBilling) // 5h disable
|
||||
|
||||
// After 2 min: standard cooldown expired but billing still active
|
||||
*current = now.Add(2 * time.Minute)
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("billing disable should take precedence over standard cooldown")
|
||||
}
|
||||
|
||||
// After 5h + 1s: both expired
|
||||
*current = now.Add(5*time.Hour + 1*time.Second)
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("should be available after all cooldowns expire")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_CooldownRemaining(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, current := newTestTracker(now)
|
||||
|
||||
// No failures → 0 remaining
|
||||
if ct.CooldownRemaining("openai") != 0 {
|
||||
t.Error("expected 0 remaining for new provider")
|
||||
}
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
|
||||
*current = now.Add(30 * time.Second)
|
||||
remaining := ct.CooldownRemaining("openai")
|
||||
if remaining <= 0 || remaining > 1*time.Minute {
|
||||
t.Errorf("remaining = %v, expected ~30s", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_SuccessOnUnknownProvider(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
// Should not panic
|
||||
ct.MarkSuccess("nonexistent")
|
||||
if !ct.IsAvailable("nonexistent") {
|
||||
t.Error("nonexistent provider should be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldown_ConcurrentAccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.IsAvailable("openai")
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ct.MarkSuccess("openai")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// If we got here without panic, concurrent access is safe
|
||||
}
|
||||
|
||||
func TestCooldown_MultipleProviders(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("anthropic", FailoverBilling)
|
||||
|
||||
if ct.IsAvailable("openai") {
|
||||
t.Error("openai should be in cooldown")
|
||||
}
|
||||
if ct.IsAvailable("anthropic") {
|
||||
t.Error("anthropic should be in cooldown")
|
||||
}
|
||||
// groq was never touched
|
||||
if !ct.IsAvailable("groq") {
|
||||
t.Error("groq should be available")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// errorPattern defines a single pattern (string or regex) for error classification.
|
||||
type errorPattern struct {
|
||||
substring string
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
func substr(s string) errorPattern { return errorPattern{substring: s} }
|
||||
func rxp(r string) errorPattern { return errorPattern{regex: regexp.MustCompile("(?i)" + r)} }
|
||||
|
||||
// Error patterns organized by FailoverReason, matching OpenClaw production (~40 patterns).
|
||||
var (
|
||||
rateLimitPatterns = []errorPattern{
|
||||
rxp(`rate[_ ]limit`),
|
||||
substr("too many requests"),
|
||||
substr("429"),
|
||||
substr("exceeded your current quota"),
|
||||
rxp(`exceeded.*quota`),
|
||||
rxp(`resource has been exhausted`),
|
||||
rxp(`resource.*exhausted`),
|
||||
substr("resource_exhausted"),
|
||||
substr("quota exceeded"),
|
||||
substr("usage limit"),
|
||||
}
|
||||
|
||||
overloadedPatterns = []errorPattern{
|
||||
rxp(`overloaded_error`),
|
||||
rxp(`"type"\s*:\s*"overloaded_error"`),
|
||||
substr("overloaded"),
|
||||
}
|
||||
|
||||
timeoutPatterns = []errorPattern{
|
||||
substr("timeout"),
|
||||
substr("timed out"),
|
||||
substr("deadline exceeded"),
|
||||
substr("context deadline exceeded"),
|
||||
}
|
||||
|
||||
billingPatterns = []errorPattern{
|
||||
rxp(`\b402\b`),
|
||||
substr("payment required"),
|
||||
substr("insufficient credits"),
|
||||
substr("credit balance"),
|
||||
substr("plans & billing"),
|
||||
substr("insufficient balance"),
|
||||
}
|
||||
|
||||
authPatterns = []errorPattern{
|
||||
rxp(`invalid[_ ]?api[_ ]?key`),
|
||||
substr("incorrect api key"),
|
||||
substr("invalid token"),
|
||||
substr("authentication"),
|
||||
substr("re-authenticate"),
|
||||
substr("oauth token refresh failed"),
|
||||
substr("unauthorized"),
|
||||
substr("forbidden"),
|
||||
substr("access denied"),
|
||||
substr("expired"),
|
||||
substr("token has expired"),
|
||||
rxp(`\b401\b`),
|
||||
rxp(`\b403\b`),
|
||||
substr("no credentials found"),
|
||||
substr("no api key found"),
|
||||
}
|
||||
|
||||
formatPatterns = []errorPattern{
|
||||
substr("string should match pattern"),
|
||||
substr("tool_use.id"),
|
||||
substr("tool_use_id"),
|
||||
substr("messages.1.content.1.tool_use.id"),
|
||||
substr("invalid request format"),
|
||||
}
|
||||
|
||||
imageDimensionPatterns = []errorPattern{
|
||||
rxp(`image dimensions exceed max`),
|
||||
}
|
||||
|
||||
imageSizePatterns = []errorPattern{
|
||||
rxp(`image exceeds.*mb`),
|
||||
}
|
||||
|
||||
// Transient HTTP status codes that map to timeout (server-side failures).
|
||||
transientStatusCodes = map[int]bool{
|
||||
500: true, 502: true, 503: true,
|
||||
521: true, 522: true, 523: true, 524: true,
|
||||
529: true,
|
||||
}
|
||||
)
|
||||
|
||||
// ClassifyError classifies an error into a FailoverError with reason.
|
||||
// Returns nil if the error is not classifiable (unknown errors should not trigger fallback).
|
||||
func ClassifyError(err error, provider, model string) *FailoverError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context cancellation: user abort, never fallback.
|
||||
if err == context.Canceled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context deadline exceeded: treat as timeout, always fallback.
|
||||
if err == context.DeadlineExceeded {
|
||||
return &FailoverError{
|
||||
Reason: FailoverTimeout,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
|
||||
// Image dimension/size errors: non-retriable, non-fallback.
|
||||
if IsImageDimensionError(msg) || IsImageSizeError(msg) {
|
||||
return &FailoverError{
|
||||
Reason: FailoverFormat,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Try HTTP status code extraction first.
|
||||
if status := extractHTTPStatus(msg); status > 0 {
|
||||
if reason := classifyByStatus(status); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Status: status,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Message pattern matching (priority order from OpenClaw).
|
||||
if reason := classifyByMessage(msg); reason != "" {
|
||||
return &FailoverError{
|
||||
Reason: reason,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// classifyByStatus maps HTTP status codes to FailoverReason.
|
||||
func classifyByStatus(status int) FailoverReason {
|
||||
switch {
|
||||
case status == 401 || status == 403:
|
||||
return FailoverAuth
|
||||
case status == 402:
|
||||
return FailoverBilling
|
||||
case status == 408:
|
||||
return FailoverTimeout
|
||||
case status == 429:
|
||||
return FailoverRateLimit
|
||||
case status == 400:
|
||||
return FailoverFormat
|
||||
case transientStatusCodes[status]:
|
||||
return FailoverTimeout
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// classifyByMessage matches error messages against patterns.
|
||||
// Priority order matters (from OpenClaw classifyFailoverReason).
|
||||
func classifyByMessage(msg string) FailoverReason {
|
||||
if matchesAny(msg, rateLimitPatterns) {
|
||||
return FailoverRateLimit
|
||||
}
|
||||
if matchesAny(msg, overloadedPatterns) {
|
||||
return FailoverRateLimit // Overloaded treated as rate_limit
|
||||
}
|
||||
if matchesAny(msg, billingPatterns) {
|
||||
return FailoverBilling
|
||||
}
|
||||
if matchesAny(msg, timeoutPatterns) {
|
||||
return FailoverTimeout
|
||||
}
|
||||
if matchesAny(msg, authPatterns) {
|
||||
return FailoverAuth
|
||||
}
|
||||
if matchesAny(msg, formatPatterns) {
|
||||
return FailoverFormat
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHTTPStatus extracts an HTTP status code from an error message.
|
||||
// Looks for patterns like "status: 429", "status 429", "HTTP 429", or standalone "429".
|
||||
func extractHTTPStatus(msg string) int {
|
||||
// Common patterns in Go HTTP error messages
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`status[:\s]+(\d{3})`),
|
||||
regexp.MustCompile(`HTTP[/\s]+\d*\.?\d*\s+(\d{3})`),
|
||||
}
|
||||
|
||||
for _, p := range patterns {
|
||||
if m := p.FindStringSubmatch(msg); len(m) > 1 {
|
||||
return parseDigits(m[1])
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsImageDimensionError returns true if the message indicates an image dimension error.
|
||||
func IsImageDimensionError(msg string) bool {
|
||||
return matchesAny(msg, imageDimensionPatterns)
|
||||
}
|
||||
|
||||
// IsImageSizeError returns true if the message indicates an image file size error.
|
||||
func IsImageSizeError(msg string) bool {
|
||||
return matchesAny(msg, imageSizePatterns)
|
||||
}
|
||||
|
||||
// matchesAny checks if msg matches any of the patterns.
|
||||
func matchesAny(msg string, patterns []errorPattern) bool {
|
||||
for _, p := range patterns {
|
||||
if p.regex != nil {
|
||||
if p.regex.MatchString(msg) {
|
||||
return true
|
||||
}
|
||||
} else if p.substring != "" {
|
||||
if strings.Contains(msg, p.substring) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseDigits converts a string of digits to an int.
|
||||
func parseDigits(s string) int {
|
||||
n := 0
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int(c-'0')
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClassifyError_Nil(t *testing.T) {
|
||||
result := ClassifyError(nil, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for nil error, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ContextCanceled(t *testing.T) {
|
||||
result := ClassifyError(context.Canceled, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for context.Canceled (user abort), got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ContextDeadlineExceeded(t *testing.T) {
|
||||
result := ClassifyError(context.DeadlineExceeded, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for deadline exceeded")
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Errorf("reason = %q, want timeout", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_StatusCodes(t *testing.T) {
|
||||
tests := []struct {
|
||||
status int
|
||||
reason FailoverReason
|
||||
}{
|
||||
{401, FailoverAuth},
|
||||
{403, FailoverAuth},
|
||||
{402, FailoverBilling},
|
||||
{408, FailoverTimeout},
|
||||
{429, FailoverRateLimit},
|
||||
{400, FailoverFormat},
|
||||
{500, FailoverTimeout},
|
||||
{502, FailoverTimeout},
|
||||
{503, FailoverTimeout},
|
||||
{521, FailoverTimeout},
|
||||
{522, FailoverTimeout},
|
||||
{523, FailoverTimeout},
|
||||
{524, FailoverTimeout},
|
||||
{529, FailoverTimeout},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
err := fmt.Errorf("API error: status: %d something went wrong", tt.status)
|
||||
result := ClassifyError(err, "test", "model")
|
||||
if result == nil {
|
||||
t.Errorf("status %d: expected non-nil", tt.status)
|
||||
continue
|
||||
}
|
||||
if result.Reason != tt.reason {
|
||||
t.Errorf("status %d: reason = %q, want %q", tt.status, result.Reason, tt.reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_RateLimitPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"rate limit exceeded",
|
||||
"rate_limit reached",
|
||||
"too many requests",
|
||||
"exceeded your current quota",
|
||||
"resource has been exhausted",
|
||||
"resource_exhausted",
|
||||
"quota exceeded",
|
||||
"usage limit reached",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverRateLimit {
|
||||
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_OverloadedPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"overloaded_error",
|
||||
`{"type": "overloaded_error"}`,
|
||||
"server is overloaded",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "anthropic", "claude")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
// Overloaded is treated as rate_limit
|
||||
if result.Reason != FailoverRateLimit {
|
||||
t.Errorf("pattern %q: reason = %q, want rate_limit", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_BillingPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"payment required",
|
||||
"insufficient credits",
|
||||
"credit balance too low",
|
||||
"plans & billing page",
|
||||
"insufficient balance",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverBilling {
|
||||
t.Errorf("pattern %q: reason = %q, want billing", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_TimeoutPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"request timeout",
|
||||
"connection timed out",
|
||||
"deadline exceeded",
|
||||
"context deadline exceeded",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverTimeout {
|
||||
t.Errorf("pattern %q: reason = %q, want timeout", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_AuthPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"invalid api key",
|
||||
"invalid_api_key",
|
||||
"incorrect api key",
|
||||
"invalid token",
|
||||
"authentication failed",
|
||||
"re-authenticate",
|
||||
"oauth token refresh failed",
|
||||
"unauthorized access",
|
||||
"forbidden",
|
||||
"access denied",
|
||||
"expired",
|
||||
"token has expired",
|
||||
"no credentials found",
|
||||
"no api key found",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverAuth {
|
||||
t.Errorf("pattern %q: reason = %q, want auth", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_FormatPatterns(t *testing.T) {
|
||||
patterns := []string{
|
||||
"string should match pattern",
|
||||
"tool_use.id is required",
|
||||
"invalid tool_use_id",
|
||||
"messages.1.content.1.tool_use.id must be valid",
|
||||
"invalid request format",
|
||||
}
|
||||
|
||||
for _, msg := range patterns {
|
||||
err := errors.New(msg)
|
||||
result := ClassifyError(err, "anthropic", "claude")
|
||||
if result == nil {
|
||||
t.Errorf("pattern %q: expected non-nil", msg)
|
||||
continue
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("pattern %q: reason = %q, want format", msg, result.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ImageDimensionError(t *testing.T) {
|
||||
err := errors.New("image dimensions exceed max allowed 2048x2048")
|
||||
result := ClassifyError(err, "openai", "gpt-4o")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for image dimension error")
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", result.Reason)
|
||||
}
|
||||
if result.IsRetriable() {
|
||||
t.Error("image dimension error should not be retriable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ImageSizeError(t *testing.T) {
|
||||
err := errors.New("image exceeds 20 mb limit")
|
||||
result := ClassifyError(err, "openai", "gpt-4o")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil for image size error")
|
||||
}
|
||||
if result.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", result.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_UnknownError(t *testing.T) {
|
||||
err := errors.New("some completely random error")
|
||||
result := ClassifyError(err, "openai", "gpt-4")
|
||||
if result != nil {
|
||||
t.Errorf("expected nil for unknown error, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_ProviderModelPropagation(t *testing.T) {
|
||||
err := errors.New("rate limit exceeded")
|
||||
result := ClassifyError(err, "my-provider", "my-model")
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if result.Provider != "my-provider" {
|
||||
t.Errorf("provider = %q, want my-provider", result.Provider)
|
||||
}
|
||||
if result.Model != "my-model" {
|
||||
t.Errorf("model = %q, want my-model", result.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_IsRetriable(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason FailoverReason
|
||||
retriable bool
|
||||
}{
|
||||
{FailoverAuth, true},
|
||||
{FailoverRateLimit, true},
|
||||
{FailoverBilling, true},
|
||||
{FailoverTimeout, true},
|
||||
{FailoverOverloaded, true},
|
||||
{FailoverFormat, false},
|
||||
{FailoverUnknown, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
fe := &FailoverError{Reason: tt.reason}
|
||||
if fe.IsRetriable() != tt.retriable {
|
||||
t.Errorf("IsRetriable(%q) = %v, want %v", tt.reason, fe.IsRetriable(), tt.retriable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_ErrorString(t *testing.T) {
|
||||
fe := &FailoverError{
|
||||
Reason: FailoverRateLimit,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
Status: 429,
|
||||
Wrapped: errors.New("too many requests"),
|
||||
}
|
||||
s := fe.Error()
|
||||
if s == "" {
|
||||
t.Error("expected non-empty error string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverError_Unwrap(t *testing.T) {
|
||||
inner := errors.New("inner error")
|
||||
fe := &FailoverError{Reason: FailoverTimeout, Wrapped: inner}
|
||||
if fe.Unwrap() != inner {
|
||||
t.Error("Unwrap should return wrapped error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHTTPStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{"status: 429 rate limited", 429},
|
||||
{"status 401 unauthorized", 401},
|
||||
{"HTTP/1.1 502 Bad Gateway", 502},
|
||||
{"no status code here", 0},
|
||||
{"random number 12345", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := extractHTTPStatus(tt.msg)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractHTTPStatus(%q) = %d, want %d", tt.msg, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsImageDimensionError(t *testing.T) {
|
||||
if !IsImageDimensionError("image dimensions exceed max 4096x4096") {
|
||||
t.Error("should match image dimensions exceed max")
|
||||
}
|
||||
if IsImageDimensionError("normal error message") {
|
||||
t.Error("should not match normal error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsImageSizeError(t *testing.T) {
|
||||
if !IsImageSizeError("image exceeds 20 mb") {
|
||||
t.Error("should match image exceeds mb")
|
||||
}
|
||||
if IsImageSizeError("normal error message") {
|
||||
t.Error("should not match normal error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,360 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
const defaultAnthropicAPIBase = "https://api.anthropic.com/v1"
|
||||
|
||||
var getCredential = auth.GetCredential
|
||||
|
||||
type providerType int
|
||||
|
||||
const (
|
||||
providerTypeHTTPCompat providerType = iota
|
||||
providerTypeClaudeAuth
|
||||
providerTypeCodexAuth
|
||||
providerTypeCodexCLIToken
|
||||
providerTypeClaudeCLI
|
||||
providerTypeCodexCLI
|
||||
providerTypeGitHubCopilot
|
||||
)
|
||||
|
||||
type providerSelection struct {
|
||||
providerType providerType
|
||||
apiKey string
|
||||
apiBase string
|
||||
proxy string
|
||||
model string
|
||||
workspace string
|
||||
connectMode string
|
||||
enableWebSearch bool
|
||||
}
|
||||
|
||||
func createClaudeAuthProvider(apiBase string) (LLMProvider, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
cred, err := getCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) {
|
||||
cred, err := getCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource())
|
||||
p.enableWebSearch = enableWebSearch
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
sel := providerSelection{
|
||||
providerType: providerTypeHTTPCompat,
|
||||
model: model,
|
||||
}
|
||||
|
||||
// First, prefer explicit provider configuration.
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Groq.APIKey
|
||||
sel.apiBase = cfg.Providers.Groq.APIBase
|
||||
sel.proxy = cfg.Providers.Groq.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
sel.providerType = providerTypeCodexCLIToken
|
||||
return sel, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
sel.providerType = providerTypeCodexAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.OpenAI.APIKey
|
||||
sel.apiBase = cfg.Providers.OpenAI.APIBase
|
||||
sel.proxy = cfg.Providers.OpenAI.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.Anthropic.APIKey
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
sel.apiBase = cfg.Providers.Zhipu.APIBase
|
||||
sel.proxy = cfg.Providers.Zhipu.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Gemini.APIKey
|
||||
sel.apiBase = cfg.Providers.Gemini.APIBase
|
||||
sel.proxy = cfg.Providers.Gemini.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
sel.proxy = cfg.Providers.VLLM.Proxy
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
sel.apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
sel.proxy = cfg.Providers.ShengSuanYun.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "nvidia":
|
||||
if cfg.Providers.Nvidia.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.Nvidia.APIKey
|
||||
sel.apiBase = cfg.Providers.Nvidia.APIBase
|
||||
sel.proxy = cfg.Providers.Nvidia.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claude-code", "claudecode":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
sel.providerType = providerTypeClaudeCLI
|
||||
sel.workspace = workspace
|
||||
return sel, nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
sel.providerType = providerTypeCodexCLI
|
||||
sel.workspace = workspace
|
||||
return sel, nil
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
sel.apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
sel.proxy = cfg.Providers.DeepSeek.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
sel.model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
sel.providerType = providerTypeGitHubCopilot
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
sel.apiBase = "localhost:4321"
|
||||
}
|
||||
sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode
|
||||
return sel, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: infer provider from model and configured keys.
|
||||
if sel.apiKey == "" && sel.apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Moonshot.APIKey
|
||||
sel.apiBase = cfg.Providers.Moonshot.APIBase
|
||||
sel.proxy = cfg.Providers.Moonshot.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
case strings.HasPrefix(model, "openrouter/") ||
|
||||
strings.HasPrefix(model, "anthropic/") ||
|
||||
strings.HasPrefix(model, "openai/") ||
|
||||
strings.HasPrefix(model, "meta-llama/") ||
|
||||
strings.HasPrefix(model, "deepseek/") ||
|
||||
strings.HasPrefix(model, "google/"):
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) &&
|
||||
(cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
sel.providerType = providerTypeClaudeAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.Anthropic.APIKey
|
||||
sel.apiBase = cfg.Providers.Anthropic.APIBase
|
||||
sel.proxy = cfg.Providers.Anthropic.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = defaultAnthropicAPIBase
|
||||
}
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) &&
|
||||
(cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
sel.providerType = providerTypeCodexCLIToken
|
||||
return sel, nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
sel.providerType = providerTypeCodexAuth
|
||||
return sel, nil
|
||||
}
|
||||
sel.apiKey = cfg.Providers.OpenAI.APIKey
|
||||
sel.apiBase = cfg.Providers.OpenAI.APIBase
|
||||
sel.proxy = cfg.Providers.OpenAI.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Gemini.APIKey
|
||||
sel.apiBase = cfg.Providers.Gemini.APIBase
|
||||
sel.proxy = cfg.Providers.Gemini.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Zhipu.APIKey
|
||||
sel.apiBase = cfg.Providers.Zhipu.APIBase
|
||||
sel.proxy = cfg.Providers.Zhipu.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Groq.APIKey
|
||||
sel.apiBase = cfg.Providers.Groq.APIBase
|
||||
sel.proxy = cfg.Providers.Groq.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Nvidia.APIKey
|
||||
sel.apiBase = cfg.Providers.Nvidia.APIBase
|
||||
sel.proxy = cfg.Providers.Nvidia.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
sel.apiKey = cfg.Providers.Ollama.APIKey
|
||||
sel.apiBase = cfg.Providers.Ollama.APIBase
|
||||
sel.proxy = cfg.Providers.Ollama.Proxy
|
||||
if sel.apiBase == "" {
|
||||
sel.apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
sel.apiKey = cfg.Providers.VLLM.APIKey
|
||||
sel.apiBase = cfg.Providers.VLLM.APIBase
|
||||
sel.proxy = cfg.Providers.VLLM.Proxy
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
sel.apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
sel.proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
sel.apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
sel.apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sel.providerType == providerTypeHTTPCompat {
|
||||
if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
if sel.apiBase == "" {
|
||||
return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
}
|
||||
|
||||
return sel, nil
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
sel, err := resolveProviderSelection(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch sel.providerType {
|
||||
case providerTypeClaudeAuth:
|
||||
return createClaudeAuthProvider(sel.apiBase)
|
||||
case providerTypeCodexAuth:
|
||||
return createCodexAuthProvider(sel.enableWebSearch)
|
||||
case providerTypeCodexCLIToken:
|
||||
c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource())
|
||||
c.enableWebSearch = sel.enableWebSearch
|
||||
return c, nil
|
||||
case providerTypeClaudeCLI:
|
||||
return NewClaudeCliProvider(sel.workspace), nil
|
||||
case providerTypeCodexCLI:
|
||||
return NewCodexCliProvider(sel.workspace), nil
|
||||
case providerTypeGitHubCopilot:
|
||||
return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model)
|
||||
default:
|
||||
return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestResolveProviderSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*config.Config)
|
||||
wantType providerType
|
||||
wantAPIBase string
|
||||
wantProxy string
|
||||
wantErrSubstr string
|
||||
}{
|
||||
{
|
||||
name: "explicit claude-cli provider routes to cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "claude-cli"
|
||||
cfg.Agents.Defaults.Workspace = "/tmp/ws"
|
||||
},
|
||||
wantType: providerTypeClaudeCLI,
|
||||
},
|
||||
{
|
||||
name: "explicit copilot provider routes to github copilot type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "copilot"
|
||||
},
|
||||
wantType: providerTypeGitHubCopilot,
|
||||
wantAPIBase: "localhost:4321",
|
||||
},
|
||||
{
|
||||
name: "explicit deepseek provider uses deepseek defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "deepseek"
|
||||
cfg.Agents.Defaults.Model = "deepseek/deepseek-chat"
|
||||
cfg.Providers.DeepSeek.APIKey = "deepseek-key"
|
||||
cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.deepseek.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit shengsuanyun provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "shengsuanyun"
|
||||
cfg.Providers.ShengSuanYun.APIKey = "ssy-key"
|
||||
cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://router.shengsuanyun.com/api/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "explicit nvidia provider uses defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "nvidia"
|
||||
cfg.Providers.Nvidia.APIKey = "nvapi-test"
|
||||
cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://integrate.api.nvidia.com/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "openrouter model uses openrouter defaults",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://openrouter.ai/api/v1",
|
||||
},
|
||||
{
|
||||
name: "anthropic oauth routes to claude auth provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929"
|
||||
cfg.Providers.Anthropic.AuthMethod = "oauth"
|
||||
},
|
||||
wantType: providerTypeClaudeAuth,
|
||||
},
|
||||
{
|
||||
name: "openai oauth routes to codex auth provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "gpt-4o"
|
||||
cfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
},
|
||||
wantType: providerTypeCodexAuth,
|
||||
},
|
||||
{
|
||||
name: "openai codex-cli auth routes to codex cli token provider",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "gpt-4o"
|
||||
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
|
||||
},
|
||||
wantType: providerTypeCodexCLIToken,
|
||||
},
|
||||
{
|
||||
name: "explicit codex-code provider routes to codex cli provider type",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Provider = "codex-code"
|
||||
cfg.Agents.Defaults.Workspace = "/tmp/ws"
|
||||
},
|
||||
wantType: providerTypeCodexCLI,
|
||||
},
|
||||
{
|
||||
name: "zhipu model uses zhipu base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "glm-4.7"
|
||||
cfg.Providers.Zhipu.APIKey = "zhipu-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://open.bigmodel.cn/api/paas/v4",
|
||||
},
|
||||
{
|
||||
name: "groq model uses groq base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "groq/llama-3.3-70b"
|
||||
cfg.Providers.Groq.APIKey = "gsk-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.groq.com/openai/v1",
|
||||
},
|
||||
{
|
||||
name: "ollama model uses ollama base default",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b"
|
||||
cfg.Providers.Ollama.APIKey = "ollama-key"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "http://localhost:11434/v1",
|
||||
},
|
||||
{
|
||||
name: "moonshot model keeps proxy and default base",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5"
|
||||
cfg.Providers.Moonshot.APIKey = "moonshot-key"
|
||||
cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890"
|
||||
},
|
||||
wantType: providerTypeHTTPCompat,
|
||||
wantAPIBase: "https://api.moonshot.cn/v1",
|
||||
wantProxy: "http://127.0.0.1:7890",
|
||||
},
|
||||
{
|
||||
name: "missing keys returns model config error",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "custom-model"
|
||||
},
|
||||
wantErrSubstr: "no API key configured for model",
|
||||
},
|
||||
{
|
||||
name: "openrouter prefix without key returns provider key error",
|
||||
setup: func(cfg *config.Config) {
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
},
|
||||
wantErrSubstr: "no API key configured for provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
tt.setup(cfg)
|
||||
|
||||
got, err := resolveProviderSelection(cfg)
|
||||
if tt.wantErrSubstr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
|
||||
t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("resolveProviderSelection() error = %v", err)
|
||||
}
|
||||
if got.providerType != tt.wantType {
|
||||
t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType)
|
||||
}
|
||||
if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase {
|
||||
t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase)
|
||||
}
|
||||
if tt.wantProxy != "" && got.proxy != tt.wantProxy {
|
||||
t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "openrouter/auto"
|
||||
cfg.Providers.OpenRouter.APIKey = "sk-or-test"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*HTTPProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *HTTPProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "codex-code"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*CodexCliProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *CodexCliProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "openai"
|
||||
cfg.Providers.OpenAI.AuthMethod = "codex-cli"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*CodexProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *CodexProvider", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) {
|
||||
originalGetCredential := getCredential
|
||||
t.Cleanup(func() { getCredential = originalGetCredential })
|
||||
|
||||
getCredential = func(provider string) (*auth.AuthCredential, error) {
|
||||
if provider != "anthropic" {
|
||||
t.Fatalf("provider = %q, want anthropic", provider)
|
||||
}
|
||||
return &auth.AuthCredential{
|
||||
AccessToken: "anthropic-token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "anthropic"
|
||||
cfg.Providers.Anthropic.AuthMethod = "oauth"
|
||||
cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
claudeProvider, ok := provider.(*ClaudeProvider)
|
||||
if !ok {
|
||||
t.Fatalf("provider type = %T, want *ClaudeProvider", provider)
|
||||
}
|
||||
if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" {
|
||||
t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) {
|
||||
originalGetCredential := getCredential
|
||||
t.Cleanup(func() { getCredential = originalGetCredential })
|
||||
|
||||
getCredential = func(provider string) (*auth.AuthCredential, error) {
|
||||
if provider != "openai" {
|
||||
t.Fatalf("provider = %q, want openai", provider)
|
||||
}
|
||||
return &auth.AuthCredential{
|
||||
AccessToken: "openai-token",
|
||||
AccountID: "acct_123",
|
||||
}, nil
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Provider = "openai"
|
||||
cfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
|
||||
provider, err := CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := provider.(*CodexProvider); !ok {
|
||||
t.Fatalf("provider type = %T, want *CodexProvider", provider)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FallbackChain orchestrates model fallback across multiple candidates.
|
||||
type FallbackChain struct {
|
||||
cooldown *CooldownTracker
|
||||
}
|
||||
|
||||
// FallbackCandidate represents one model/provider to try.
|
||||
type FallbackCandidate struct {
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// FallbackResult contains the successful response and metadata about all attempts.
|
||||
type FallbackResult struct {
|
||||
Response *LLMResponse
|
||||
Provider string
|
||||
Model string
|
||||
Attempts []FallbackAttempt
|
||||
}
|
||||
|
||||
// FallbackAttempt records one attempt in the fallback chain.
|
||||
type FallbackAttempt struct {
|
||||
Provider string
|
||||
Model string
|
||||
Error error
|
||||
Reason FailoverReason
|
||||
Duration time.Duration
|
||||
Skipped bool // true if skipped due to cooldown
|
||||
}
|
||||
|
||||
// NewFallbackChain creates a new fallback chain with the given cooldown tracker.
|
||||
func NewFallbackChain(cooldown *CooldownTracker) *FallbackChain {
|
||||
return &FallbackChain{cooldown: cooldown}
|
||||
}
|
||||
|
||||
// ResolveCandidates parses model config into a deduplicated candidate list.
|
||||
func ResolveCandidates(cfg ModelConfig, defaultProvider string) []FallbackCandidate {
|
||||
seen := make(map[string]bool)
|
||||
var candidates []FallbackCandidate
|
||||
|
||||
addCandidate := func(raw string) {
|
||||
ref := ParseModelRef(raw, defaultProvider)
|
||||
if ref == nil {
|
||||
return
|
||||
}
|
||||
key := ModelKey(ref.Provider, ref.Model)
|
||||
if seen[key] {
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
candidates = append(candidates, FallbackCandidate{
|
||||
Provider: ref.Provider,
|
||||
Model: ref.Model,
|
||||
})
|
||||
}
|
||||
|
||||
// Primary first.
|
||||
addCandidate(cfg.Primary)
|
||||
|
||||
// Then fallbacks.
|
||||
for _, fb := range cfg.Fallbacks {
|
||||
addCandidate(fb)
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// Execute runs the fallback chain for text/chat requests.
|
||||
// It tries each candidate in order, respecting cooldowns and error classification.
|
||||
//
|
||||
// Behavior:
|
||||
// - Candidates in cooldown are skipped (logged as skipped attempt).
|
||||
// - context.Canceled aborts immediately (user abort, no fallback).
|
||||
// - Non-retriable errors (format) abort immediately.
|
||||
// - Retriable errors trigger fallback to next candidate.
|
||||
// - Success marks provider as good (resets cooldown).
|
||||
// - If all fail, returns aggregate error with all attempts.
|
||||
func (fc *FallbackChain) Execute(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("fallback: no candidates configured")
|
||||
}
|
||||
|
||||
result := &FallbackResult{
|
||||
Attempts: make([]FallbackAttempt, 0, len(candidates)),
|
||||
}
|
||||
|
||||
for i, candidate := range candidates {
|
||||
// Check context before each attempt.
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Check cooldown.
|
||||
if !fc.cooldown.IsAvailable(candidate.Provider) {
|
||||
remaining := fc.cooldown.CooldownRemaining(candidate.Provider)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Skipped: true,
|
||||
Reason: FailoverRateLimit,
|
||||
Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute the run function.
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
// Success.
|
||||
fc.cooldown.MarkSuccess(candidate.Provider)
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Context cancellation: abort immediately, no fallback.
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Classify the error.
|
||||
failErr := ClassifyError(err, candidate.Provider, candidate.Model)
|
||||
|
||||
if failErr == nil {
|
||||
// Unclassifiable error: do not fallback, return immediately.
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, fmt.Errorf("fallback: unclassified error from %s/%s: %w",
|
||||
candidate.Provider, candidate.Model, err)
|
||||
}
|
||||
|
||||
// Non-retriable error: abort immediately.
|
||||
if !failErr.IsRetriable() {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: failErr,
|
||||
Reason: failErr.Reason,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, failErr
|
||||
}
|
||||
|
||||
// Retriable error: mark failure and continue to next candidate.
|
||||
fc.cooldown.MarkFailure(candidate.Provider, failErr.Reason)
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: failErr,
|
||||
Reason: failErr.Reason,
|
||||
Duration: elapsed,
|
||||
})
|
||||
|
||||
// If this was the last candidate, return aggregate error.
|
||||
if i == len(candidates)-1 {
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
}
|
||||
|
||||
// All candidates were skipped (all in cooldown).
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
|
||||
// ExecuteImage runs the fallback chain for image/vision requests.
|
||||
// Simpler than Execute: no cooldown checks (image endpoints have different rate limits).
|
||||
// Image dimension/size errors abort immediately (non-retriable).
|
||||
func (fc *FallbackChain) ExecuteImage(
|
||||
ctx context.Context,
|
||||
candidates []FallbackCandidate,
|
||||
run func(ctx context.Context, provider, model string) (*LLMResponse, error),
|
||||
) (*FallbackResult, error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, fmt.Errorf("image fallback: no candidates configured")
|
||||
}
|
||||
|
||||
result := &FallbackResult{
|
||||
Attempts: make([]FallbackAttempt, 0, len(candidates)),
|
||||
}
|
||||
|
||||
for i, candidate := range candidates {
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := run(ctx, candidate.Provider, candidate.Model)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
result.Response = resp
|
||||
result.Provider = candidate.Provider
|
||||
result.Model = candidate.Model
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
// Image dimension/size errors are non-retriable.
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if IsImageDimensionError(errMsg) || IsImageSizeError(errMsg) {
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Reason: FailoverFormat,
|
||||
Duration: elapsed,
|
||||
})
|
||||
return nil, &FailoverError{
|
||||
Reason: FailoverFormat,
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Wrapped: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Any other error: record and try next.
|
||||
result.Attempts = append(result.Attempts, FallbackAttempt{
|
||||
Provider: candidate.Provider,
|
||||
Model: candidate.Model,
|
||||
Error: err,
|
||||
Duration: elapsed,
|
||||
})
|
||||
|
||||
if i == len(candidates)-1 {
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, &FallbackExhaustedError{Attempts: result.Attempts}
|
||||
}
|
||||
|
||||
// FallbackExhaustedError indicates all fallback candidates were tried and failed.
|
||||
type FallbackExhaustedError struct {
|
||||
Attempts []FallbackAttempt
|
||||
}
|
||||
|
||||
func (e *FallbackExhaustedError) Error() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("fallback: all %d candidates failed:", len(e.Attempts)))
|
||||
for i, a := range e.Attempts {
|
||||
if a.Skipped {
|
||||
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: skipped (cooldown)", i+1, a.Provider, a.Model))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf("\n [%d] %s/%s: %v (reason=%s, %s)",
|
||||
i+1, a.Provider, a.Model, a.Error, a.Reason, a.Duration.Round(time.Millisecond)))
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,473 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func makeCandidate(provider, model string) FallbackCandidate {
|
||||
return FallbackCandidate{Provider: provider, Model: model}
|
||||
}
|
||||
|
||||
func successRun(content string) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return &LLMResponse{Content: content, FinishReason: "stop"}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SingleCandidate_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "hello" {
|
||||
t.Errorf("content = %q, want hello", result.Response.Content)
|
||||
}
|
||||
if result.Provider != "openai" || result.Model != "gpt-4" {
|
||||
t.Errorf("provider/model = %s/%s, want openai/gpt-4", result.Provider, result.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SecondCandidateSuccess(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude-opus"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
return &LLMResponse{Content: "from claude", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
if result.Response.Content != "from claude" {
|
||||
t.Errorf("content = %q, want 'from claude'", result.Response.Content)
|
||||
}
|
||||
if len(result.Attempts) != 1 {
|
||||
t.Errorf("attempts = %d, want 1 (failed attempt recorded)", len(result.Attempts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_AllFail(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
makeCandidate("groq", "llama"),
|
||||
}
|
||||
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all candidates fail")
|
||||
}
|
||||
var exhausted *FallbackExhaustedError
|
||||
if !errors.As(err, &exhausted) {
|
||||
t.Errorf("expected FallbackExhaustedError, got %T: %v", err, err)
|
||||
}
|
||||
if len(exhausted.Attempts) != 3 {
|
||||
t.Errorf("attempts = %d, want 3", len(exhausted.Attempts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_ContextCanceled(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
cancel() // cancel context
|
||||
return nil, context.Canceled
|
||||
}
|
||||
t.Error("should not reach second candidate after cancel")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, err := fc.Execute(ctx, candidates, run)
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NonRetriableError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("string should match pattern")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-retriable")
|
||||
}
|
||||
var fe *FailoverError
|
||||
if !errors.As(err, &fe) {
|
||||
t.Fatalf("expected FailoverError, got %T", err)
|
||||
}
|
||||
if fe.Reason != FailoverFormat {
|
||||
t.Errorf("reason = %q, want format", fe.Reason)
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (non-retriable should not try next)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_CooldownSkip(t *testing.T) {
|
||||
now := time.Now()
|
||||
ct, _ := newTestTracker(now)
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put openai in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
if provider == "openai" {
|
||||
t.Error("should not call openai (in cooldown)")
|
||||
}
|
||||
return &LLMResponse{Content: "claude response", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
// Should have 1 skipped attempt
|
||||
skipped := 0
|
||||
for _, a := range result.Attempts {
|
||||
if a.Skipped {
|
||||
skipped++
|
||||
}
|
||||
}
|
||||
if skipped != 1 {
|
||||
t.Errorf("skipped = %d, want 1", skipped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_AllInCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
// Put all providers in cooldown
|
||||
ct.MarkFailure("openai", FailoverRateLimit)
|
||||
ct.MarkFailure("anthropic", FailoverBilling)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates,
|
||||
func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
t.Error("should not call any provider (all in cooldown)")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error when all in cooldown")
|
||||
}
|
||||
var exhausted *FallbackExhaustedError
|
||||
if !errors.As(err, &exhausted) {
|
||||
t.Fatalf("expected FallbackExhaustedError, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
_, err := fc.Execute(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty candidates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_EmptyFallbacks(t *testing.T) {
|
||||
// Single primary, no fallbacks: should work like direct call
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
result, err := fc.Execute(context.Background(), candidates, successRun("ok"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "ok" {
|
||||
t.Error("expected success with single candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_UnclassifiedError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("completely unknown internal error")
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unclassified error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (should not fallback on unclassified)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallback_SuccessResetsCooldown(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4")}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
ct.MarkFailure("openai", FailoverRateLimit) // simulate failure tracked elsewhere
|
||||
}
|
||||
return &LLMResponse{Content: "ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
_, err := fc.Execute(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !ct.IsAvailable("openai") {
|
||||
t.Error("success should reset cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Image Fallback Tests ---
|
||||
|
||||
func TestImageFallback_Success(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{makeCandidate("openai", "gpt-4o")}
|
||||
result, err := fc.ExecuteImage(context.Background(), candidates, successRun("image result"))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Response.Content != "image result" {
|
||||
t.Error("expected image result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_DimensionError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("image dimensions exceed max 4096x4096")
|
||||
}
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for image dimension error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (image dimension error should not retry)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_SizeError(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
return nil, errors.New("image exceeds 20 mb")
|
||||
}
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for image size error")
|
||||
}
|
||||
if attempt != 1 {
|
||||
t.Errorf("attempt = %d, want 1 (image size error should not retry)", attempt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_RetryOnOtherErrors(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
candidates := []FallbackCandidate{
|
||||
makeCandidate("openai", "gpt-4o"),
|
||||
makeCandidate("anthropic", "claude-sonnet"),
|
||||
}
|
||||
|
||||
attempt := 0
|
||||
run := func(ctx context.Context, provider, model string) (*LLMResponse, error) {
|
||||
attempt++
|
||||
if attempt == 1 {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
return &LLMResponse{Content: "image ok", FinishReason: "stop"}, nil
|
||||
}
|
||||
|
||||
result, err := fc.ExecuteImage(context.Background(), candidates, run)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", result.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageFallback_NoCandidates(t *testing.T) {
|
||||
ct := NewCooldownTracker()
|
||||
fc := NewFallbackChain(ct)
|
||||
|
||||
_, err := fc.ExecuteImage(context.Background(), nil, successRun("ok"))
|
||||
if err == nil {
|
||||
t.Error("expected error for empty candidates")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResolveCandidates Tests ---
|
||||
|
||||
func TestResolveCandidates_Simple(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: []string{"anthropic/claude-opus", "groq/llama-3"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 3 {
|
||||
t.Fatalf("candidates = %d, want 3", len(candidates))
|
||||
}
|
||||
|
||||
if candidates[0].Provider != "openai" || candidates[0].Model != "gpt-4" {
|
||||
t.Errorf("candidate[0] = %s/%s, want openai/gpt-4", candidates[0].Provider, candidates[0].Model)
|
||||
}
|
||||
if candidates[1].Provider != "anthropic" || candidates[1].Model != "claude-opus" {
|
||||
t.Errorf("candidate[1] = %s/%s, want anthropic/claude-opus", candidates[1].Provider, candidates[1].Model)
|
||||
}
|
||||
if candidates[2].Provider != "groq" || candidates[2].Model != "llama-3" {
|
||||
t.Errorf("candidate[2] = %s/%s, want groq/llama-3", candidates[2].Provider, candidates[2].Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_Deduplication(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "openai/gpt-4",
|
||||
Fallbacks: []string{"openai/gpt-4", "anthropic/claude"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "default")
|
||||
if len(candidates) != 2 {
|
||||
t.Errorf("candidates = %d, want 2 (duplicate removed)", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_EmptyFallbacks(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "gpt-4",
|
||||
Fallbacks: nil,
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 1 {
|
||||
t.Errorf("candidates = %d, want 1", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCandidates_EmptyPrimary(t *testing.T) {
|
||||
cfg := ModelConfig{
|
||||
Primary: "",
|
||||
Fallbacks: []string{"anthropic/claude"},
|
||||
}
|
||||
|
||||
candidates := ResolveCandidates(cfg, "openai")
|
||||
if len(candidates) != 1 {
|
||||
t.Errorf("candidates = %d, want 1", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackExhaustedError_Message(t *testing.T) {
|
||||
e := &FallbackExhaustedError{
|
||||
Attempts: []FallbackAttempt{
|
||||
{Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond},
|
||||
{Provider: "anthropic", Model: "claude", Skipped: true},
|
||||
},
|
||||
}
|
||||
msg := e.Error()
|
||||
if msg == "" {
|
||||
t.Error("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
@@ -7,444 +7,25 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/openai_compat"
|
||||
)
|
||||
|
||||
type HTTPProvider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
delegate *openai_compat.Provider
|
||||
}
|
||||
|
||||
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
proxyURL, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &HTTPProvider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
delegate: openai_compat.NewProvider(apiKey, apiBase, proxy),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
|
||||
if idx := strings.Index(model, "/"); idx != -1 {
|
||||
prefix := model[:idx]
|
||||
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
|
||||
model = model[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
} else {
|
||||
requestBody["max_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
if temperature, ok := options["temperature"].(float64); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
requestBody["temperature"] = 1.0
|
||||
} else {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return p.parseResponse(body)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
|
||||
// Handle OpenAI format with nested function object
|
||||
if tc.Type == "function" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else if tc.Function != nil {
|
||||
// Legacy format without type field
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
})
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
return p.delegate.Chat(ctx, messages, tools, model, options)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||
|
||||
var apiKey, apiBase, proxy string
|
||||
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// First, try to use explicitly configured provider
|
||||
if providerName != "" {
|
||||
switch providerName {
|
||||
case "groq":
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
}
|
||||
case "anthropic", "claude":
|
||||
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
return createClaudeAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
}
|
||||
case "openrouter":
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
}
|
||||
case "zhipu", "glm":
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
}
|
||||
case "gemini", "google":
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
}
|
||||
case "vllm":
|
||||
if cfg.Providers.VLLM.APIBase != "" {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), nil
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
if apiKey == "" && apiBase == "" {
|
||||
switch {
|
||||
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
|
||||
apiKey = cfg.Providers.Moonshot.APIKey
|
||||
apiBase = cfg.Providers.Moonshot.APIBase
|
||||
proxy = cfg.Providers.Moonshot.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.moonshot.cn/v1"
|
||||
}
|
||||
|
||||
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
return createClaudeAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
proxy = cfg.Providers.Anthropic.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
proxy = cfg.Providers.OpenAI.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||
apiKey = cfg.Providers.Gemini.APIKey
|
||||
apiBase = cfg.Providers.Gemini.APIBase
|
||||
proxy = cfg.Providers.Gemini.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||
apiKey = cfg.Providers.Zhipu.APIKey
|
||||
apiBase = cfg.Providers.Zhipu.APIBase
|
||||
proxy = cfg.Providers.Zhipu.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||
apiKey = cfg.Providers.Groq.APIKey
|
||||
apiBase = cfg.Providers.Groq.APIBase
|
||||
proxy = cfg.Providers.Groq.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.groq.com/openai/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
|
||||
apiKey = cfg.Providers.Nvidia.APIKey
|
||||
apiBase = cfg.Providers.Nvidia.APIBase
|
||||
proxy = cfg.Providers.Nvidia.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
proxy = cfg.Providers.VLLM.Proxy
|
||||
|
||||
default:
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||
proxy = cfg.Providers.OpenRouter.Proxy
|
||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||
} else {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("no API key configured for model: %s", model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
if apiBase == "" {
|
||||
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
|
||||
}
|
||||
|
||||
return NewHTTPProvider(apiKey, apiBase, proxy), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package providers
|
||||
|
||||
import "strings"
|
||||
|
||||
// ModelRef represents a parsed model reference with provider and model name.
|
||||
type ModelRef struct {
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// ParseModelRef parses "anthropic/claude-opus" into {Provider: "anthropic", Model: "claude-opus"}.
|
||||
// If no slash present, uses defaultProvider.
|
||||
// Returns nil for empty input.
|
||||
func ParseModelRef(raw string, defaultProvider string) *ModelRef {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if idx := strings.Index(raw, "/"); idx > 0 {
|
||||
provider := NormalizeProvider(raw[:idx])
|
||||
model := strings.TrimSpace(raw[idx+1:])
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
return &ModelRef{Provider: provider, Model: model}
|
||||
}
|
||||
|
||||
return &ModelRef{
|
||||
Provider: NormalizeProvider(defaultProvider),
|
||||
Model: raw,
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeProvider normalizes provider identifiers to canonical form.
|
||||
func NormalizeProvider(provider string) string {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
|
||||
switch p {
|
||||
case "z.ai", "z-ai":
|
||||
return "zai"
|
||||
case "opencode-zen":
|
||||
return "opencode"
|
||||
case "qwen":
|
||||
return "qwen-portal"
|
||||
case "kimi-code":
|
||||
return "kimi-coding"
|
||||
case "gpt":
|
||||
return "openai"
|
||||
case "claude":
|
||||
return "anthropic"
|
||||
case "glm":
|
||||
return "zhipu"
|
||||
case "google":
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// ModelKey returns a canonical "provider/model" key for deduplication.
|
||||
func ModelKey(provider, model string) string {
|
||||
return NormalizeProvider(provider) + "/" + strings.ToLower(strings.TrimSpace(model))
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseModelRef_WithSlash(t *testing.T) {
|
||||
ref := ParseModelRef("anthropic/claude-opus", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", ref.Provider)
|
||||
}
|
||||
if ref.Model != "claude-opus" {
|
||||
t.Errorf("model = %q, want claude-opus", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_WithoutSlash(t *testing.T) {
|
||||
ref := ParseModelRef("gpt-4", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "openai" {
|
||||
t.Errorf("provider = %q, want openai", ref.Provider)
|
||||
}
|
||||
if ref.Model != "gpt-4" {
|
||||
t.Errorf("model = %q, want gpt-4", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_Empty(t *testing.T) {
|
||||
ref := ParseModelRef("", "openai")
|
||||
if ref != nil {
|
||||
t.Errorf("expected nil for empty string, got %+v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_EmptyModelAfterSlash(t *testing.T) {
|
||||
ref := ParseModelRef("openai/", "default")
|
||||
if ref != nil {
|
||||
t.Errorf("expected nil for empty model, got %+v", ref)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_WhitespaceHandling(t *testing.T) {
|
||||
ref := ParseModelRef(" anthropic / claude-opus ", "openai")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "anthropic" {
|
||||
t.Errorf("provider = %q, want anthropic", ref.Provider)
|
||||
}
|
||||
if ref.Model != "claude-opus" {
|
||||
t.Errorf("model = %q, want claude-opus", ref.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"OpenAI", "openai"},
|
||||
{"ANTHROPIC", "anthropic"},
|
||||
{"z.ai", "zai"},
|
||||
{"z-ai", "zai"},
|
||||
{"Z.AI", "zai"},
|
||||
{"opencode-zen", "opencode"},
|
||||
{"qwen", "qwen-portal"},
|
||||
{"kimi-code", "kimi-coding"},
|
||||
{"gpt", "openai"},
|
||||
{"claude", "anthropic"},
|
||||
{"glm", "zhipu"},
|
||||
{"google", "gemini"},
|
||||
{"groq", "groq"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := NormalizeProvider(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeProvider(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
provider string
|
||||
model string
|
||||
want string
|
||||
}{
|
||||
{"openai", "gpt-4", "openai/gpt-4"},
|
||||
{"Anthropic", "Claude-Opus", "anthropic/claude-opus"},
|
||||
{"claude", "sonnet", "anthropic/sonnet"},
|
||||
{"z.ai", "Model-X", "zai/model-x"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := ModelKey(tt.provider, tt.model)
|
||||
if got != tt.want {
|
||||
t.Errorf("ModelKey(%q, %q) = %q, want %q", tt.provider, tt.model, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_ProviderNormalization(t *testing.T) {
|
||||
ref := ParseModelRef("Z.AI/model-x", "default")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "zai" {
|
||||
t.Errorf("provider = %q, want zai", ref.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseModelRef_DefaultProviderNormalization(t *testing.T) {
|
||||
ref := ParseModelRef("gpt-4o", "GPT")
|
||||
if ref == nil {
|
||||
t.Fatal("expected non-nil ref")
|
||||
}
|
||||
if ref.Provider != "openai" {
|
||||
t.Errorf("provider = %q, want openai (normalized from GPT)", ref.Provider)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
} else {
|
||||
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
model = normalizeModel(model, p.apiBase)
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := asInt(options["max_tokens"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
} else {
|
||||
requestBody["max_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
if temperature, ok := asFloat(options["temperature"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1.
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
requestBody["temperature"] = 1.0
|
||||
} else {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return parseResponse(body)
|
||||
}
|
||||
|
||||
func parseResponse(body []byte) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]interface{})
|
||||
name := ""
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
if tc.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
})
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
idx := strings.Index(model, "/")
|
||||
if idx == -1 {
|
||||
return model
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") {
|
||||
return model
|
||||
}
|
||||
|
||||
prefix := strings.ToLower(model[:idx])
|
||||
switch prefix {
|
||||
case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
|
||||
return model[idx+1:]
|
||||
default:
|
||||
return model
|
||||
}
|
||||
}
|
||||
|
||||
func asInt(v interface{}) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, ok := requestBody["max_completion_tokens"]; !ok {
|
||||
t.Fatalf("expected max_completion_tokens in request body")
|
||||
}
|
||||
if _, ok := requestBody["max_tokens"]; ok {
|
||||
t.Fatalf("did not expect max_tokens key for glm model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_ParsesToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{
|
||||
"content": "",
|
||||
"tool_calls": []map[string]interface{}{
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"SF\"}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["city"] != "SF" {
|
||||
t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_HTTPError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"moonshot/kimi-k2.5",
|
||||
map[string]interface{}{"temperature": 0.3},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["model"] != "kimi-k2.5" {
|
||||
t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"])
|
||||
}
|
||||
if requestBody["temperature"] != 1.0 {
|
||||
t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "strips groq prefix and keeps nested model",
|
||||
input: "groq/openai/gpt-oss-120b",
|
||||
wantModel: "openai/gpt-oss-120b",
|
||||
},
|
||||
{
|
||||
name: "strips ollama prefix",
|
||||
input: "ollama/qwen2.5:14b",
|
||||
wantModel: "qwen2.5:14b",
|
||||
},
|
||||
{
|
||||
name: "strips deepseek prefix",
|
||||
input: "deepseek/deepseek-chat",
|
||||
wantModel: "deepseek-chat",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["model"] != tt.wantModel {
|
||||
t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_ProxyConfigured(t *testing.T) {
|
||||
proxyURL := "http://127.0.0.1:8080"
|
||||
p := NewProvider("key", "https://example.com", proxyURL)
|
||||
|
||||
transport, ok := p.httpClient.Transport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport)
|
||||
}
|
||||
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
|
||||
gotProxy, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy function returned error: %v", err)
|
||||
}
|
||||
if gotProxy == nil || gotProxy.String() != proxyURL {
|
||||
t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"message": map[string]interface{}{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if requestBody["max_tokens"] != float64(512) {
|
||||
t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"])
|
||||
}
|
||||
if requestBody["temperature"] != float64(1) {
|
||||
t.Fatalf("temperature = %v, want 1", requestBody["temperature"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeModel_UsesAPIBase(t *testing.T) {
|
||||
if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" {
|
||||
t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat")
|
||||
}
|
||||
if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" {
|
||||
t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package protocoltypes
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
+51
-39
@@ -1,52 +1,64 @@
|
||||
package providers
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function *FunctionCall `json:"function,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage *UsageInfo `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
type ToolCall = protocoltypes.ToolCall
|
||||
type FunctionCall = protocoltypes.FunctionCall
|
||||
type LLMResponse = protocoltypes.LLMResponse
|
||||
type UsageInfo = protocoltypes.UsageInfo
|
||||
type Message = protocoltypes.Message
|
||||
type ToolDefinition = protocoltypes.ToolDefinition
|
||||
type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
|
||||
type LLMProvider interface {
|
||||
Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
|
||||
GetDefaultModel() string
|
||||
}
|
||||
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDefinition `json:"function"`
|
||||
// FailoverReason classifies why an LLM request failed for fallback decisions.
|
||||
type FailoverReason string
|
||||
|
||||
const (
|
||||
FailoverAuth FailoverReason = "auth"
|
||||
FailoverRateLimit FailoverReason = "rate_limit"
|
||||
FailoverBilling FailoverReason = "billing"
|
||||
FailoverTimeout FailoverReason = "timeout"
|
||||
FailoverFormat FailoverReason = "format"
|
||||
FailoverOverloaded FailoverReason = "overloaded"
|
||||
FailoverUnknown FailoverReason = "unknown"
|
||||
)
|
||||
|
||||
// FailoverError wraps an LLM provider error with classification metadata.
|
||||
type FailoverError struct {
|
||||
Reason FailoverReason
|
||||
Provider string
|
||||
Model string
|
||||
Status int
|
||||
Wrapped error
|
||||
}
|
||||
|
||||
type ToolFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
func (e *FailoverError) Error() string {
|
||||
return fmt.Sprintf("failover(%s): provider=%s model=%s status=%d: %v",
|
||||
e.Reason, e.Provider, e.Model, e.Status, e.Wrapped)
|
||||
}
|
||||
|
||||
func (e *FailoverError) Unwrap() error {
|
||||
return e.Wrapped
|
||||
}
|
||||
|
||||
// IsRetriable returns true if this error should trigger fallback to next candidate.
|
||||
// Non-retriable: Format errors (bad request structure, image dimension/size).
|
||||
func (e *FailoverError) IsRetriable() bool {
|
||||
return e.Reason != FailoverFormat
|
||||
}
|
||||
|
||||
// ModelConfig holds primary model and fallback list.
|
||||
type ModelConfig struct {
|
||||
Primary string
|
||||
Fallbacks []string
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultAgentID = "main"
|
||||
DefaultMainKey = "main"
|
||||
DefaultAccountID = "default"
|
||||
MaxAgentIDLength = 64
|
||||
)
|
||||
|
||||
var (
|
||||
validIDRe = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{0,63}$`)
|
||||
invalidCharsRe = regexp.MustCompile(`[^a-z0-9_-]+`)
|
||||
leadingDashRe = regexp.MustCompile(`^-+`)
|
||||
trailingDashRe = regexp.MustCompile(`-+$`)
|
||||
)
|
||||
|
||||
// NormalizeAgentID sanitizes an agent ID to [a-z0-9][a-z0-9_-]{0,63}.
|
||||
// Invalid characters are collapsed to "-". Leading/trailing dashes stripped.
|
||||
// Empty input returns DefaultAgentID ("main").
|
||||
func NormalizeAgentID(id string) string {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" {
|
||||
return DefaultAgentID
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
if validIDRe.MatchString(lower) {
|
||||
return lower
|
||||
}
|
||||
result := invalidCharsRe.ReplaceAllString(lower, "-")
|
||||
result = leadingDashRe.ReplaceAllString(result, "")
|
||||
result = trailingDashRe.ReplaceAllString(result, "")
|
||||
if len(result) > MaxAgentIDLength {
|
||||
result = result[:MaxAgentIDLength]
|
||||
}
|
||||
if result == "" {
|
||||
return DefaultAgentID
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// NormalizeAccountID sanitizes an account ID. Empty returns DefaultAccountID.
|
||||
func NormalizeAccountID(id string) string {
|
||||
trimmed := strings.TrimSpace(id)
|
||||
if trimmed == "" {
|
||||
return DefaultAccountID
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
if validIDRe.MatchString(lower) {
|
||||
return lower
|
||||
}
|
||||
result := invalidCharsRe.ReplaceAllString(lower, "-")
|
||||
result = leadingDashRe.ReplaceAllString(result, "")
|
||||
result = trailingDashRe.ReplaceAllString(result, "")
|
||||
if len(result) > MaxAgentIDLength {
|
||||
result = result[:MaxAgentIDLength]
|
||||
}
|
||||
if result == "" {
|
||||
return DefaultAccountID
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeAgentID_Empty(t *testing.T) {
|
||||
if got := NormalizeAgentID(""); got != DefaultAgentID {
|
||||
t.Errorf("NormalizeAgentID('') = %q, want %q", got, DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_Whitespace(t *testing.T) {
|
||||
if got := NormalizeAgentID(" "); got != DefaultAgentID {
|
||||
t.Errorf("NormalizeAgentID(' ') = %q, want %q", got, DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_Valid(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, want string
|
||||
}{
|
||||
{"main", "main"},
|
||||
{"Main", "main"},
|
||||
{"SALES", "sales"},
|
||||
{"support-bot", "support-bot"},
|
||||
{"agent_1", "agent_1"},
|
||||
{"a", "a"},
|
||||
{"0test", "0test"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := NormalizeAgentID(tt.input); got != tt.want {
|
||||
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_InvalidChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
input, want string
|
||||
}{
|
||||
{"Hello World", "hello-world"},
|
||||
{"agent@123", "agent-123"},
|
||||
{"foo.bar.baz", "foo-bar-baz"},
|
||||
{"--leading", "leading"},
|
||||
{"--both--", "both"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := NormalizeAgentID(tt.input); got != tt.want {
|
||||
t.Errorf("NormalizeAgentID(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_AllInvalid(t *testing.T) {
|
||||
if got := NormalizeAgentID("@@@"); got != DefaultAgentID {
|
||||
t.Errorf("NormalizeAgentID('@@@') = %q, want %q", got, DefaultAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAgentID_TruncatesAt64(t *testing.T) {
|
||||
long := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
long += "a"
|
||||
}
|
||||
got := NormalizeAgentID(long)
|
||||
if len(got) > MaxAgentIDLength {
|
||||
t.Errorf("length = %d, want <= %d", len(got), MaxAgentIDLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAccountID_Empty(t *testing.T) {
|
||||
if got := NormalizeAccountID(""); got != DefaultAccountID {
|
||||
t.Errorf("NormalizeAccountID('') = %q, want %q", got, DefaultAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAccountID_Valid(t *testing.T) {
|
||||
if got := NormalizeAccountID("MyBot"); got != "mybot" {
|
||||
t.Errorf("NormalizeAccountID('MyBot') = %q, want 'mybot'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAccountID_InvalidChars(t *testing.T) {
|
||||
if got := NormalizeAccountID("bot@home"); got != "bot-home" {
|
||||
t.Errorf("NormalizeAccountID('bot@home') = %q, want 'bot-home'", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// RouteInput contains the routing context from an inbound message.
|
||||
type RouteInput struct {
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
ParentPeer *RoutePeer
|
||||
GuildID string
|
||||
TeamID string
|
||||
}
|
||||
|
||||
// ResolvedRoute is the result of agent routing.
|
||||
type ResolvedRoute struct {
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
SessionKey string
|
||||
MainSessionKey string
|
||||
MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default"
|
||||
}
|
||||
|
||||
// RouteResolver determines which agent handles a message based on config bindings.
|
||||
type RouteResolver struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewRouteResolver creates a new route resolver.
|
||||
func NewRouteResolver(cfg *config.Config) *RouteResolver {
|
||||
return &RouteResolver{cfg: cfg}
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message and constructs session keys.
|
||||
// Implements the 7-level priority cascade:
|
||||
// peer > parent_peer > guild > team > account > channel_wildcard > default
|
||||
func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
|
||||
channel := strings.ToLower(strings.TrimSpace(input.Channel))
|
||||
accountID := NormalizeAccountID(input.AccountID)
|
||||
peer := input.Peer
|
||||
|
||||
dmScope := DMScope(r.cfg.Session.DMScope)
|
||||
if dmScope == "" {
|
||||
dmScope = DMScopeMain
|
||||
}
|
||||
identityLinks := r.cfg.Session.IdentityLinks
|
||||
|
||||
bindings := r.filterBindings(channel, accountID)
|
||||
|
||||
choose := func(agentID string, matchedBy string) ResolvedRoute {
|
||||
resolvedAgentID := r.pickAgentID(agentID)
|
||||
sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: resolvedAgentID,
|
||||
Channel: channel,
|
||||
AccountID: accountID,
|
||||
Peer: peer,
|
||||
DMScope: dmScope,
|
||||
IdentityLinks: identityLinks,
|
||||
}))
|
||||
mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID))
|
||||
return ResolvedRoute{
|
||||
AgentID: resolvedAgentID,
|
||||
Channel: channel,
|
||||
AccountID: accountID,
|
||||
SessionKey: sessionKey,
|
||||
MainSessionKey: mainSessionKey,
|
||||
MatchedBy: matchedBy,
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 1: Peer binding
|
||||
if peer != nil && strings.TrimSpace(peer.ID) != "" {
|
||||
if match := r.findPeerMatch(bindings, peer); match != nil {
|
||||
return choose(match.AgentID, "binding.peer")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Parent peer binding
|
||||
parentPeer := input.ParentPeer
|
||||
if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" {
|
||||
if match := r.findPeerMatch(bindings, parentPeer); match != nil {
|
||||
return choose(match.AgentID, "binding.peer.parent")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Guild binding
|
||||
guildID := strings.TrimSpace(input.GuildID)
|
||||
if guildID != "" {
|
||||
if match := r.findGuildMatch(bindings, guildID); match != nil {
|
||||
return choose(match.AgentID, "binding.guild")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: Team binding
|
||||
teamID := strings.TrimSpace(input.TeamID)
|
||||
if teamID != "" {
|
||||
if match := r.findTeamMatch(bindings, teamID); match != nil {
|
||||
return choose(match.AgentID, "binding.team")
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 5: Account binding
|
||||
if match := r.findAccountMatch(bindings); match != nil {
|
||||
return choose(match.AgentID, "binding.account")
|
||||
}
|
||||
|
||||
// Priority 6: Channel wildcard binding
|
||||
if match := r.findChannelWildcardMatch(bindings); match != nil {
|
||||
return choose(match.AgentID, "binding.channel")
|
||||
}
|
||||
|
||||
// Priority 7: Default agent
|
||||
return choose(r.resolveDefaultAgentID(), "default")
|
||||
}
|
||||
|
||||
func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding {
|
||||
var filtered []config.AgentBinding
|
||||
for _, b := range r.cfg.Bindings {
|
||||
matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel))
|
||||
if matchChannel == "" || matchChannel != channel {
|
||||
continue
|
||||
}
|
||||
if !matchesAccountID(b.Match.AccountID, accountID) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, b)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func matchesAccountID(matchAccountID, actual string) bool {
|
||||
trimmed := strings.TrimSpace(matchAccountID)
|
||||
if trimmed == "" {
|
||||
return actual == DefaultAccountID
|
||||
}
|
||||
if trimmed == "*" {
|
||||
return true
|
||||
}
|
||||
return strings.ToLower(trimmed) == strings.ToLower(actual)
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
if b.Match.Peer == nil {
|
||||
continue
|
||||
}
|
||||
peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind))
|
||||
peerID := strings.TrimSpace(b.Match.Peer.ID)
|
||||
if peerKind == "" || peerID == "" {
|
||||
continue
|
||||
}
|
||||
if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
matchGuild := strings.TrimSpace(b.Match.GuildID)
|
||||
if matchGuild != "" && matchGuild == guildID {
|
||||
return &bindings[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
matchTeam := strings.TrimSpace(b.Match.TeamID)
|
||||
if matchTeam != "" && matchTeam == teamID {
|
||||
return &bindings[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
accountID := strings.TrimSpace(b.Match.AccountID)
|
||||
if accountID == "*" {
|
||||
continue
|
||||
}
|
||||
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
|
||||
continue
|
||||
}
|
||||
return &bindings[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding {
|
||||
for i := range bindings {
|
||||
b := &bindings[i]
|
||||
accountID := strings.TrimSpace(b.Match.AccountID)
|
||||
if accountID != "*" {
|
||||
continue
|
||||
}
|
||||
if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" {
|
||||
continue
|
||||
}
|
||||
return &bindings[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteResolver) pickAgentID(agentID string) string {
|
||||
trimmed := strings.TrimSpace(agentID)
|
||||
if trimmed == "" {
|
||||
return NormalizeAgentID(r.resolveDefaultAgentID())
|
||||
}
|
||||
normalized := NormalizeAgentID(trimmed)
|
||||
agents := r.cfg.Agents.List
|
||||
if len(agents) == 0 {
|
||||
return normalized
|
||||
}
|
||||
for _, a := range agents {
|
||||
if NormalizeAgentID(a.ID) == normalized {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return NormalizeAgentID(r.resolveDefaultAgentID())
|
||||
}
|
||||
|
||||
func (r *RouteResolver) resolveDefaultAgentID() string {
|
||||
agents := r.cfg.Agents.List
|
||||
if len(agents) == 0 {
|
||||
return DefaultAgentID
|
||||
}
|
||||
for _, a := range agents {
|
||||
if a.Default {
|
||||
id := strings.TrimSpace(a.ID)
|
||||
if id != "" {
|
||||
return NormalizeAgentID(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if id := strings.TrimSpace(agents[0].ID); id != "" {
|
||||
return NormalizeAgentID(id)
|
||||
}
|
||||
return DefaultAgentID
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config {
|
||||
return &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: "/tmp/picoclaw-test",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
List: agents,
|
||||
},
|
||||
Bindings: bindings,
|
||||
Session: config.SessionConfig{
|
||||
DMScope: "per-peer",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
|
||||
cfg := testConfig(nil, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if route.AgentID != DefaultAgentID {
|
||||
t.Errorf("AgentID = %q, want %q", route.AgentID, DefaultAgentID)
|
||||
}
|
||||
if route.MatchedBy != "default" {
|
||||
t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_PeerBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "sales", Default: true},
|
||||
{ID: "support"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "support",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
Peer: &config.PeerMatch{Kind: "direct", ID: "user123"},
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
})
|
||||
|
||||
if route.AgentID != "support" {
|
||||
t.Errorf("AgentID = %q, want 'support'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.peer" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_GuildBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "gaming"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "gaming",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
GuildID: "guild-abc",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-abc",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "ch1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "gaming" {
|
||||
t.Errorf("AgentID = %q, want 'gaming'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.guild" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_TeamBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "work"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "work",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "slack",
|
||||
AccountID: "*",
|
||||
TeamID: "T12345",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "slack",
|
||||
TeamID: "T12345",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "C001"},
|
||||
})
|
||||
|
||||
if route.AgentID != "work" {
|
||||
t.Errorf("AgentID = %q, want 'work'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.team" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_AccountBinding(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "default-agent", Default: true},
|
||||
{ID: "premium"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "premium",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "bot2",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
AccountID: "bot2",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "premium" {
|
||||
t.Errorf("AgentID = %q, want 'premium'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.account" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_ChannelWildcard(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
{ID: "telegram-bot"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "telegram-bot",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if route.AgentID != "telegram-bot" {
|
||||
t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.channel" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "general", Default: true},
|
||||
{ID: "vip"},
|
||||
{ID: "gaming"},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "vip",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"},
|
||||
},
|
||||
},
|
||||
{
|
||||
AgentID: "gaming",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "discord",
|
||||
AccountID: "*",
|
||||
GuildID: "guild-1",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user-vip"},
|
||||
})
|
||||
|
||||
if route.AgentID != "vip" {
|
||||
t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID)
|
||||
}
|
||||
if route.MatchedBy != "binding.peer" {
|
||||
t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "main", Default: true},
|
||||
}
|
||||
bindings := []config.AgentBinding{
|
||||
{
|
||||
AgentID: "nonexistent",
|
||||
Match: config.BindingMatch{
|
||||
Channel: "telegram",
|
||||
AccountID: "*",
|
||||
},
|
||||
},
|
||||
}
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
})
|
||||
|
||||
if route.AgentID != "main" {
|
||||
t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_DefaultAgentSelection(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "alpha"},
|
||||
{ID: "beta", Default: true},
|
||||
{ID: "gamma"},
|
||||
}
|
||||
cfg := testConfig(agents, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
|
||||
if route.AgentID != "beta" {
|
||||
t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) {
|
||||
agents := []config.AgentConfig{
|
||||
{ID: "alpha"},
|
||||
{ID: "beta"},
|
||||
}
|
||||
cfg := testConfig(agents, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
|
||||
if route.AgentID != "alpha" {
|
||||
t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DMScope controls DM session isolation granularity.
|
||||
type DMScope string
|
||||
|
||||
const (
|
||||
DMScopeMain DMScope = "main"
|
||||
DMScopePerPeer DMScope = "per-peer"
|
||||
DMScopePerChannelPeer DMScope = "per-channel-peer"
|
||||
DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer"
|
||||
)
|
||||
|
||||
// RoutePeer represents a chat peer with kind and ID.
|
||||
type RoutePeer struct {
|
||||
Kind string // "direct", "group", "channel"
|
||||
ID string
|
||||
}
|
||||
|
||||
// SessionKeyParams holds all inputs for session key construction.
|
||||
type SessionKeyParams struct {
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
DMScope DMScope
|
||||
IdentityLinks map[string][]string
|
||||
}
|
||||
|
||||
// ParsedSessionKey is the result of parsing an agent-scoped session key.
|
||||
type ParsedSessionKey struct {
|
||||
AgentID string
|
||||
Rest string
|
||||
}
|
||||
|
||||
// BuildAgentMainSessionKey returns "agent:<agentId>:main".
|
||||
func BuildAgentMainSessionKey(agentID string) string {
|
||||
return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey)
|
||||
}
|
||||
|
||||
// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope.
|
||||
func BuildAgentPeerSessionKey(params SessionKeyParams) string {
|
||||
agentID := NormalizeAgentID(params.AgentID)
|
||||
|
||||
peer := params.Peer
|
||||
if peer == nil {
|
||||
peer = &RoutePeer{Kind: "direct"}
|
||||
}
|
||||
peerKind := strings.TrimSpace(peer.Kind)
|
||||
if peerKind == "" {
|
||||
peerKind = "direct"
|
||||
}
|
||||
|
||||
if peerKind == "direct" {
|
||||
dmScope := params.DMScope
|
||||
if dmScope == "" {
|
||||
dmScope = DMScopeMain
|
||||
}
|
||||
peerID := strings.TrimSpace(peer.ID)
|
||||
|
||||
// Resolve identity links (cross-platform collapse)
|
||||
if dmScope != DMScopeMain && peerID != "" {
|
||||
if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" {
|
||||
peerID = linked
|
||||
}
|
||||
}
|
||||
peerID = strings.ToLower(peerID)
|
||||
|
||||
switch dmScope {
|
||||
case DMScopePerAccountChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
accountID := NormalizeAccountID(params.AccountID)
|
||||
return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID)
|
||||
}
|
||||
case DMScopePerChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID)
|
||||
}
|
||||
case DMScopePerPeer:
|
||||
if peerID != "" {
|
||||
return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID)
|
||||
}
|
||||
}
|
||||
return BuildAgentMainSessionKey(agentID)
|
||||
}
|
||||
|
||||
// Group/channel peers always get per-peer sessions
|
||||
channel := normalizeChannel(params.Channel)
|
||||
peerID := strings.ToLower(strings.TrimSpace(peer.ID))
|
||||
if peerID == "" {
|
||||
peerID = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
|
||||
}
|
||||
|
||||
// ParseAgentSessionKey extracts agentId and rest from "agent:<agentId>:<rest>".
|
||||
func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.SplitN(raw, ":", 3)
|
||||
if len(parts) < 3 {
|
||||
return nil
|
||||
}
|
||||
if parts[0] != "agent" {
|
||||
return nil
|
||||
}
|
||||
agentID := strings.TrimSpace(parts[1])
|
||||
rest := parts[2]
|
||||
if agentID == "" || rest == "" {
|
||||
return nil
|
||||
}
|
||||
return &ParsedSessionKey{AgentID: agentID, Rest: rest}
|
||||
}
|
||||
|
||||
// IsSubagentSessionKey returns true if the session key represents a subagent.
|
||||
func IsSubagentSessionKey(sessionKey string) bool {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "subagent:") {
|
||||
return true
|
||||
}
|
||||
parsed := ParseAgentSessionKey(raw)
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:")
|
||||
}
|
||||
|
||||
func normalizeChannel(channel string) string {
|
||||
c := strings.TrimSpace(strings.ToLower(channel))
|
||||
if c == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]bool)
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = true
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
|
||||
candidates[scopedCandidate] = true
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonicalName := strings.TrimSpace(canonical)
|
||||
if canonicalName == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized != "" && candidates[normalized] {
|
||||
return canonicalName
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildAgentMainSessionKey(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("sales")
|
||||
want := "agent:sales:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("Sales Bot")
|
||||
want := "agent:sales-bot:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopeMain,
|
||||
})
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("DMScopeMain = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
AccountID: "bot1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "User123"},
|
||||
DMScope: DMScopePerAccountChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:bot1:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "group", ID: "chat456"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:telegram:group:chat456"
|
||||
if got != want {
|
||||
t.Errorf("GroupPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: nil,
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
// nil peer defaults to direct with empty ID, falls to main
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("NilPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:user123", "discord:john#1234"},
|
||||
}
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
IdentityLinks: links,
|
||||
})
|
||||
want := "agent:main:direct:john"
|
||||
if got != want {
|
||||
t.Errorf("IdentityLink = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Valid(t *testing.T) {
|
||||
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
|
||||
if parsed == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if parsed.AgentID != "sales" {
|
||||
t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID)
|
||||
}
|
||||
if parsed.Rest != "telegram:direct:user123" {
|
||||
t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Invalid(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"foo:bar",
|
||||
"notprefix:sales:main",
|
||||
"agent::main",
|
||||
"agent:sales:",
|
||||
}
|
||||
for _, input := range tests {
|
||||
if got := ParseAgentSessionKey(input); got != nil {
|
||||
t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubagentSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"subagent:task-1", true},
|
||||
{"agent:main:subagent:task-1", true},
|
||||
{"agent:main:main", false},
|
||||
{"agent:main:telegram:direct:user123", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := IsSubagentSessionKey(tt.input); got != tt.want {
|
||||
t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -24,12 +23,6 @@ type AvailableSkill struct {
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
type BuiltinSkill struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func NewSkillInstaller(workspace string) *SkillInstaller {
|
||||
return &SkillInstaller{
|
||||
workspace: workspace,
|
||||
@@ -123,49 +116,3 @@ func (si *SkillInstaller) ListAvailableSkills(ctx context.Context) ([]AvailableS
|
||||
|
||||
return skills, nil
|
||||
}
|
||||
|
||||
func (si *SkillInstaller) ListBuiltinSkills() []BuiltinSkill {
|
||||
builtinSkillsDir := filepath.Join(filepath.Dir(si.workspace), "picoclaw", "skills")
|
||||
|
||||
entries, err := os.ReadDir(builtinSkillsDir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var skills []BuiltinSkill
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
_ = entry
|
||||
skillName := entry.Name()
|
||||
skillFile := filepath.Join(builtinSkillsDir, skillName, "SKILL.md")
|
||||
|
||||
data, err := os.ReadFile(skillFile)
|
||||
description := ""
|
||||
if err == nil {
|
||||
content := string(data)
|
||||
if idx := strings.Index(content, "\n"); idx > 0 {
|
||||
firstLine := content[:idx]
|
||||
if strings.Contains(firstLine, "description:") {
|
||||
descLine := strings.Index(content[idx:], "\n")
|
||||
if descLine > 0 {
|
||||
description = strings.TrimSpace(content[idx+descLine : idx+descLine])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// skill := BuiltinSkill{
|
||||
// Name: skillName,
|
||||
// Path: description,
|
||||
// Enabled: true,
|
||||
// }
|
||||
|
||||
status := "✓"
|
||||
fmt.Printf(" %s %s\n", status, entry.Name())
|
||||
if description != "" {
|
||||
fmt.Printf(" %s\n", description)
|
||||
}
|
||||
}
|
||||
}
|
||||
return skills
|
||||
}
|
||||
|
||||
+22
-5
@@ -9,6 +9,8 @@ import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
|
||||
@@ -251,6 +253,11 @@ func (sl *SkillsLoader) BuildSkillsSummary() string {
|
||||
func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata {
|
||||
content, err := os.ReadFile(skillPath)
|
||||
if err != nil {
|
||||
logger.WarnCF("skills", "Failed to read skill metadata",
|
||||
map[string]interface{}{
|
||||
"skill_path": skillPath,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -283,10 +290,15 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata {
|
||||
|
||||
// parseSimpleYAML parses simple key: value YAML format
|
||||
// Example: name: github\n description: "..."
|
||||
// Normalizes line endings to handle \n (Unix), \r\n (Windows), and \r (classic Mac)
|
||||
func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
// Normalize line endings: convert \r\n and \r to \n
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalized = strings.ReplaceAll(normalized, "\r", "\n")
|
||||
|
||||
for _, line := range strings.Split(normalized, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
@@ -306,9 +318,10 @@ func (sl *SkillsLoader) parseSimpleYAML(content string) map[string]string {
|
||||
}
|
||||
|
||||
func (sl *SkillsLoader) extractFrontmatter(content string) string {
|
||||
// (?s) enables DOTALL mode so . matches newlines
|
||||
// Match first ---, capture everything until next --- on its own line
|
||||
re := regexp.MustCompile(`(?s)^---\n(.*)\n---`)
|
||||
// Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
|
||||
// (?s) enables DOTALL so . matches newlines;
|
||||
// ^--- at start, then ... --- at start of line, honoring all three line ending types
|
||||
re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---`)
|
||||
match := re.FindStringSubmatch(content)
|
||||
if len(match) > 1 {
|
||||
return match[1]
|
||||
@@ -317,7 +330,11 @@ func (sl *SkillsLoader) extractFrontmatter(content string) string {
|
||||
}
|
||||
|
||||
func (sl *SkillsLoader) stripFrontmatter(content string) string {
|
||||
re := regexp.MustCompile(`^---\n.*?\n---\n`)
|
||||
// Support \n (Unix), \r\n (Windows), and \r (classic Mac) line endings for frontmatter blocks
|
||||
// (?s) enables DOTALL so . matches newlines;
|
||||
// ^--- at start, then ... --- at start of line, honoring all three line ending types
|
||||
// Match zero or more trailing line endings after closing --- (handles both with and without blank lines)
|
||||
re := regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
|
||||
return re.ReplaceAllString(content, "")
|
||||
}
|
||||
|
||||
|
||||
@@ -75,3 +75,105 @@ func TestSkillsInfoValidate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFrontmatter(t *testing.T) {
|
||||
sl := &SkillsLoader{}
|
||||
|
||||
testcases := []struct {
|
||||
name string
|
||||
content string
|
||||
expectedName string
|
||||
expectedDesc string
|
||||
lineEndingType string
|
||||
}{
|
||||
{
|
||||
name: "unix-line-endings",
|
||||
lineEndingType: "Unix (\\n)",
|
||||
content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content",
|
||||
expectedName: "test-skill",
|
||||
expectedDesc: "A test skill",
|
||||
},
|
||||
{
|
||||
name: "windows-line-endings",
|
||||
lineEndingType: "Windows (\\r\\n)",
|
||||
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content",
|
||||
expectedName: "test-skill",
|
||||
expectedDesc: "A test skill",
|
||||
},
|
||||
{
|
||||
name: "classic-mac-line-endings",
|
||||
lineEndingType: "Classic Mac (\\r)",
|
||||
content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content",
|
||||
expectedName: "test-skill",
|
||||
expectedDesc: "A test skill",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Extract frontmatter
|
||||
frontmatter := sl.extractFrontmatter(tc.content)
|
||||
assert.NotEmpty(t, frontmatter, "Frontmatter should be extracted for %s line endings", tc.lineEndingType)
|
||||
|
||||
// Parse YAML to get name and description (parseSimpleYAML now handles all line ending types)
|
||||
yamlMeta := sl.parseSimpleYAML(frontmatter)
|
||||
assert.Equal(t, tc.expectedName, yamlMeta["name"], "Name should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
|
||||
assert.Equal(t, tc.expectedDesc, yamlMeta["description"], "Description should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripFrontmatter(t *testing.T) {
|
||||
sl := &SkillsLoader{}
|
||||
|
||||
testcases := []struct {
|
||||
name string
|
||||
content string
|
||||
expectedContent string
|
||||
lineEndingType string
|
||||
}{
|
||||
{
|
||||
name: "unix-line-endings",
|
||||
lineEndingType: "Unix (\\n)",
|
||||
content: "---\nname: test-skill\ndescription: A test skill\n---\n\n# Skill Content",
|
||||
expectedContent: "# Skill Content",
|
||||
},
|
||||
{
|
||||
name: "windows-line-endings",
|
||||
lineEndingType: "Windows (\\r\\n)",
|
||||
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n\r\n# Skill Content",
|
||||
expectedContent: "# Skill Content",
|
||||
},
|
||||
{
|
||||
name: "classic-mac-line-endings",
|
||||
lineEndingType: "Classic Mac (\\r)",
|
||||
content: "---\rname: test-skill\rdescription: A test skill\r---\r\r# Skill Content",
|
||||
expectedContent: "# Skill Content",
|
||||
},
|
||||
{
|
||||
name: "unix-line-endings-without-trailing-newline",
|
||||
lineEndingType: "Unix (\\n) without trailing newline",
|
||||
content: "---\nname: test-skill\ndescription: A test skill\n---\n# Skill Content",
|
||||
expectedContent: "# Skill Content",
|
||||
},
|
||||
{
|
||||
name: "windows-line-endings-without-trailing-newline",
|
||||
lineEndingType: "Windows (\\r\\n) without trailing newline",
|
||||
content: "---\r\nname: test-skill\r\ndescription: A test skill\r\n---\r\n# Skill Content",
|
||||
expectedContent: "# Skill Content",
|
||||
},
|
||||
{
|
||||
name: "no-frontmatter",
|
||||
lineEndingType: "No frontmatter",
|
||||
content: "# Skill Content\n\nSome content here.",
|
||||
expectedContent: "# Skill Content\n\nSome content here.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := sl.stripFrontmatter(tc.content)
|
||||
assert.Equal(t, tc.expectedContent, result, "Frontmatter should be stripped correctly for %s", tc.lineEndingType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+6
-2
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
@@ -28,12 +29,15 @@ type CronTool struct {
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool {
|
||||
// execTimeout: 0 means no timeout, >0 sets the timeout duration
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *CronTool {
|
||||
execTool := NewExecToolWithConfig(workspace, restrict, config)
|
||||
execTool.SetTimeout(execTimeout)
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: NewExecTool(workspace, restrict),
|
||||
execTool: execTool,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+85
-10
@@ -11,6 +11,8 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type ExecTool struct {
|
||||
@@ -21,16 +23,82 @@ type ExecTool struct {
|
||||
restrictToWorkspace bool
|
||||
}
|
||||
|
||||
var defaultDenyPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
regexp.MustCompile(`\$\([^)]+\)`),
|
||||
regexp.MustCompile(`\$\{[^}]+\}`),
|
||||
regexp.MustCompile("`[^`]+`"),
|
||||
regexp.MustCompile(`\|\s*sh\b`),
|
||||
regexp.MustCompile(`\|\s*bash\b`),
|
||||
regexp.MustCompile(`;\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`&&\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`\|\|\s*rm\s+-[rf]`),
|
||||
regexp.MustCompile(`>\s*/dev/null\s*>&?\s*\d?`),
|
||||
regexp.MustCompile(`<<\s*EOF`),
|
||||
regexp.MustCompile(`\$\(\s*cat\s+`),
|
||||
regexp.MustCompile(`\$\(\s*curl\s+`),
|
||||
regexp.MustCompile(`\$\(\s*wget\s+`),
|
||||
regexp.MustCompile(`\$\(\s*which\s+`),
|
||||
regexp.MustCompile(`\bsudo\b`),
|
||||
regexp.MustCompile(`\bchmod\s+[0-7]{3,4}\b`),
|
||||
regexp.MustCompile(`\bchown\b`),
|
||||
regexp.MustCompile(`\bpkill\b`),
|
||||
regexp.MustCompile(`\bkillall\b`),
|
||||
regexp.MustCompile(`\bkill\s+-[9]\b`),
|
||||
regexp.MustCompile(`\bcurl\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bwget\b.*\|\s*(sh|bash)`),
|
||||
regexp.MustCompile(`\bnpm\s+install\s+-g\b`),
|
||||
regexp.MustCompile(`\bpip\s+install\s+--user\b`),
|
||||
regexp.MustCompile(`\bapt\s+(install|remove|purge)\b`),
|
||||
regexp.MustCompile(`\byum\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdnf\s+(install|remove)\b`),
|
||||
regexp.MustCompile(`\bdocker\s+run\b`),
|
||||
regexp.MustCompile(`\bdocker\s+exec\b`),
|
||||
regexp.MustCompile(`\bgit\s+push\b`),
|
||||
regexp.MustCompile(`\bgit\s+force\b`),
|
||||
regexp.MustCompile(`\bssh\b.*@`),
|
||||
regexp.MustCompile(`\beval\b`),
|
||||
regexp.MustCompile(`\bsource\s+.*\.sh\b`),
|
||||
}
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) *ExecTool {
|
||||
denyPatterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
|
||||
regexp.MustCompile(`\bdel\s+/[fq]\b`),
|
||||
regexp.MustCompile(`\brmdir\s+/s\b`),
|
||||
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
|
||||
regexp.MustCompile(`\bdd\s+if=`),
|
||||
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
|
||||
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
|
||||
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
|
||||
enableDenyPatterns := true
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
enableDenyPatterns = execConfig.EnableDenyPatterns
|
||||
if enableDenyPatterns {
|
||||
if len(execConfig.CustomDenyPatterns) > 0 {
|
||||
fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns)
|
||||
for _, pattern := range execConfig.CustomDenyPatterns {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid custom deny pattern %q: %v\n", pattern, err)
|
||||
continue
|
||||
}
|
||||
denyPatterns = append(denyPatterns, re)
|
||||
}
|
||||
} else {
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
} else {
|
||||
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
|
||||
fmt.Println("Warning: deny patterns are disabled. All commands will be allowed.")
|
||||
}
|
||||
} else {
|
||||
denyPatterns = append(denyPatterns, defaultDenyPatterns...)
|
||||
}
|
||||
|
||||
return &ExecTool{
|
||||
@@ -89,7 +157,14 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
|
||||
return ErrorResult(guardError)
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
|
||||
// timeout == 0 means no timeout
|
||||
var cmdCtx context.Context
|
||||
var cancel context.CancelFunc
|
||||
if t.timeout > 0 {
|
||||
cmdCtx, cancel = context.WithTimeout(ctx, t.timeout)
|
||||
} else {
|
||||
cmdCtx, cancel = context.WithCancel(ctx)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
+22
-5
@@ -6,10 +6,11 @@ import (
|
||||
)
|
||||
|
||||
type SpawnTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
callback AsyncCallback // For async completion notification
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
allowlistCheck func(targetAgentID string) bool
|
||||
callback AsyncCallback // For async completion notification
|
||||
}
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
@@ -45,6 +46,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
"agent_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Optional target agent ID to delegate the task to",
|
||||
},
|
||||
},
|
||||
"required": []string{"task"},
|
||||
}
|
||||
@@ -55,6 +60,10 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
|
||||
t.allowlistCheck = check
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
@@ -62,13 +71,21 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
agentID, _ := args["agent_id"].(string)
|
||||
|
||||
// Check allowlist if targeting a specific agent
|
||||
if agentID != "" && t.allowlistCheck != nil {
|
||||
if !t.allowlistCheck(agentID) {
|
||||
return ErrorResult(fmt.Sprintf("not allowed to spawn agent '%s'", agentID))
|
||||
}
|
||||
}
|
||||
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Pass callback to manager for async completion notification
|
||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
|
||||
result, err := t.manager.Spawn(ctx, task, label, agentID, t.originChannel, t.originChatID, t.callback)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ type SubagentTask struct {
|
||||
ID string
|
||||
Task string
|
||||
Label string
|
||||
AgentID string
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
Status string
|
||||
@@ -61,7 +62,7 @@ func (sm *SubagentManager) RegisterTool(tool Tool) {
|
||||
sm.tools.Register(tool)
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
@@ -72,6 +73,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
||||
ID: taskID,
|
||||
Task: task,
|
||||
Label: label,
|
||||
AgentID: agentID,
|
||||
OriginChannel: originChannel,
|
||||
OriginChatID: originChatID,
|
||||
Status: "running",
|
||||
|
||||
+79
-4
@@ -176,6 +176,71 @@ func stripTags(content string) string {
|
||||
return re.ReplaceAllString(content, "")
|
||||
}
|
||||
|
||||
type PerplexitySearchProvider struct {
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
|
||||
searchURL := "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model": "sonar",
|
||||
"messages": []map[string]string{
|
||||
{"role": "system", "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary."},
|
||||
{"role": "user", "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count)},
|
||||
},
|
||||
"max_tokens": 1000,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", searchURL, strings.NewReader(string(payloadBytes)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("Perplexity API error: %s", string(body))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(searchResp.Choices) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Results for: %s (via Perplexity)\n%s", query, searchResp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
type WebSearchTool struct {
|
||||
provider SearchProvider
|
||||
maxResults int
|
||||
@@ -187,14 +252,22 @@ type WebSearchToolOptions struct {
|
||||
BraveEnabled bool
|
||||
DuckDuckGoMaxResults int
|
||||
DuckDuckGoEnabled bool
|
||||
PerplexityAPIKey string
|
||||
PerplexityMaxResults int
|
||||
PerplexityEnabled bool
|
||||
}
|
||||
|
||||
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
|
||||
var provider SearchProvider
|
||||
maxResults := 5
|
||||
|
||||
// Priority: Brave > DuckDuckGo
|
||||
if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
// Priority: Perplexity > Brave > DuckDuckGo
|
||||
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
|
||||
provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey}
|
||||
if opts.PerplexityMaxResults > 0 {
|
||||
maxResults = opts.PerplexityMaxResults
|
||||
}
|
||||
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
|
||||
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
|
||||
if opts.BraveMaxResults > 0 {
|
||||
maxResults = opts.BraveMaxResults
|
||||
@@ -419,8 +492,10 @@ func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
result = strings.TrimSpace(result)
|
||||
|
||||
re = regexp.MustCompile(`\s+`)
|
||||
result = re.ReplaceAllLiteralString(result, " ")
|
||||
re = regexp.MustCompile(`[^\S\n]+`)
|
||||
result = re.ReplaceAllString(result, " ")
|
||||
re = regexp.MustCompile(`\n{3,}`)
|
||||
result = re.ReplaceAllString(result, "\n\n")
|
||||
|
||||
lines := strings.Split(result, "\n")
|
||||
var cleanLines []string
|
||||
|
||||
@@ -234,6 +234,80 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebFetchTool_extractText verifies text extraction preserves newlines
|
||||
func TestWebFetchTool_extractText(t *testing.T) {
|
||||
tool := &WebFetchTool{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantFunc func(t *testing.T, got string)
|
||||
}{
|
||||
{
|
||||
name: "preserves newlines between block elements",
|
||||
input: "<html><body><h1>Title</h1>\n<p>Paragraph 1</p>\n<p>Paragraph 2</p></body></html>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
lines := strings.Split(got, "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
|
||||
}
|
||||
if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") {
|
||||
t.Errorf("Missing expected text: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "removes script and style tags",
|
||||
input: "<script>alert('x');</script><style>body{}</style><p>Keep this</p>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if strings.Contains(got, "alert") || strings.Contains(got, "body{}") {
|
||||
t.Errorf("Expected script/style content removed, got: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Keep this") {
|
||||
t.Errorf("Expected 'Keep this' to remain, got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "collapses excessive blank lines",
|
||||
input: "<p>A</p>\n\n\n\n\n<p>B</p>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if strings.Contains(got, "\n\n\n") {
|
||||
t.Errorf("Expected excessive blank lines collapsed, got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "collapses horizontal whitespace",
|
||||
input: "<p>hello world</p>",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if strings.Contains(got, " ") {
|
||||
t.Errorf("Expected spaces collapsed, got: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "hello world") {
|
||||
t.Errorf("Expected 'hello world', got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
wantFunc: func(t *testing.T, got string) {
|
||||
if got != "" {
|
||||
t.Errorf("Expected empty string, got: %q", got)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tool.extractText(tt.input)
|
||||
tt.wantFunc(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SplitMessage splits long messages into chunks, preserving code block integrity.
|
||||
// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks,
|
||||
// but may extend to maxLen when needed.
|
||||
// Call SplitMessage with the full text content and the maximum allowed length of a single message;
|
||||
// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks.
|
||||
func SplitMessage(content string, maxLen int) []string {
|
||||
var messages []string
|
||||
|
||||
// Dynamic buffer: 10% of maxLen, but at least 50 chars if possible
|
||||
codeBlockBuffer := maxLen / 10
|
||||
if codeBlockBuffer < 50 {
|
||||
codeBlockBuffer = 50
|
||||
}
|
||||
if codeBlockBuffer > maxLen/2 {
|
||||
codeBlockBuffer = maxLen / 2
|
||||
}
|
||||
|
||||
for len(content) > 0 {
|
||||
if len(content) <= maxLen {
|
||||
messages = append(messages, content)
|
||||
break
|
||||
}
|
||||
|
||||
// Effective split point: maxLen minus buffer, to leave room for code blocks
|
||||
effectiveLimit := maxLen - codeBlockBuffer
|
||||
if effectiveLimit < maxLen/2 {
|
||||
effectiveLimit = maxLen / 2
|
||||
}
|
||||
|
||||
// Find natural split point within the effective limit
|
||||
msgEnd := findLastNewline(content[:effectiveLimit], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:effectiveLimit], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = effectiveLimit
|
||||
}
|
||||
|
||||
// Check if this would end with an incomplete code block
|
||||
candidate := content[:msgEnd]
|
||||
unclosedIdx := findLastUnclosedCodeBlock(candidate)
|
||||
|
||||
if unclosedIdx >= 0 {
|
||||
// Message would end with incomplete code block
|
||||
// Try to extend up to maxLen to include the closing ```
|
||||
if len(content) > msgEnd {
|
||||
closingIdx := findNextClosingCodeBlock(content, msgEnd)
|
||||
if closingIdx > 0 && closingIdx <= maxLen {
|
||||
// Extend to include the closing ```
|
||||
msgEnd = closingIdx
|
||||
} else {
|
||||
// Code block is too long to fit in one chunk or missing closing fence.
|
||||
// Try to split inside by injecting closing and reopening fences.
|
||||
headerEnd := strings.Index(content[unclosedIdx:], "\n")
|
||||
if headerEnd == -1 {
|
||||
headerEnd = unclosedIdx + 3
|
||||
} else {
|
||||
headerEnd += unclosedIdx
|
||||
}
|
||||
header := strings.TrimSpace(content[unclosedIdx:headerEnd])
|
||||
|
||||
// If we have a reasonable amount of content after the header, split inside
|
||||
if msgEnd > headerEnd+20 {
|
||||
// Find a better split point closer to maxLen
|
||||
innerLimit := maxLen - 5 // Leave room for "\n```"
|
||||
betterEnd := findLastNewline(content[:innerLimit], 200)
|
||||
if betterEnd > headerEnd {
|
||||
msgEnd = betterEnd
|
||||
} else {
|
||||
msgEnd = innerLimit
|
||||
}
|
||||
messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```")
|
||||
content = strings.TrimSpace(header + "\n" + content[msgEnd:])
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, try to split before the code block starts
|
||||
newEnd := findLastNewline(content[:unclosedIdx], 200)
|
||||
if newEnd <= 0 {
|
||||
newEnd = findLastSpace(content[:unclosedIdx], 100)
|
||||
}
|
||||
if newEnd > 0 {
|
||||
msgEnd = newEnd
|
||||
} else {
|
||||
// If we can't split before, we MUST split inside (last resort)
|
||||
if unclosedIdx > 20 {
|
||||
msgEnd = unclosedIdx
|
||||
} else {
|
||||
msgEnd = maxLen - 5
|
||||
messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```")
|
||||
content = strings.TrimSpace(header + "\n" + content[msgEnd:])
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = effectiveLimit
|
||||
}
|
||||
|
||||
messages = append(messages, content[:msgEnd])
|
||||
content = strings.TrimSpace(content[msgEnd:])
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
|
||||
// Returns the position of the opening ``` or -1 if all code blocks are complete
|
||||
func findLastUnclosedCodeBlock(text string) int {
|
||||
inCodeBlock := false
|
||||
lastOpenIdx := -1
|
||||
|
||||
for i := 0; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
// Toggle code block state on each fence
|
||||
if !inCodeBlock {
|
||||
// Entering a code block: record this opening fence
|
||||
lastOpenIdx = i
|
||||
}
|
||||
inCodeBlock = !inCodeBlock
|
||||
i += 2
|
||||
}
|
||||
}
|
||||
|
||||
if inCodeBlock {
|
||||
return lastOpenIdx
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findNextClosingCodeBlock finds the next closing ``` starting from a position
|
||||
// Returns the position after the closing ``` or -1 if not found
|
||||
func findNextClosingCodeBlock(text string, startIdx int) int {
|
||||
for i := startIdx; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
return i + 3
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastNewline finds the last newline character within the last N characters
|
||||
// Returns the position of the newline or -1 if not found
|
||||
func findLastNewline(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == '\n' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastSpace finds the last space character within the last N characters
|
||||
// Returns the position of the space or -1 if not found
|
||||
func findLastSpace(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == ' ' || s[i] == '\t' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitMessage(t *testing.T) {
|
||||
longText := strings.Repeat("a", 2500)
|
||||
longCode := "```go\n" + strings.Repeat("fmt.Println(\"hello\")\n", 100) + "```" // ~2100 chars
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
maxLen int
|
||||
expectChunks int // Check number of chunks
|
||||
checkContent func(t *testing.T, chunks []string) // Custom validation
|
||||
}{
|
||||
{
|
||||
name: "Empty message",
|
||||
content: "",
|
||||
maxLen: 2000,
|
||||
expectChunks: 0,
|
||||
},
|
||||
{
|
||||
name: "Short message fits in one chunk",
|
||||
content: "Hello world",
|
||||
maxLen: 2000,
|
||||
expectChunks: 1,
|
||||
},
|
||||
{
|
||||
name: "Simple split regular text",
|
||||
content: longText,
|
||||
maxLen: 2000,
|
||||
expectChunks: 2,
|
||||
checkContent: func(t *testing.T, chunks []string) {
|
||||
if len(chunks[0]) > 2000 {
|
||||
t.Errorf("Chunk 0 too large: %d", len(chunks[0]))
|
||||
}
|
||||
if len(chunks[0])+len(chunks[1]) != len(longText) {
|
||||
t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Split at newline",
|
||||
// 1750 chars then newline, then more chars.
|
||||
// Dynamic buffer: 2000 / 10 = 200.
|
||||
// Effective limit: 2000 - 200 = 1800.
|
||||
// Split should happen at newline because it's at 1750 (< 1800).
|
||||
// Total length must > 2000 to trigger split. 1750 + 1 + 300 = 2051.
|
||||
content: strings.Repeat("a", 1750) + "\n" + strings.Repeat("b", 300),
|
||||
maxLen: 2000,
|
||||
expectChunks: 2,
|
||||
checkContent: func(t *testing.T, chunks []string) {
|
||||
if len(chunks[0]) != 1750 {
|
||||
t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0]))
|
||||
}
|
||||
if chunks[1] != strings.Repeat("b", 300) {
|
||||
t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1]))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Long code block split",
|
||||
content: "Prefix\n" + longCode,
|
||||
maxLen: 2000,
|
||||
expectChunks: 2,
|
||||
checkContent: func(t *testing.T, chunks []string) {
|
||||
// Check that first chunk ends with closing fence
|
||||
if !strings.HasSuffix(chunks[0], "\n```") {
|
||||
t.Error("First chunk should end with injected closing fence")
|
||||
}
|
||||
// Check that second chunk starts with execution header
|
||||
if !strings.HasPrefix(chunks[1], "```go") {
|
||||
t.Error("Second chunk should start with injected code block header")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Preserve Unicode characters",
|
||||
content: strings.Repeat("\u4e16", 1000), // 3000 bytes
|
||||
maxLen: 2000,
|
||||
expectChunks: 2,
|
||||
checkContent: func(t *testing.T, chunks []string) {
|
||||
// Just verify we didn't panic and got valid strings.
|
||||
// Go strings are UTF-8, if we split mid-rune it would be bad,
|
||||
// but standard slicing might do that.
|
||||
// Let's assume standard behavior is acceptable or check if it produces invalid rune?
|
||||
if !strings.Contains(chunks[0], "\u4e16") {
|
||||
t.Error("Chunk should contain unicode characters")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := SplitMessage(tc.content, tc.maxLen)
|
||||
|
||||
if tc.expectChunks == 0 {
|
||||
if len(got) != 0 {
|
||||
t.Errorf("Expected 0 chunks, got %d", len(got))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(got) != tc.expectChunks {
|
||||
t.Errorf("Expected %d chunks, got %d", tc.expectChunks, len(got))
|
||||
// Log sizes for debugging
|
||||
for i, c := range got {
|
||||
t.Logf("Chunk %d length: %d", i, len(c))
|
||||
}
|
||||
return // Stop further checks if count assumes specific split
|
||||
}
|
||||
|
||||
if tc.checkContent != nil {
|
||||
tc.checkContent(t, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessage_CodeBlockIntegrity(t *testing.T) {
|
||||
// Focused test for the core requirement: splitting inside a code block preserves syntax highlighting
|
||||
|
||||
// 60 chars total approximately
|
||||
content := "```go\npackage main\n\nfunc main() {\n\tprintln(\"Hello\")\n}\n```"
|
||||
maxLen := 40
|
||||
|
||||
chunks := SplitMessage(content, maxLen)
|
||||
|
||||
if len(chunks) != 2 {
|
||||
t.Fatalf("Expected 2 chunks, got %d: %q", len(chunks), chunks)
|
||||
}
|
||||
|
||||
// First chunk must end with "\n```"
|
||||
if !strings.HasSuffix(chunks[0], "\n```") {
|
||||
t.Errorf("First chunk should end with closing fence. Got: %q", chunks[0])
|
||||
}
|
||||
|
||||
// Second chunk must start with the header "```go"
|
||||
if !strings.HasPrefix(chunks[1], "```go") {
|
||||
t.Errorf("Second chunk should start with code block header. Got: %q", chunks[1])
|
||||
}
|
||||
|
||||
// First chunk should contain meaningful content
|
||||
if len(chunks[0]) > 40 {
|
||||
t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0]))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user